From: Jérôme Benoit Date: Sat, 8 Mar 2025 20:29:36 +0000 (+0100) Subject: fix(reforcexy): revert incorrect observation stacking X-Git-Url: https://git.piment-noir.org/?a=commitdiff_plain;h=992a6695a539effd729bf7e301ca2974439b2e08;p=freqai-strategies.git fix(reforcexy): revert incorrect observation stacking Signed-off-by: Jérôme Benoit --- diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index 3ab6f4d..a65682e 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -1,4 +1,3 @@ -from collections import deque import copy import gc import json @@ -106,7 +105,6 @@ class ReforceXY(BaseReinforcementLearningModel): raise ValueError( "FreqAI model requires StaticPairList method defined in pairlists configuration and pair_whitelist defined in exchange section configuration" ) - self.observations_buffer: Dict[str, deque] = {} self.is_maskable: bool = ( self.model_type == "MaskablePPO" ) # Enable action masking @@ -266,8 +264,6 @@ class ReforceXY(BaseReinforcementLearningModel): rollout_plot_callback = None verbose = self.model_training_parameters.get("verbose", 0) - eval_freq //= self.n_envs - if self.plot_new_best: rollout_plot_callback = RolloutPlotCallback(verbose=verbose) @@ -377,7 +373,9 @@ class ReforceXY(BaseReinforcementLearningModel): model = self.dd.model_dictionary[dk.pair] model.set_env(self.train_env) - callbacks = self.get_callbacks(train_timesteps, str(dk.data_path)) + callbacks = self.get_callbacks( + train_timesteps // self.n_envs, str(dk.data_path) + ) try: model.learn(total_timesteps=total_timesteps, callback=callbacks) finally: @@ -429,12 +427,6 @@ class ReforceXY(BaseReinforcementLearningModel): :param dk: FreqaiDatakitchen = data kitchen for the current pair :param model: Any = the trained model used to inference the features. """ - if not self.observations_buffer.get(dk.pair): - buffer_size = max(1, self.frame_stacking) - initial_observation = dataframe.iloc[0].to_numpy(dtype=np.float32) - self.observations_buffer[dk.pair] = deque( - [initial_observation] * buffer_size, maxlen=buffer_size - ) def _is_valid(action: int, position: float) -> bool: """ @@ -471,14 +463,16 @@ class ReforceXY(BaseReinforcementLearningModel): observation = observation.to_numpy(dtype=np.float32) - self.observations_buffer[dk.pair].append(observation) - - stacked_observations = np.concatenate( - self.observations_buffer[dk.pair], axis=1 - ) + if self.frame_stacking: + # FIXME: proper observation stacking need more work + observations = np.repeat( + observation, axis=1, repeats=self.frame_stacking + ) + else: + observations = observation action, _ = model.predict( - stacked_observations, deterministic=True, **action_masks_param + observations, deterministic=True, **action_masks_param ) return action @@ -515,7 +509,7 @@ class ReforceXY(BaseReinforcementLearningModel): else: study_name = identifier storage = self.get_storage() - eval_freq = len(train_df) // self.n_envs + eval_freq = max(1, len(train_df) // self.n_envs) study: Study = create_study( study_name=study_name, sampler=TPESampler( @@ -525,7 +519,7 @@ class ReforceXY(BaseReinforcementLearningModel): ), pruner=HyperbandPruner( min_resource=3, - max_resource=total_timesteps // eval_freq, + max_resource=(total_timesteps + eval_freq - 1) // eval_freq, reduction_factor=3, ), direction=StudyDirection.MAXIMIZE, @@ -652,7 +646,9 @@ class ReforceXY(BaseReinforcementLearningModel): **params, ) - callbacks = self.get_callbacks(len(train_df), str(dk.data_path), trial) + callbacks = self.get_callbacks( + len(train_df) // self.n_envs, str(dk.data_path), trial + ) try: model.learn(total_timesteps=total_timesteps, callback=callbacks) except AssertionError: @@ -734,11 +730,11 @@ class ReforceXY(BaseReinforcementLearningModel): """ Reset is called at the beginning of every episode """ - _, history = super().reset(seed, **kwargs) + observation, history = super().reset(seed, **kwargs) self._force_action: Optional[ForceActions] = None self._last_closed_position: Positions = None self._last_closed_trade_tick: int = 0 - return self._get_observation(), history + return observation, history def _get_reward_factor_at_trade_exit( self,