]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
perf(reforcexy): idle duration aware reward
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Tue, 16 Sep 2025 15:02:45 +0000 (17:02 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Tue, 16 Sep 2025 15:02:45 +0000 (17:02 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index bc2d93d07da71505912f50b8946ceece36846ab4..d8b099bea70a8097602348e5d825cec92b4e227f 100644 (file)
@@ -452,19 +452,19 @@ class ReforceXY(BaseReinforcementLearningModel):
         else:
             tensorboard_log_path = None
 
-        if dk.pair not in self.dd.model_dictionary or not self.continual_learning:
+        model = self.get_init_model(dk.pair)
+        if model is not None:
+            logger.info(
+                "Continual training activated: starting training from previously trained model state"
+            )
+            model.set_env(self.train_env)
+        else:
             model = self.MODELCLASS(
                 self.policy_type,
                 self.train_env,
                 tensorboard_log=tensorboard_log_path,
                 **model_params,
             )
-        else:
-            logger.info(
-                "Continual training activated: starting training from previously trained agent"
-            )
-            model = self.dd.model_dictionary[dk.pair]
-            model.set_env(self.train_env)
 
         callbacks = self.get_callbacks(
             max(1, train_timesteps // self.n_envs), str(dk.data_path)
@@ -481,35 +481,23 @@ class ReforceXY(BaseReinforcementLearningModel):
         self.dd.update_metric_tracker("fit_time", time_spent, dk.pair)
 
         model_filename = dk.model_filename if dk.model_filename else "best"
-        model_path = Path(dk.data_path / f"{model_filename}_model.zip")
-        if model_path.is_file():
-            logger.info(f"Callback found a best model: {model_path}")
+        model_filepath = Path(dk.data_path / f"{model_filename}_model.zip")
+        if model_filepath.is_file():
+            logger.info("Found best model at %s", model_filepath)
             try:
                 best_model = self.MODELCLASS.load(
                     dk.data_path / f"{model_filename}_model"
                 )
                 return best_model
             except Exception as e:
-                logger.error(f"Error loading best model: {repr(e)}", exc_info=True)
+                logger.error(f"Error at loading best model: {repr(e)}", exc_info=True)
 
-        logger.info("Couldn't find best model, using final model instead")
+        logger.info(
+            "Could not find best model at %s, using final model instead", model_filepath
+        )
 
         return model
 
-    def get_state_info(self, pair: str) -> Tuple[float, float, int]:
-        """
-        State info during dry/live (not backtesting) which is fed back
-        into the model.
-        :param pair: str = COIN/STAKE to get the environment information for
-        :return:
-        :market_side: float = representing short, long, or neutral for pair
-        :current_profit: float = unrealized profit of the current trade
-        :trade_duration: int = the number of candles that the trade has been open for
-        """
-        # STATE_INFO
-        position, pnl, trade_duration = super().get_state_info(pair)
-        return position, pnl, int(trade_duration)
-
     def rl_model_predict(
         self, dataframe: DataFrame, dk: FreqaiDataKitchen, model: Any
     ) -> DataFrame:
@@ -600,8 +588,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             return int(action)
 
         actions = dataframe.iloc[:, 0].rolling(window=self.CONV_WIDTH).apply(_predict)
-        output = DataFrame({label: actions for label in dk.label_list})
-        return output
+        return DataFrame({label: actions for label in dk.label_list})
 
     def get_storage(self, pair: Optional[str] = None) -> BaseStorage:
         """
@@ -757,7 +744,7 @@ class ReforceXY(BaseReinforcementLearningModel):
         log_msg: str = (
             f"{pair}: saving best params to {best_trial_params_path} JSON file"
             if pair
-            else f"saving best params to {best_trial_params_path} JSON file"
+            else f"Saving best params to {best_trial_params_path} JSON file"
         )
         logger.info(log_msg)
         try:
@@ -787,7 +774,7 @@ class ReforceXY(BaseReinforcementLearningModel):
         log_msg: str = (
             f"{pair}: loading best params from {best_trial_params_path} JSON file"
             if pair
-            else f"loading best params from {best_trial_params_path} JSON file"
+            else f"Loading best params from {best_trial_params_path} JSON file"
         )
         if best_trial_params_path.is_file():
             logger.info(log_msg)
@@ -807,7 +794,7 @@ class ReforceXY(BaseReinforcementLearningModel):
         Defines a single trial for hyperparameter optimization using Optuna
         """
         if self.train_env is None or self.eval_env is None:
-            raise RuntimeError("Environments not set. Cannot run HPO model training")
+            raise RuntimeError("Environments not set. Cannot run HPO trial")
         if "PPO" in self.model_type:
             params = sample_params_ppo(trial, self.n_envs)
             if params.get("n_steps", 0) > total_timesteps:
@@ -889,6 +876,7 @@ class ReforceXY(BaseReinforcementLearningModel):
 
         def __init__(self, **kwargs):
             super().__init__(**kwargs)
+            self._set_observation_space()
             self.force_actions: bool = self.rl_config.get("force_actions", False)
             self._force_action: Optional[ForceActions] = None
             self.take_profit: float = self.config.get("minimal_roi", {}).get("0", 0.03)
@@ -907,6 +895,24 @@ class ReforceXY(BaseReinforcementLearningModel):
                     self.observation_space,
                 )
 
+        def _set_observation_space(self) -> None:
+            """
+            Set the observation space
+            """
+            signal_features = self.signal_features.shape[1]
+            if self.add_state_info:
+                # STATE_INFO
+                self.state_features = ["pnl", "position", "trade_duration"]
+                self.total_features = signal_features + len(self.state_features)
+            else:
+                self.state_features = []
+                self.total_features = signal_features
+
+            self.shape = (self.window_size, self.total_features)
+            self.observation_space = Box(
+                low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float32
+            )
+
         def reset_env(
             self,
             df: DataFrame,
@@ -919,19 +925,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             Resets the environment when the agent fails
             """
             super().reset_env(df, prices, window_size, reward_kwargs, starting_point)
-            base_features = self.signal_features.shape[1]
-            if self.add_state_info:
-                # STATE_INFO
-                self.state_features = ["pnl", "position", "trade_duration"]
-                self.total_features = base_features + len(self.state_features)
-            else:
-                self.state_features = []
-                self.total_features = base_features
-
-            self.shape = (window_size, self.total_features)
-            self.observation_space = Box(
-                low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float32
-            )
+            self._set_observation_space()
 
         def reset(
             self, seed=None, **kwargs
@@ -1030,13 +1024,10 @@ class ReforceXY(BaseReinforcementLearningModel):
 
             # discourage agent from sitting idle too long
             if action == Actions.Neutral.value and self._position == Positions.Neutral:
-                return float(
-                    self.rl_config.get("model_reward_parameters", {}).get(
-                        "inaction", -1.0
-                    )
-                )
+                idle_duration = self.get_idle_duration()
+                return float(-0.01 * idle_duration**1.05)
 
-            # pnl and duration aware reward for sitting in position
+            # pnl and duration aware agent reward while sitting in position
             if (
                 self._position in (Positions.Short, Positions.Long)
                 and action == Actions.Neutral.value
@@ -1306,8 +1297,8 @@ class ReforceXY(BaseReinforcementLearningModel):
             if self._current_tick <= 0:
                 return 0.0
             if self._position == Positions.Long:
-                current_price = self.prices.iloc[self._current_tick].get("open")
-                previous_price = self.prices.iloc[self._current_tick - 1].get("open")
+                current_price = self.current_price()
+                previous_price = self.previous_price()
                 if (
                     self._position_history[self._current_tick - 1] == Positions.Short
                     or self._position_history[self._current_tick - 1]
@@ -1316,8 +1307,8 @@ class ReforceXY(BaseReinforcementLearningModel):
                     previous_price = self.add_entry_fee(previous_price)
                 return np.log(current_price) - np.log(previous_price)
             if self._position == Positions.Short:
-                current_price = self.prices.iloc[self._current_tick].get("open")
-                previous_price = self.prices.iloc[self._current_tick - 1].get("open")
+                current_price = self.current_price()
+                previous_price = self.previous_price()
                 if (
                     self._position_history[self._current_tick - 1] == Positions.Long
                     or self._position_history[self._current_tick - 1]
@@ -1327,6 +1318,11 @@ class ReforceXY(BaseReinforcementLearningModel):
                 return np.log(previous_price) - np.log(current_price)
             return 0.0
 
+        def update_portfolio_log_returns(self):
+            self.portfolio_log_returns[self._current_tick] = (
+                self.get_most_recent_return()
+            )
+
         def get_most_recent_profit(self) -> float:
             """
             Calculate the tick to tick unrealized profit if in a trade