From 18ca2d680c59c4500a73453927d57cadb7e126f2 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Sat, 8 Mar 2025 18:43:43 +0100 Subject: [PATCH] fix(reforcexy): properly stack observations MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/user_data/config-template.json | 4 +- ReforceXY/user_data/freqaimodels/ReforceXY.py | 94 ++++++++++++------- .../LightGBMRegressorQuickAdapterV35.py | 4 +- .../XGBoostRegressorQuickAdapterV35.py | 4 +- .../user_data/strategies/QuickAdapterV3.py | 7 +- 5 files changed, 71 insertions(+), 42 deletions(-) diff --git a/ReforceXY/user_data/config-template.json b/ReforceXY/user_data/config-template.json index 92cbf82..15237b0 100644 --- a/ReforceXY/user_data/config-template.json +++ b/ReforceXY/user_data/config-template.json @@ -164,14 +164,14 @@ "profit_aim": 0.025, "win_reward_factor": 2 }, - "train_cycles": 250, + "train_cycles": 25, "add_state_info": true, "cpu_count": 6, "max_training_drawdown_pct": 0.02, "max_trade_duration_candles": 96, // Timeout exit value used with force_actions "force_actions": false, // Utilize minimal_roi, stoploss, and max_trade_duration_candles as TP/SL/Timeout in the environment "n_envs": 32, // Number of DummyVecEnv environments - "frame_staking": 4, // Number of VecFrameStack stacks (set > 1 to use) + "frame_stacking": 4, // Number of VecFrameStack stacks (set > 1 to use) "lr_schedule": false, // Enable learning rate linear schedule "cr_schedule": false, // Enable clip range linear schedule "max_no_improvement_evals": 0, // Maximum consecutive evaluations without a new best model diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index 492b9c9..3ab6f4d 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -1,3 +1,4 @@ +from collections import deque import copy import gc import json @@ -74,7 +75,7 @@ class ReforceXY(BaseReinforcementLearningModel): "max_trade_duration_candles": 96, // Timeout exit value used with force_actions "force_actions": false, // Utilize minimal_roi, stoploss, and max_trade_duration_candles as TP/SL/Timeout in the environment "n_envs": 1, // Number of DummyVecEnv environments - "frame_staking": 0, // Number of VecFrameStack stacks (set > 1 to use) + "frame_stacking": 0, // Number of VecFrameStack stacks (set > 1 to use) "lr_schedule": false, // Enable learning rate linear schedule "cr_schedule": false, // Enable clip range linear schedule "max_no_improvement_evals": 0, // Maximum consecutive evaluations without a new best model @@ -105,14 +106,14 @@ class ReforceXY(BaseReinforcementLearningModel): raise ValueError( "FreqAI model requires StaticPairList method defined in pairlists configuration and pair_whitelist defined in exchange section configuration" ) + self.observations_buffer: Dict[str, deque] = {} self.is_maskable: bool = ( self.model_type == "MaskablePPO" ) # Enable action masking self.lr_schedule: bool = self.rl_config.get("lr_schedule", False) self.cr_schedule: bool = self.rl_config.get("cr_schedule", False) self.n_envs: int = self.rl_config.get("n_envs", 1) - self.frame_staking: int = self.rl_config.get("frame_staking", 0) - self.frame_staking += 1 if self.frame_staking == 1 else 0 + self.frame_stacking: int = self.rl_config.get("frame_stacking", 0) self.max_no_improvement_evals: int = self.rl_config.get( "max_no_improvement_evals", 0 ) @@ -140,9 +141,9 @@ class ReforceXY(BaseReinforcementLearningModel): If user has activated any custom function that may conflict, this function will set them to false and warn them """ - if self.continual_learning and self.frame_staking: + if self.continual_learning and self.frame_stacking: logger.warning( - "User tried to use continual_learning with frame_staking. \ + "User tried to use continual_learning with frame_stacking. \ Deactivating continual_learning" ) self.continual_learning = False @@ -206,12 +207,14 @@ class ReforceXY(BaseReinforcementLearningModel): for i in range(self.n_envs) ] ) - if self.frame_staking: - logger.info("Frame staking: %s", self.frame_staking) - train_env = VecFrameStack(train_env, n_stack=self.frame_staking) - eval_env = VecFrameStack(eval_env, n_stack=self.frame_staking) + if self.frame_stacking: + logger.info("Frame stacking: %s", self.frame_stacking) + train_env = VecFrameStack(train_env, n_stack=self.frame_stacking) + eval_env = VecFrameStack(eval_env, n_stack=self.frame_stacking) self.train_env = VecMonitor(train_env) + if self.frame_stacking and not self.train_env.observation_space.shape: + raise ValueError("Frame stacking requires predefined observation shape") self.eval_env = VecMonitor(eval_env) def get_model_params(self) -> Dict: @@ -254,7 +257,7 @@ class ReforceXY(BaseReinforcementLearningModel): def get_callbacks( self, eval_freq: int, data_path: str, trial: Trial = None - ) -> list: + ) -> list[BaseCallback]: """ Get the model specific callbacks """ @@ -263,8 +266,7 @@ class ReforceXY(BaseReinforcementLearningModel): rollout_plot_callback = None verbose = self.model_training_parameters.get("verbose", 0) - if self.n_envs > 1: - eval_freq //= self.n_envs + eval_freq //= self.n_envs if self.plot_new_best: rollout_plot_callback = RolloutPlotCallback(verbose=verbose) @@ -325,17 +327,18 @@ class ReforceXY(BaseReinforcementLearningModel): train_df = data_dictionary["train_features"] train_timesteps = len(train_df) test_timesteps = len(data_dictionary["test_features"]) - train_cycles = int(self.rl_config.get("train_cycles", 250)) - total_timesteps = train_timesteps * train_cycles + train_cycles = max(1, int(self.rl_config.get("train_cycles", 25))) + total_timesteps = train_timesteps * train_cycles * self.n_envs train_days = steps_to_days(train_timesteps, self.config["timeframe"]) total_days = steps_to_days(total_timesteps, self.config["timeframe"]) logger.info("Action masking: %s", self.is_maskable) logger.info( - "Train: %s steps (%s days) * %s cycles = Total %s (%s days)", + "Train: %s steps (%s days) * %s cycles * %s environments = Total %s (%s days)", train_timesteps, train_days, train_cycles, + self.n_envs, total_timesteps, total_days, ) @@ -380,6 +383,8 @@ class ReforceXY(BaseReinforcementLearningModel): finally: if self.progressbar_callback: self.progressbar_callback.on_training_end() + self.close_envs() + model.env.close() time_spent = time.time() - start self.dd.update_metric_tracker("fit_time", time_spent, dk.pair) @@ -387,8 +392,13 @@ class ReforceXY(BaseReinforcementLearningModel): model_path = Path(dk.data_path / f"{model_filename}_model.zip") if model_path.is_file(): logger.info(f"Callback found a best model: {model_path}.") - best_model = self.MODELCLASS.load(dk.data_path / f"{model_filename}_model") - return best_model + try: + best_model = self.MODELCLASS.load( + dk.data_path / f"{model_filename}_model" + ) + return best_model + except Exception as e: + logger.error(f"Error loading best model: {e}", exc_info=True) logger.info("Couldn't find best model, using final model instead.") @@ -419,6 +429,12 @@ class ReforceXY(BaseReinforcementLearningModel): :param dk: FreqaiDatakitchen = data kitchen for the current pair :param model: Any = the trained model used to inference the features. """ + if not self.observations_buffer.get(dk.pair): + buffer_size = max(1, self.frame_stacking) + initial_observation = dataframe.iloc[0].to_numpy(dtype=np.float32) + self.observations_buffer[dk.pair] = deque( + [initial_observation] * buffer_size, maxlen=buffer_size + ) def _is_valid(action: int, position: float) -> bool: """ @@ -440,28 +456,29 @@ class ReforceXY(BaseReinforcementLearningModel): return [_is_valid(action.value, position) for action in Actions] def _predict(window): - observations: DataFrame = dataframe.iloc[window.index] + observation: DataFrame = dataframe.iloc[window.index] action_masks_param: dict = {} if self.live and self.rl_config.get("add_state_info", False): position, pnl, trade_duration = self.get_state_info(dk.pair) # STATE_INFO - observations["pnl"] = pnl - observations["position"] = position - observations["trade_duration"] = trade_duration + observation["pnl"] = pnl + observation["position"] = position + observation["trade_duration"] = trade_duration if self.is_maskable: action_masks_param = {"action_masks": _action_masks(position)} - observations = observations.to_numpy(dtype=np.float32) + observation = observation.to_numpy(dtype=np.float32) - if self.frame_staking: - observations = np.repeat( - observations, axis=1, repeats=self.frame_staking - ) + self.observations_buffer[dk.pair].append(observation) + + stacked_observations = np.concatenate( + self.observations_buffer[dk.pair], axis=1 + ) action, _ = model.predict( - observations, deterministic=True, **action_masks_param + stacked_observations, deterministic=True, **action_masks_param ) return action @@ -469,7 +486,7 @@ class ReforceXY(BaseReinforcementLearningModel): output = output.rolling(window=self.CONV_WIDTH).apply(_predict) return output - def get_storage(self, pair: str | None = None) -> BaseStorage: + def get_storage(self, pair: str | None = None) -> BaseStorage | None: """ Get the storage for Optuna """ @@ -498,6 +515,7 @@ class ReforceXY(BaseReinforcementLearningModel): else: study_name = identifier storage = self.get_storage() + eval_freq = len(train_df) // self.n_envs study: Study = create_study( study_name=study_name, sampler=TPESampler( @@ -506,7 +524,9 @@ class ReforceXY(BaseReinforcementLearningModel): group=True, ), pruner=HyperbandPruner( - min_resource=1, max_resource=self.optuna_n_trials, reduction_factor=3 + min_resource=3, + max_resource=total_timesteps // eval_freq, + reduction_factor=3, ), direction=StudyDirection.MAXIMIZE, storage=storage, @@ -601,6 +621,8 @@ class ReforceXY(BaseReinforcementLearningModel): """ if "PPO" in self.model_type: params = sample_params_ppo(trial) + if params.get("n_steps", 0) > total_timesteps: + raise TrialPruned("n_steps exceeds total_timesteps") elif "QRDQN" in self.model_type: params = sample_params_qrdqn(trial) elif "DQN" in self.model_type: @@ -629,8 +651,8 @@ class ReforceXY(BaseReinforcementLearningModel): tensorboard_log=tensorboard_log_path, **params, ) - callbacks = self.get_callbacks(len(train_df), str(dk.data_path), trial) + callbacks = self.get_callbacks(len(train_df), str(dk.data_path), trial) try: model.learn(total_timesteps=total_timesteps, callback=callbacks) except AssertionError: @@ -1068,6 +1090,8 @@ class ReforceXY(BaseReinforcementLearningModel): and falling prices in Short positions. The actions Sell/Buy or Hold during a Long position trigger the sell/buy-fee. """ + if self._current_tick <= 0: + return 0.0 if self._position == Positions.Long: current_price = self.prices.iloc[self._current_tick].open previous_price = self.prices.iloc[self._current_tick - 1].open @@ -1131,7 +1155,7 @@ class ReforceXY(BaseReinforcementLearningModel): _rollout_history = _history_df.merge( _trade_history_df, on="tick", how="left" - ) + ).fillna(method="ffill") _price_history = ( self.prices.iloc[_rollout_history.tick].copy().reset_index() ) @@ -1425,13 +1449,13 @@ def get_net_arch( "medium": {"pi": [256, 256], "vf": [256, 256]}, "large": {"pi": [512, 512], "vf": [512, 512]}, "extra_large": {"pi": [1024, 1024], "vf": [1024, 1024]}, - }[net_arch_type] + }.get(net_arch_type, {"pi": [128, 128], "vf": [128, 128]}) return { "small": [128, 128], "medium": [256, 256], "large": [512, 512], "extra_large": [1024, 1024], - }[net_arch_type] + }.get(net_arch_type, [128, 128]) def get_activation_fn(activation_fn_name: str) -> type[th.nn.Module]: @@ -1443,7 +1467,7 @@ def get_activation_fn(activation_fn_name: str) -> type[th.nn.Module]: "relu": th.nn.ReLU, "elu": th.nn.ELU, "leaky_relu": th.nn.LeakyReLU, - }[activation_fn_name] + }.get(activation_fn_name, th.nn.ReLU) def get_optimizer_class(optimizer_class_name: str) -> type[th.optim.Optimizer]: @@ -1453,7 +1477,7 @@ def get_optimizer_class(optimizer_class_name: str) -> type[th.optim.Optimizer]: return { "adam": th.optim.Adam, "rmsprop": th.optim.RMSprop, - }[optimizer_class_name] + }.get(optimizer_class_name, th.optim.Adam) def sample_params_ppo(trial: Trial) -> Dict[str, Any]: diff --git a/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py b/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py index 6e30d73..5bb6ebc 100644 --- a/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py +++ b/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py @@ -251,7 +251,7 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel): return eval_set, eval_weights - def optuna_storage(self, pair: str) -> optuna.storages.BaseStorage: + def optuna_storage(self, pair: str) -> optuna.storages.BaseStorage | None: storage_dir = str(self.full_path) storage_filename = f"optuna-{pair.split('/')[0]}" storage_backend = self.__optuna_config.get("storage", "file") @@ -278,7 +278,7 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel): "quantile": self.quantile_min_max_pred, "mean": mean_min_max_pred, "median": median_min_max_pred, - }[prediction_thresholds_smoothing]( + }.get(prediction_thresholds_smoothing, mean_min_max_pred)( pred_df, fit_live_predictions_candles, label_period_candles ) diff --git a/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py b/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py index 5330cc9..010a3b9 100644 --- a/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py +++ b/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py @@ -252,7 +252,7 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel): return eval_set, eval_weights - def optuna_storage(self, pair: str) -> optuna.storages.BaseStorage: + def optuna_storage(self, pair: str) -> optuna.storages.BaseStorage | None: storage_dir = str(self.full_path) storage_filename = f"optuna-{pair.split('/')[0]}" storage_backend = self.__optuna_config.get("storage", "file") @@ -279,7 +279,7 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel): "quantile": self.quantile_min_max_pred, "mean": mean_min_max_pred, "median": median_min_max_pred, - }[prediction_thresholds_smoothing]( + }.get(prediction_thresholds_smoothing, mean_min_max_pred)( pred_df, fit_live_predictions_candles, label_period_candles ) diff --git a/quickadapter/user_data/strategies/QuickAdapterV3.py b/quickadapter/user_data/strategies/QuickAdapterV3.py index d65c04a..cede49c 100644 --- a/quickadapter/user_data/strategies/QuickAdapterV3.py +++ b/quickadapter/user_data/strategies/QuickAdapterV3.py @@ -429,7 +429,12 @@ class QuickAdapterV3(IStrategy): ), "ewma": series.ewm(span=window).mean(), "zlewma": zlewma(series, length=window), - }[extrema_smoothing] + }.get( + extrema_smoothing, + series.rolling(window=window, win_type="gaussian", center=center).mean( + std=std + ), + ) def top_percent_change(dataframe: DataFrame, length: int) -> Series: -- 2.43.0