From: Jérôme Benoit Date: Tue, 16 Sep 2025 15:02:45 +0000 (+0200) Subject: perf(reforcexy): idle duration aware reward X-Git-Url: https://git.piment-noir.org/?a=commitdiff_plain;h=46abc6be5b41a67cad5225dab231d1a601390e3e;p=freqai-strategies.git perf(reforcexy): idle duration aware reward Signed-off-by: Jérôme Benoit --- diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index bc2d93d..d8b099b 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -452,19 +452,19 @@ class ReforceXY(BaseReinforcementLearningModel): else: tensorboard_log_path = None - if dk.pair not in self.dd.model_dictionary or not self.continual_learning: + model = self.get_init_model(dk.pair) + if model is not None: + logger.info( + "Continual training activated: starting training from previously trained model state" + ) + model.set_env(self.train_env) + else: model = self.MODELCLASS( self.policy_type, self.train_env, tensorboard_log=tensorboard_log_path, **model_params, ) - else: - logger.info( - "Continual training activated: starting training from previously trained agent" - ) - model = self.dd.model_dictionary[dk.pair] - model.set_env(self.train_env) callbacks = self.get_callbacks( max(1, train_timesteps // self.n_envs), str(dk.data_path) @@ -481,35 +481,23 @@ class ReforceXY(BaseReinforcementLearningModel): self.dd.update_metric_tracker("fit_time", time_spent, dk.pair) model_filename = dk.model_filename if dk.model_filename else "best" - model_path = Path(dk.data_path / f"{model_filename}_model.zip") - if model_path.is_file(): - logger.info(f"Callback found a best model: {model_path}") + model_filepath = Path(dk.data_path / f"{model_filename}_model.zip") + if model_filepath.is_file(): + logger.info("Found best model at %s", model_filepath) try: best_model = self.MODELCLASS.load( dk.data_path / f"{model_filename}_model" ) return best_model except Exception as e: - logger.error(f"Error loading best model: {repr(e)}", exc_info=True) + logger.error(f"Error at loading best model: {repr(e)}", exc_info=True) - logger.info("Couldn't find best model, using final model instead") + logger.info( + "Could not find best model at %s, using final model instead", model_filepath + ) return model - def get_state_info(self, pair: str) -> Tuple[float, float, int]: - """ - State info during dry/live (not backtesting) which is fed back - into the model. - :param pair: str = COIN/STAKE to get the environment information for - :return: - :market_side: float = representing short, long, or neutral for pair - :current_profit: float = unrealized profit of the current trade - :trade_duration: int = the number of candles that the trade has been open for - """ - # STATE_INFO - position, pnl, trade_duration = super().get_state_info(pair) - return position, pnl, int(trade_duration) - def rl_model_predict( self, dataframe: DataFrame, dk: FreqaiDataKitchen, model: Any ) -> DataFrame: @@ -600,8 +588,7 @@ class ReforceXY(BaseReinforcementLearningModel): return int(action) actions = dataframe.iloc[:, 0].rolling(window=self.CONV_WIDTH).apply(_predict) - output = DataFrame({label: actions for label in dk.label_list}) - return output + return DataFrame({label: actions for label in dk.label_list}) def get_storage(self, pair: Optional[str] = None) -> BaseStorage: """ @@ -757,7 +744,7 @@ class ReforceXY(BaseReinforcementLearningModel): log_msg: str = ( f"{pair}: saving best params to {best_trial_params_path} JSON file" if pair - else f"saving best params to {best_trial_params_path} JSON file" + else f"Saving best params to {best_trial_params_path} JSON file" ) logger.info(log_msg) try: @@ -787,7 +774,7 @@ class ReforceXY(BaseReinforcementLearningModel): log_msg: str = ( f"{pair}: loading best params from {best_trial_params_path} JSON file" if pair - else f"loading best params from {best_trial_params_path} JSON file" + else f"Loading best params from {best_trial_params_path} JSON file" ) if best_trial_params_path.is_file(): logger.info(log_msg) @@ -807,7 +794,7 @@ class ReforceXY(BaseReinforcementLearningModel): Defines a single trial for hyperparameter optimization using Optuna """ if self.train_env is None or self.eval_env is None: - raise RuntimeError("Environments not set. Cannot run HPO model training") + raise RuntimeError("Environments not set. Cannot run HPO trial") if "PPO" in self.model_type: params = sample_params_ppo(trial, self.n_envs) if params.get("n_steps", 0) > total_timesteps: @@ -889,6 +876,7 @@ class ReforceXY(BaseReinforcementLearningModel): def __init__(self, **kwargs): super().__init__(**kwargs) + self._set_observation_space() self.force_actions: bool = self.rl_config.get("force_actions", False) self._force_action: Optional[ForceActions] = None self.take_profit: float = self.config.get("minimal_roi", {}).get("0", 0.03) @@ -907,6 +895,24 @@ class ReforceXY(BaseReinforcementLearningModel): self.observation_space, ) + def _set_observation_space(self) -> None: + """ + Set the observation space + """ + signal_features = self.signal_features.shape[1] + if self.add_state_info: + # STATE_INFO + self.state_features = ["pnl", "position", "trade_duration"] + self.total_features = signal_features + len(self.state_features) + else: + self.state_features = [] + self.total_features = signal_features + + self.shape = (self.window_size, self.total_features) + self.observation_space = Box( + low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float32 + ) + def reset_env( self, df: DataFrame, @@ -919,19 +925,7 @@ class ReforceXY(BaseReinforcementLearningModel): Resets the environment when the agent fails """ super().reset_env(df, prices, window_size, reward_kwargs, starting_point) - base_features = self.signal_features.shape[1] - if self.add_state_info: - # STATE_INFO - self.state_features = ["pnl", "position", "trade_duration"] - self.total_features = base_features + len(self.state_features) - else: - self.state_features = [] - self.total_features = base_features - - self.shape = (window_size, self.total_features) - self.observation_space = Box( - low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float32 - ) + self._set_observation_space() def reset( self, seed=None, **kwargs @@ -1030,13 +1024,10 @@ class ReforceXY(BaseReinforcementLearningModel): # discourage agent from sitting idle too long if action == Actions.Neutral.value and self._position == Positions.Neutral: - return float( - self.rl_config.get("model_reward_parameters", {}).get( - "inaction", -1.0 - ) - ) + idle_duration = self.get_idle_duration() + return float(-0.01 * idle_duration**1.05) - # pnl and duration aware reward for sitting in position + # pnl and duration aware agent reward while sitting in position if ( self._position in (Positions.Short, Positions.Long) and action == Actions.Neutral.value @@ -1306,8 +1297,8 @@ class ReforceXY(BaseReinforcementLearningModel): if self._current_tick <= 0: return 0.0 if self._position == Positions.Long: - current_price = self.prices.iloc[self._current_tick].get("open") - previous_price = self.prices.iloc[self._current_tick - 1].get("open") + current_price = self.current_price() + previous_price = self.previous_price() if ( self._position_history[self._current_tick - 1] == Positions.Short or self._position_history[self._current_tick - 1] @@ -1316,8 +1307,8 @@ class ReforceXY(BaseReinforcementLearningModel): previous_price = self.add_entry_fee(previous_price) return np.log(current_price) - np.log(previous_price) if self._position == Positions.Short: - current_price = self.prices.iloc[self._current_tick].get("open") - previous_price = self.prices.iloc[self._current_tick - 1].get("open") + current_price = self.current_price() + previous_price = self.previous_price() if ( self._position_history[self._current_tick - 1] == Positions.Long or self._position_history[self._current_tick - 1] @@ -1327,6 +1318,11 @@ class ReforceXY(BaseReinforcementLearningModel): return np.log(previous_price) - np.log(current_price) return 0.0 + def update_portfolio_log_returns(self): + self.portfolio_log_returns[self._current_tick] = ( + self.get_most_recent_return() + ) + def get_most_recent_profit(self) -> float: """ Calculate the tick to tick unrealized profit if in a trade