From: Jérôme Benoit Date: Wed, 19 Feb 2025 00:46:45 +0000 (+0100) Subject: fix(reforcexy): uniformize rewarding behavior X-Git-Url: https://git.piment-noir.org/?a=commitdiff_plain;h=82ed0abb456e01b4abbd8770662e8820f26a16e7;p=freqai-strategies.git fix(reforcexy): uniformize rewarding behavior Signed-off-by: Jérôme Benoit --- diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index 0da09cb..0d9a576 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -402,22 +402,6 @@ class ReforceXY(BaseReinforcementLearningModel): :param model: Any = the trained model used to inference the features. """ - def _is_valid(action: int, position: float) -> bool: - return not ( - ( - action in (Actions.Short_enter.value, Actions.Long_enter.value) - and position != Positions.Neutral.value - ) - or ( - action == Actions.Long_exit.value - and position != Positions.Long.value - ) - or ( - action == Actions.Short_exit.value - and position != Positions.Short.value - ) - ) - def _action_masks(position: float): return [_is_valid(action.value, position) for action in Actions] @@ -451,12 +435,10 @@ class ReforceXY(BaseReinforcementLearningModel): output = output.rolling(window=self.CONV_WIDTH).apply(_predict) return output - def study(self, train_df, total_timesteps: int, dk: FreqaiDataKitchen) -> Dict: + def get_storage(self, dk: FreqaiDataKitchen): """ - Runs hyperparameter optimization using Optuna and - returns the best hyperparameters found + Get the storage for Optuna """ - study_name = str(dk.pair) storage_dir = str(dk.full_path) storage_backend = self.rl_config_optuna.get("storage", "sqlite") if storage_backend == "sqlite": @@ -465,6 +447,15 @@ class ReforceXY(BaseReinforcementLearningModel): storage = JournalStorage( JournalFileBackend(f"{storage_dir}/optuna-{dk.pair.split('/')[0]}.log") ) + return storage + + def study(self, train_df, total_timesteps: int, dk: FreqaiDataKitchen) -> Dict: + """ + Runs hyperparameter optimization using Optuna and + returns the best hyperparameters found + """ + study_name = str(dk.pair) + storage = self.get_storage(dk) study: Study = create_study( study_name=study_name, sampler=TPESampler( @@ -662,14 +653,34 @@ class ReforceXY(BaseReinforcementLearningModel): float = the reward to give to the agent for current step (used for optimization of weights in NN) """ + # first penalize if the action is not valid + if ( + self.force_actions + and self._force_action is not None + and self._force_action + not in ( + ForceActions.Take_profit, + ForceActions.Stop_loss, + ForceActions.Timeout, + ) + ) or not self._is_valid(action): + return -2 + pnl = self.get_unrealized_profit() # mrr = self.get_most_recent_return() # mrp = self.get_most_recent_profit() factor = 100.0 + max_trade_duration = self.rl_config.get("max_trade_duration_candles", 300) + trade_duration = self.get_trade_duration() + if trade_duration <= max_trade_duration: + factor *= 1.5 + elif trade_duration > max_trade_duration: + factor *= 0.5 + # Force exits - if self._force_action in ( + if self.force_actions and self._force_action in ( ForceActions.Take_profit, ForceActions.Stop_loss, ForceActions.Timeout, @@ -680,10 +691,6 @@ class ReforceXY(BaseReinforcementLearningModel): ) return pnl * factor - # first, penalize if the action is not valid - if not self._is_valid(action): - return -2 - # # you can use feature values from dataframe # rsi_now = self.get_feature_value( # name="%-rsi", @@ -715,19 +722,12 @@ class ReforceXY(BaseReinforcementLearningModel): self._non_profit_steps += 1 else: self._non_profit_steps = 0 - if self._non_profit_steps > 0: - return pnl - ( - 0.1 * (self._non_profit_steps**2) * max(0, pnl) - ) # time aggressive (quadratic) and loss magnitude aware penalty + if self._non_profit_steps > 0: + return factor * ( + pnl - (0.1 * (self._non_profit_steps**2) * abs(pnl)) + ) # time aggressive (quadratic) and loss magnitude aware penalty # discourage sitting in position - max_trade_duration = self.rl_config.get("max_trade_duration_candles", 300) - trade_duration = self.get_trade_duration() - if trade_duration <= max_trade_duration: - factor *= 1.5 - elif trade_duration > max_trade_duration: - factor *= 0.5 - if ( self._position in (Positions.Short, Positions.Long) and action == Actions.Neutral.value @@ -781,7 +781,7 @@ class ReforceXY(BaseReinforcementLearningModel): else: return features_window.to_numpy(dtype=np.float32) - def _get_force_action(self): + def _get_force_action(self) -> Optional[ForceActions]: if not self.force_actions or self._position == Positions.Neutral: return None @@ -813,6 +813,7 @@ class ReforceXY(BaseReinforcementLearningModel): self._position = Positions.Neutral self._last_closed_trade_tick = self._last_trade_tick self._last_trade_tick = None + self._non_profit_steps = 0 def execute_trade(self, action: int) -> None: """ @@ -913,25 +914,6 @@ class ReforceXY(BaseReinforcementLearningModel): ) ) - def _is_valid(self, action: int) -> bool: - """ - Determine if the action is valid for the step - """ - return not ( - ( - action in (Actions.Short_enter.value, Actions.Long_enter.value) - and self._position != Positions.Neutral - ) - or ( - action == Actions.Long_exit.value - and self._position != Positions.Long - ) - or ( - action == Actions.Short_exit.value - and self._position != Positions.Short - ) - ) - def get_feature_value( self, name: str, @@ -1392,3 +1374,20 @@ def sample_params_qrdqn(trial: Trial) -> Dict[str, Any]: n_quantiles = trial.suggest_int("n_quantiles", 5, 200) hyperparams["policy_kwargs"].update({"n_quantiles": n_quantiles}) return hyperparams + + +def _is_valid(action: int, position: float) -> bool: + """ + Determine if the action is valid for the step + """ + # Agent should only try to exit if it is in position + if action in (Actions.Short_exit.value, Actions.Long_exit.value): + if position not in (Positions.Short, Positions.Long): + return False + + # Agent should only try to enter if it is not in position + if action in (Actions.Short_enter.value, Actions.Long_enter.value): + if position != Positions.Neutral: + return False + + return True