]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
fix(reforcexy): revert incorrect observation stacking
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Sat, 8 Mar 2025 20:29:36 +0000 (21:29 +0100)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Sat, 8 Mar 2025 20:29:36 +0000 (21:29 +0100)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index 3ab6f4dd748d563f800105f8ccafba71bb03297e..a65682eb4f76fac67de02e45303442235de7f3fb 100644 (file)
@@ -1,4 +1,3 @@
-from collections import deque
 import copy
 import gc
 import json
@@ -106,7 +105,6 @@ class ReforceXY(BaseReinforcementLearningModel):
             raise ValueError(
                 "FreqAI model requires StaticPairList method defined in pairlists configuration and pair_whitelist defined in exchange section configuration"
             )
-        self.observations_buffer: Dict[str, deque] = {}
         self.is_maskable: bool = (
             self.model_type == "MaskablePPO"
         )  # Enable action masking
@@ -266,8 +264,6 @@ class ReforceXY(BaseReinforcementLearningModel):
         rollout_plot_callback = None
         verbose = self.model_training_parameters.get("verbose", 0)
 
-        eval_freq //= self.n_envs
-
         if self.plot_new_best:
             rollout_plot_callback = RolloutPlotCallback(verbose=verbose)
 
@@ -377,7 +373,9 @@ class ReforceXY(BaseReinforcementLearningModel):
             model = self.dd.model_dictionary[dk.pair]
             model.set_env(self.train_env)
 
-        callbacks = self.get_callbacks(train_timesteps, str(dk.data_path))
+        callbacks = self.get_callbacks(
+            train_timesteps // self.n_envs, str(dk.data_path)
+        )
         try:
             model.learn(total_timesteps=total_timesteps, callback=callbacks)
         finally:
@@ -429,12 +427,6 @@ class ReforceXY(BaseReinforcementLearningModel):
         :param dk: FreqaiDatakitchen = data kitchen for the current pair
         :param model: Any = the trained model used to inference the features.
         """
-        if not self.observations_buffer.get(dk.pair):
-            buffer_size = max(1, self.frame_stacking)
-            initial_observation = dataframe.iloc[0].to_numpy(dtype=np.float32)
-            self.observations_buffer[dk.pair] = deque(
-                [initial_observation] * buffer_size, maxlen=buffer_size
-            )
 
         def _is_valid(action: int, position: float) -> bool:
             """
@@ -471,14 +463,16 @@ class ReforceXY(BaseReinforcementLearningModel):
 
             observation = observation.to_numpy(dtype=np.float32)
 
-            self.observations_buffer[dk.pair].append(observation)
-
-            stacked_observations = np.concatenate(
-                self.observations_buffer[dk.pair], axis=1
-            )
+            if self.frame_stacking:
+                # FIXME: proper observation stacking need more work
+                observations = np.repeat(
+                    observation, axis=1, repeats=self.frame_stacking
+                )
+            else:
+                observations = observation
 
             action, _ = model.predict(
-                stacked_observations, deterministic=True, **action_masks_param
+                observations, deterministic=True, **action_masks_param
             )
             return action
 
@@ -515,7 +509,7 @@ class ReforceXY(BaseReinforcementLearningModel):
         else:
             study_name = identifier
             storage = self.get_storage()
-        eval_freq = len(train_df) // self.n_envs
+        eval_freq = max(1, len(train_df) // self.n_envs)
         study: Study = create_study(
             study_name=study_name,
             sampler=TPESampler(
@@ -525,7 +519,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             ),
             pruner=HyperbandPruner(
                 min_resource=3,
-                max_resource=total_timesteps // eval_freq,
+                max_resource=(total_timesteps + eval_freq - 1) // eval_freq,
                 reduction_factor=3,
             ),
             direction=StudyDirection.MAXIMIZE,
@@ -652,7 +646,9 @@ class ReforceXY(BaseReinforcementLearningModel):
             **params,
         )
 
-        callbacks = self.get_callbacks(len(train_df), str(dk.data_path), trial)
+        callbacks = self.get_callbacks(
+            len(train_df) // self.n_envs, str(dk.data_path), trial
+        )
         try:
             model.learn(total_timesteps=total_timesteps, callback=callbacks)
         except AssertionError:
@@ -734,11 +730,11 @@ class ReforceXY(BaseReinforcementLearningModel):
             """
             Reset is called at the beginning of every episode
             """
-            _, history = super().reset(seed, **kwargs)
+            observation, history = super().reset(seed, **kwargs)
             self._force_action: Optional[ForceActions] = None
             self._last_closed_position: Positions = None
             self._last_closed_trade_tick: int = 0
-            return self._get_observation(), history
+            return observation, history
 
         def _get_reward_factor_at_trade_exit(
             self,