From 413c1d8053e365fc8ff8c3c31022b72e487af027 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Thu, 27 Feb 2025 15:58:33 +0100 Subject: [PATCH] fix(refactorxy): ensure hyperopt params are merged with user defined ones MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/user_data/freqaimodels/ReforceXY.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index 11c19d6..ab0594b 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -479,7 +479,7 @@ class ReforceXY(BaseReinforcementLearningModel): ) -> Dict: """ Runs hyperparameter optimization using Optuna and - returns the best hyperparameters found + returns the best hyperparameters found merged with the user defined parameters """ _, identifier = str(self.full_path).rsplit("/", 1) if self.rl_config_optuna.get("per_pair", False): @@ -536,7 +536,7 @@ class ReforceXY(BaseReinforcementLearningModel): dk.pair if self.rl_config_optuna.get("per_pair", False) else None, ) - return study.best_trial.params + return {**self.model_training_parameters, **study.best_trial.params} def save_best_params(self, best_params: Dict, pair: str | None = None) -> None: """ @@ -794,7 +794,9 @@ class ReforceXY(BaseReinforcementLearningModel): action in (Actions.Long_enter.value, Actions.Short_enter.value) and self._position == Positions.Neutral ): - return 25.0 + return self.rl_config.get("model_reward_parameters", {}).get( + "enter_action", 25.0 + ) # discourage agent from not entering trades if action == Actions.Neutral.value and self._position == Positions.Neutral: @@ -809,9 +811,9 @@ class ReforceXY(BaseReinforcementLearningModel): self._position in (Positions.Short, Positions.Long) and action == Actions.Neutral.value ): - peak_pnl = max(self.get_most_recent_peak_pnl(), pnl) - if peak_pnl > 0: - drawdown_penalty = 0.01 * factor * (peak_pnl - pnl) + max_pnl = max(self.get_most_recent_max_pnl(), pnl) + if max_pnl > 0: + drawdown_penalty = 0.01 * factor * (max_pnl - pnl) else: drawdown_penalty = 0.0 lambda1 = 0.05 @@ -1042,7 +1044,7 @@ class ReforceXY(BaseReinforcementLearningModel): return self._current_tick - self._start_tick return self._current_tick - self._last_closed_trade_tick - def get_most_recent_peak_pnl(self) -> float: + def get_most_recent_max_pnl(self) -> float: return ( np.max(self.history.get("pnl")) if self.history.get("pnl") else -np.inf ) -- 2.43.0