]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
fix(refactorxy): ensure hyperopt params are merged with user defined
authorJérôme Benoit <jerome.benoit@sap.com>
Thu, 27 Feb 2025 14:58:33 +0000 (15:58 +0100)
committerJérôme Benoit <jerome.benoit@sap.com>
Thu, 27 Feb 2025 14:58:33 +0000 (15:58 +0100)
 ones

Signed-off-by: Jérôme Benoit <jerome.benoit@sap.com>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index 11c19d6f159698a58adc30f3d754599aa05d1987..ab0594b73a38c08d0e484c4c0e2224ab61a6dbd5 100644 (file)
@@ -479,7 +479,7 @@ class ReforceXY(BaseReinforcementLearningModel):
     ) -> Dict:
         """
         Runs hyperparameter optimization using Optuna and
-        returns the best hyperparameters found
+        returns the best hyperparameters found merged with the user defined parameters
         """
         _, identifier = str(self.full_path).rsplit("/", 1)
         if self.rl_config_optuna.get("per_pair", False):
@@ -536,7 +536,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             dk.pair if self.rl_config_optuna.get("per_pair", False) else None,
         )
 
-        return study.best_trial.params
+        return {**self.model_training_parameters, **study.best_trial.params}
 
     def save_best_params(self, best_params: Dict, pair: str | None = None) -> None:
         """
@@ -794,7 +794,9 @@ class ReforceXY(BaseReinforcementLearningModel):
                 action in (Actions.Long_enter.value, Actions.Short_enter.value)
                 and self._position == Positions.Neutral
             ):
-                return 25.0
+                return self.rl_config.get("model_reward_parameters", {}).get(
+                    "enter_action", 25.0
+                )
 
             # discourage agent from not entering trades
             if action == Actions.Neutral.value and self._position == Positions.Neutral:
@@ -809,9 +811,9 @@ class ReforceXY(BaseReinforcementLearningModel):
                 self._position in (Positions.Short, Positions.Long)
                 and action == Actions.Neutral.value
             ):
-                peak_pnl = max(self.get_most_recent_peak_pnl(), pnl)
-                if peak_pnl > 0:
-                    drawdown_penalty = 0.01 * factor * (peak_pnl - pnl)
+                max_pnl = max(self.get_most_recent_max_pnl(), pnl)
+                if max_pnl > 0:
+                    drawdown_penalty = 0.01 * factor * (max_pnl - pnl)
                 else:
                     drawdown_penalty = 0.0
                 lambda1 = 0.05
@@ -1042,7 +1044,7 @@ class ReforceXY(BaseReinforcementLearningModel):
                 return self._current_tick - self._start_tick
             return self._current_tick - self._last_closed_trade_tick
 
-        def get_most_recent_peak_pnl(self) -> float:
+        def get_most_recent_max_pnl(self) -> float:
             return (
                 np.max(self.history.get("pnl")) if self.history.get("pnl") else -np.inf
             )