]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
fix(reforcexy): properly stack observations
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Sat, 8 Mar 2025 17:43:43 +0000 (18:43 +0100)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Sat, 8 Mar 2025 17:43:43 +0000 (18:43 +0100)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/config-template.json
ReforceXY/user_data/freqaimodels/ReforceXY.py
quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py
quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py
quickadapter/user_data/strategies/QuickAdapterV3.py

index 92cbf82bea01f1790788d12726fb26b3470ce8d0..15237b020d21b0e7ba03dea479f3c8e657ea0630 100644 (file)
         "profit_aim": 0.025,
         "win_reward_factor": 2
       },
-      "train_cycles": 250,
+      "train_cycles": 25,
       "add_state_info": true,
       "cpu_count": 6,
       "max_training_drawdown_pct": 0.02,
       "max_trade_duration_candles": 96, // Timeout exit value used with force_actions
       "force_actions": false, // Utilize minimal_roi, stoploss, and max_trade_duration_candles as TP/SL/Timeout in the environment
       "n_envs": 32, // Number of DummyVecEnv environments
-      "frame_staking": 4, // Number of VecFrameStack stacks (set > 1 to use)
+      "frame_stacking": 4, // Number of VecFrameStack stacks (set > 1 to use)
       "lr_schedule": false, // Enable learning rate linear schedule
       "cr_schedule": false, // Enable clip range linear schedule
       "max_no_improvement_evals": 0, // Maximum consecutive evaluations without a new best model
index 492b9c9cc4aeabef59f20d690eba84b594a9072e..3ab6f4dd748d563f800105f8ccafba71bb03297e 100644 (file)
@@ -1,3 +1,4 @@
+from collections import deque
 import copy
 import gc
 import json
@@ -74,7 +75,7 @@ class ReforceXY(BaseReinforcementLearningModel):
                 "max_trade_duration_candles": 96,   // Timeout exit value used with force_actions
                 "force_actions": false,             // Utilize minimal_roi, stoploss, and max_trade_duration_candles as TP/SL/Timeout in the environment
                 "n_envs": 1,                        // Number of DummyVecEnv environments
-                "frame_staking": 0,                 // Number of VecFrameStack stacks (set > 1 to use)
+                "frame_stacking": 0,                // Number of VecFrameStack stacks (set > 1 to use)
                 "lr_schedule": false,               // Enable learning rate linear schedule
                 "cr_schedule": false,               // Enable clip range linear schedule
                 "max_no_improvement_evals": 0,      // Maximum consecutive evaluations without a new best model
@@ -105,14 +106,14 @@ class ReforceXY(BaseReinforcementLearningModel):
             raise ValueError(
                 "FreqAI model requires StaticPairList method defined in pairlists configuration and pair_whitelist defined in exchange section configuration"
             )
+        self.observations_buffer: Dict[str, deque] = {}
         self.is_maskable: bool = (
             self.model_type == "MaskablePPO"
         )  # Enable action masking
         self.lr_schedule: bool = self.rl_config.get("lr_schedule", False)
         self.cr_schedule: bool = self.rl_config.get("cr_schedule", False)
         self.n_envs: int = self.rl_config.get("n_envs", 1)
-        self.frame_staking: int = self.rl_config.get("frame_staking", 0)
-        self.frame_staking += 1 if self.frame_staking == 1 else 0
+        self.frame_stacking: int = self.rl_config.get("frame_stacking", 0)
         self.max_no_improvement_evals: int = self.rl_config.get(
             "max_no_improvement_evals", 0
         )
@@ -140,9 +141,9 @@ class ReforceXY(BaseReinforcementLearningModel):
         If user has activated any custom function that may conflict, this
         function will set them to false and warn them
         """
-        if self.continual_learning and self.frame_staking:
+        if self.continual_learning and self.frame_stacking:
             logger.warning(
-                "User tried to use continual_learning with frame_staking. \
+                "User tried to use continual_learning with frame_stacking. \
                 Deactivating continual_learning"
             )
             self.continual_learning = False
@@ -206,12 +207,14 @@ class ReforceXY(BaseReinforcementLearningModel):
                 for i in range(self.n_envs)
             ]
         )
-        if self.frame_staking:
-            logger.info("Frame staking: %s", self.frame_staking)
-            train_env = VecFrameStack(train_env, n_stack=self.frame_staking)
-            eval_env = VecFrameStack(eval_env, n_stack=self.frame_staking)
+        if self.frame_stacking:
+            logger.info("Frame stacking: %s", self.frame_stacking)
+            train_env = VecFrameStack(train_env, n_stack=self.frame_stacking)
+            eval_env = VecFrameStack(eval_env, n_stack=self.frame_stacking)
 
         self.train_env = VecMonitor(train_env)
+        if self.frame_stacking and not self.train_env.observation_space.shape:
+            raise ValueError("Frame stacking requires predefined observation shape")
         self.eval_env = VecMonitor(eval_env)
 
     def get_model_params(self) -> Dict:
@@ -254,7 +257,7 @@ class ReforceXY(BaseReinforcementLearningModel):
 
     def get_callbacks(
         self, eval_freq: int, data_path: str, trial: Trial = None
-    ) -> list:
+    ) -> list[BaseCallback]:
         """
         Get the model specific callbacks
         """
@@ -263,8 +266,7 @@ class ReforceXY(BaseReinforcementLearningModel):
         rollout_plot_callback = None
         verbose = self.model_training_parameters.get("verbose", 0)
 
-        if self.n_envs > 1:
-            eval_freq //= self.n_envs
+        eval_freq //= self.n_envs
 
         if self.plot_new_best:
             rollout_plot_callback = RolloutPlotCallback(verbose=verbose)
@@ -325,17 +327,18 @@ class ReforceXY(BaseReinforcementLearningModel):
         train_df = data_dictionary["train_features"]
         train_timesteps = len(train_df)
         test_timesteps = len(data_dictionary["test_features"])
-        train_cycles = int(self.rl_config.get("train_cycles", 250))
-        total_timesteps = train_timesteps * train_cycles
+        train_cycles = max(1, int(self.rl_config.get("train_cycles", 25)))
+        total_timesteps = train_timesteps * train_cycles * self.n_envs
         train_days = steps_to_days(train_timesteps, self.config["timeframe"])
         total_days = steps_to_days(total_timesteps, self.config["timeframe"])
 
         logger.info("Action masking: %s", self.is_maskable)
         logger.info(
-            "Train: %s steps (%s days) * %s cycles = Total %s (%s days)",
+            "Train: %s steps (%s days) * %s cycles * %s environments = Total %s (%s days)",
             train_timesteps,
             train_days,
             train_cycles,
+            self.n_envs,
             total_timesteps,
             total_days,
         )
@@ -380,6 +383,8 @@ class ReforceXY(BaseReinforcementLearningModel):
         finally:
             if self.progressbar_callback:
                 self.progressbar_callback.on_training_end()
+            self.close_envs()
+            model.env.close()
         time_spent = time.time() - start
         self.dd.update_metric_tracker("fit_time", time_spent, dk.pair)
 
@@ -387,8 +392,13 @@ class ReforceXY(BaseReinforcementLearningModel):
         model_path = Path(dk.data_path / f"{model_filename}_model.zip")
         if model_path.is_file():
             logger.info(f"Callback found a best model: {model_path}.")
-            best_model = self.MODELCLASS.load(dk.data_path / f"{model_filename}_model")
-            return best_model
+            try:
+                best_model = self.MODELCLASS.load(
+                    dk.data_path / f"{model_filename}_model"
+                )
+                return best_model
+            except Exception as e:
+                logger.error(f"Error loading best model: {e}", exc_info=True)
 
         logger.info("Couldn't find best model, using final model instead.")
 
@@ -419,6 +429,12 @@ class ReforceXY(BaseReinforcementLearningModel):
         :param dk: FreqaiDatakitchen = data kitchen for the current pair
         :param model: Any = the trained model used to inference the features.
         """
+        if not self.observations_buffer.get(dk.pair):
+            buffer_size = max(1, self.frame_stacking)
+            initial_observation = dataframe.iloc[0].to_numpy(dtype=np.float32)
+            self.observations_buffer[dk.pair] = deque(
+                [initial_observation] * buffer_size, maxlen=buffer_size
+            )
 
         def _is_valid(action: int, position: float) -> bool:
             """
@@ -440,28 +456,29 @@ class ReforceXY(BaseReinforcementLearningModel):
             return [_is_valid(action.value, position) for action in Actions]
 
         def _predict(window):
-            observations: DataFrame = dataframe.iloc[window.index]
+            observation: DataFrame = dataframe.iloc[window.index]
             action_masks_param: dict = {}
 
             if self.live and self.rl_config.get("add_state_info", False):
                 position, pnl, trade_duration = self.get_state_info(dk.pair)
                 # STATE_INFO
-                observations["pnl"] = pnl
-                observations["position"] = position
-                observations["trade_duration"] = trade_duration
+                observation["pnl"] = pnl
+                observation["position"] = position
+                observation["trade_duration"] = trade_duration
 
                 if self.is_maskable:
                     action_masks_param = {"action_masks": _action_masks(position)}
 
-            observations = observations.to_numpy(dtype=np.float32)
+            observation = observation.to_numpy(dtype=np.float32)
 
-            if self.frame_staking:
-                observations = np.repeat(
-                    observations, axis=1, repeats=self.frame_staking
-                )
+            self.observations_buffer[dk.pair].append(observation)
+
+            stacked_observations = np.concatenate(
+                self.observations_buffer[dk.pair], axis=1
+            )
 
             action, _ = model.predict(
-                observations, deterministic=True, **action_masks_param
+                stacked_observations, deterministic=True, **action_masks_param
             )
             return action
 
@@ -469,7 +486,7 @@ class ReforceXY(BaseReinforcementLearningModel):
         output = output.rolling(window=self.CONV_WIDTH).apply(_predict)
         return output
 
-    def get_storage(self, pair: str | None = None) -> BaseStorage:
+    def get_storage(self, pair: str | None = None) -> BaseStorage | None:
         """
         Get the storage for Optuna
         """
@@ -498,6 +515,7 @@ class ReforceXY(BaseReinforcementLearningModel):
         else:
             study_name = identifier
             storage = self.get_storage()
+        eval_freq = len(train_df) // self.n_envs
         study: Study = create_study(
             study_name=study_name,
             sampler=TPESampler(
@@ -506,7 +524,9 @@ class ReforceXY(BaseReinforcementLearningModel):
                 group=True,
             ),
             pruner=HyperbandPruner(
-                min_resource=1, max_resource=self.optuna_n_trials, reduction_factor=3
+                min_resource=3,
+                max_resource=total_timesteps // eval_freq,
+                reduction_factor=3,
             ),
             direction=StudyDirection.MAXIMIZE,
             storage=storage,
@@ -601,6 +621,8 @@ class ReforceXY(BaseReinforcementLearningModel):
         """
         if "PPO" in self.model_type:
             params = sample_params_ppo(trial)
+            if params.get("n_steps", 0) > total_timesteps:
+                raise TrialPruned("n_steps exceeds total_timesteps")
         elif "QRDQN" in self.model_type:
             params = sample_params_qrdqn(trial)
         elif "DQN" in self.model_type:
@@ -629,8 +651,8 @@ class ReforceXY(BaseReinforcementLearningModel):
             tensorboard_log=tensorboard_log_path,
             **params,
         )
-        callbacks = self.get_callbacks(len(train_df), str(dk.data_path), trial)
 
+        callbacks = self.get_callbacks(len(train_df), str(dk.data_path), trial)
         try:
             model.learn(total_timesteps=total_timesteps, callback=callbacks)
         except AssertionError:
@@ -1068,6 +1090,8 @@ class ReforceXY(BaseReinforcementLearningModel):
             and falling prices in Short positions.
             The actions Sell/Buy or Hold during a Long position trigger the sell/buy-fee.
             """
+            if self._current_tick <= 0:
+                return 0.0
             if self._position == Positions.Long:
                 current_price = self.prices.iloc[self._current_tick].open
                 previous_price = self.prices.iloc[self._current_tick - 1].open
@@ -1131,7 +1155,7 @@ class ReforceXY(BaseReinforcementLearningModel):
 
             _rollout_history = _history_df.merge(
                 _trade_history_df, on="tick", how="left"
-            )
+            ).fillna(method="ffill")
             _price_history = (
                 self.prices.iloc[_rollout_history.tick].copy().reset_index()
             )
@@ -1425,13 +1449,13 @@ def get_net_arch(
             "medium": {"pi": [256, 256], "vf": [256, 256]},
             "large": {"pi": [512, 512], "vf": [512, 512]},
             "extra_large": {"pi": [1024, 1024], "vf": [1024, 1024]},
-        }[net_arch_type]
+        }.get(net_arch_type, {"pi": [128, 128], "vf": [128, 128]})
     return {
         "small": [128, 128],
         "medium": [256, 256],
         "large": [512, 512],
         "extra_large": [1024, 1024],
-    }[net_arch_type]
+    }.get(net_arch_type, [128, 128])
 
 
 def get_activation_fn(activation_fn_name: str) -> type[th.nn.Module]:
@@ -1443,7 +1467,7 @@ def get_activation_fn(activation_fn_name: str) -> type[th.nn.Module]:
         "relu": th.nn.ReLU,
         "elu": th.nn.ELU,
         "leaky_relu": th.nn.LeakyReLU,
-    }[activation_fn_name]
+    }.get(activation_fn_name, th.nn.ReLU)
 
 
 def get_optimizer_class(optimizer_class_name: str) -> type[th.optim.Optimizer]:
@@ -1453,7 +1477,7 @@ def get_optimizer_class(optimizer_class_name: str) -> type[th.optim.Optimizer]:
     return {
         "adam": th.optim.Adam,
         "rmsprop": th.optim.RMSprop,
-    }[optimizer_class_name]
+    }.get(optimizer_class_name, th.optim.Adam)
 
 
 def sample_params_ppo(trial: Trial) -> Dict[str, Any]:
index 6e30d7302f0ba492dff9ef31284455e5c8e6d281..5bb6ebc5ac7f2f9404f8dfb10471123cf18a5fe6 100644 (file)
@@ -251,7 +251,7 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel):
 
         return eval_set, eval_weights
 
-    def optuna_storage(self, pair: str) -> optuna.storages.BaseStorage:
+    def optuna_storage(self, pair: str) -> optuna.storages.BaseStorage | None:
         storage_dir = str(self.full_path)
         storage_filename = f"optuna-{pair.split('/')[0]}"
         storage_backend = self.__optuna_config.get("storage", "file")
@@ -278,7 +278,7 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel):
             "quantile": self.quantile_min_max_pred,
             "mean": mean_min_max_pred,
             "median": median_min_max_pred,
-        }[prediction_thresholds_smoothing](
+        }.get(prediction_thresholds_smoothing, mean_min_max_pred)(
             pred_df, fit_live_predictions_candles, label_period_candles
         )
 
index 5330cc994614911ad06669fd874c2409281f6dad..010a3b921c7fd2d1da99bd1feb27c5b7370c4980 100644 (file)
@@ -252,7 +252,7 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel):
 
         return eval_set, eval_weights
 
-    def optuna_storage(self, pair: str) -> optuna.storages.BaseStorage:
+    def optuna_storage(self, pair: str) -> optuna.storages.BaseStorage | None:
         storage_dir = str(self.full_path)
         storage_filename = f"optuna-{pair.split('/')[0]}"
         storage_backend = self.__optuna_config.get("storage", "file")
@@ -279,7 +279,7 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel):
             "quantile": self.quantile_min_max_pred,
             "mean": mean_min_max_pred,
             "median": median_min_max_pred,
-        }[prediction_thresholds_smoothing](
+        }.get(prediction_thresholds_smoothing, mean_min_max_pred)(
             pred_df, fit_live_predictions_candles, label_period_candles
         )
 
index d65c04a408a11b9346a77fac23379402b2dac1a7..cede49c002bc95579403aef23c38d4d601904741 100644 (file)
@@ -429,7 +429,12 @@ class QuickAdapterV3(IStrategy):
             ),
             "ewma": series.ewm(span=window).mean(),
             "zlewma": zlewma(series, length=window),
-        }[extrema_smoothing]
+        }.get(
+            extrema_smoothing,
+            series.rolling(window=window, win_type="gaussian", center=center).mean(
+                std=std
+            ),
+        )
 
 
 def top_percent_change(dataframe: DataFrame, length: int) -> Series: