]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
fix(reforcexy): proper action masking implementation
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Fri, 19 Sep 2025 13:48:00 +0000 (15:48 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Fri, 19 Sep 2025 13:48:00 +0000 (15:48 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index e0e486c60b37f226e89cabeef32dccc5f7ac1572..d08cd1b03981c5a5d0833ea384c19a5a52e187ee 100644 (file)
@@ -85,7 +85,8 @@ class ReforceXY(BaseReinforcementLearningModel):
                 ...
                 "max_trade_duration_candles": 96,   // Timeout exit value used with force_actions
                 "force_actions": false,             // Utilize minimal_roi, stoploss, and max_trade_duration_candles as TP/SL/Timeout in the environment
-                "n_envs": 1,                        // Number of DummyVecEnv environments
+                "n_envs": 1,                        // Number of DummyVecEnv or SubProcVecEnv environments
+                "multiprocessing": false,           // Use SubprocVecEnv if n_envs>1 (otherwise DummyVecEnv)
                 "frame_stacking": 0,                // Number of VecFrameStack stacks (set > 1 to use)
                 "lr_schedule": false,               // Enable learning rate linear schedule
                 "cr_schedule": false,               // Enable clip range linear schedule
@@ -117,12 +118,15 @@ class ReforceXY(BaseReinforcementLearningModel):
             raise ValueError(
                 "FreqAI model requires StaticPairList method defined in pairlists configuration and pair_whitelist defined in exchange section configuration"
             )
-        self.is_maskable: bool = (
+        self.action_masking: bool = (
             self.model_type == "MaskablePPO"
         )  # Enable action masking
+        self.rl_config["action_masking"] = self.action_masking
+        self.inference_masking: bool = self.rl_config.get("inference_masking", True)
         self.lr_schedule: bool = self.rl_config.get("lr_schedule", False)
         self.cr_schedule: bool = self.rl_config.get("cr_schedule", False)
         self.n_envs: int = self.rl_config.get("n_envs", 1)
+        self.multiprocessing: bool = self.rl_config.get("multiprocessing", False)
         self.frame_stacking: int = self.rl_config.get("frame_stacking", 0)
         self.max_no_improvement_evals: int = self.rl_config.get(
             "max_no_improvement_evals", 0
@@ -149,6 +153,42 @@ class ReforceXY(BaseReinforcementLearningModel):
         self._model_params_cache: Optional[Dict[str, Any]] = None
         self.unset_unsupported()
 
+    @staticmethod
+    def build_action_mask(
+        position: Positions, force_action: Optional[ForceActions] = None
+    ) -> np.ndarray:
+        action_mask = np.zeros(len(Actions), dtype=bool)
+
+        action_mask[Actions.Neutral.value] = True
+
+        if position == Positions.Neutral:
+            action_mask[Actions.Long_enter.value] = True
+            action_mask[Actions.Short_enter.value] = True
+        elif position == Positions.Long:
+            action_mask[Actions.Long_exit.value] = True
+        elif position == Positions.Short:
+            action_mask[Actions.Short_exit.value] = True
+
+        if force_action is not None and position in (Positions.Long, Positions.Short):
+            forced = np.zeros(len(Actions), dtype=bool)
+            try:
+                if position == Positions.Long:
+                    forced[Actions.Long_exit.value] = True
+                elif position == Positions.Short:
+                    forced[Actions.Short_exit.value] = True
+            except Exception:
+                return action_mask
+            if forced.any():
+                return forced
+            return action_mask
+
+        if not action_mask.any():
+            try:
+                action_mask[Actions.Neutral.value] = True
+            except Exception:
+                action_mask = np.ones_like(action_mask, dtype=bool)
+        return action_mask
+
     def unset_unsupported(self) -> None:
         """
         If user has activated any features that may conflict, this
@@ -157,10 +197,14 @@ class ReforceXY(BaseReinforcementLearningModel):
         if not isinstance(self.n_envs, int) or self.n_envs < 1:
             logger.warning("Invalid n_envs=%s. Forcing n_envs=1", self.n_envs)
             self.n_envs = 1
-        vec_env = self.rl_config.get("vec_env", "dummy")
-        if vec_env == "subproc" and self.plot_new_best:
+        if self.multiprocessing and self.n_envs <= 1:
             logger.warning(
-                "User tried to use plot_new_best with SubprocVecEnv. Deactivating plot_new_best"
+                "User tried to use multiprocessing with n_envs=1. Deactivating multiprocessing"
+            )
+            self.multiprocessing = False
+        if self.multiprocessing and self.plot_new_best:
+            logger.warning(
+                "User tried to use plot_new_best with multiprocessing. Deactivating plot_new_best"
             )
             self.plot_new_best = False
         if not isinstance(self.frame_stacking, int) or self.frame_stacking < 0:
@@ -171,9 +215,10 @@ class ReforceXY(BaseReinforcementLearningModel):
             self.frame_stacking = 0
         if self.frame_stacking == 1:
             logger.warning(
-                "Setting frame_stacking=%s is equivalent to no stacking; use >=2 or 0",
+                "Setting frame_stacking=%s is equivalent to no stacking. Forcing frame_stacking=0",
                 self.frame_stacking,
             )
+            self.frame_stacking = 0
         if self.continual_learning and self.frame_stacking:
             logger.warning(
                 "User tried to use continual_learning with frame_stacking. \
@@ -251,25 +296,25 @@ class ReforceXY(BaseReinforcementLearningModel):
             )
             for i in range(self.n_envs)
         ]
-        vec_env = str(self.rl_config.get("vec_env", "dummy"))
-        if vec_env == "dummy":
-            logger.info("Using DummyVecEnv")
-            train_env = DummyVecEnv(train_fns)
-            eval_env = DummyVecEnv(eval_fns)
-        elif vec_env == "subproc":
-            logger.info("Using SubprocVecEnv")
+
+        logger.info("Multiprocessing: %s", self.multiprocessing)
+        if self.multiprocessing and self.n_envs > 1:
             train_env = SubprocVecEnv(train_fns, start_method="spawn")
             eval_env = SubprocVecEnv(eval_fns, start_method="spawn")
         else:
-            raise ValueError(f"Invalid vec_env: {vec_env}")
+            train_env = DummyVecEnv(train_fns)
+            eval_env = DummyVecEnv(eval_fns)
+
+        train_env = VecMonitor(train_env)
+        eval_env = VecMonitor(eval_env)
 
         if self.frame_stacking:
             logger.info("Frame stacking: %s", self.frame_stacking)
             train_env = VecFrameStack(train_env, n_stack=self.frame_stacking)
             eval_env = VecFrameStack(eval_env, n_stack=self.frame_stacking)
 
-        self.train_env = VecMonitor(train_env)
-        self.eval_env = VecMonitor(eval_env)
+        self.train_env = train_env
+        self.eval_env = eval_env
 
     def get_model_params(self) -> Dict[str, Any]:
         """
@@ -372,9 +417,9 @@ class ReforceXY(BaseReinforcementLearningModel):
                 n_steps = model_params.get("n_steps")
                 if isinstance(n_steps, int) and n_steps > 0:
                     return n_steps
-            for s in sorted(PPO_N_STEPS, reverse=True):
-                if s <= train_timesteps:
-                    return s
+            for step in sorted(PPO_N_STEPS, reverse=True):
+                if step <= train_timesteps:
+                    return step
             return PPO_N_STEPS[0]
         return max(1, train_timesteps // max(1, self.n_envs))
 
@@ -413,7 +458,7 @@ class ReforceXY(BaseReinforcementLearningModel):
                 deterministic=True,
                 render=False,
                 best_model_save_path=data_path,
-                use_masking=self.is_maskable,
+                use_masking=self.action_masking,
                 callback_on_new_best=rollout_plot_callback,
                 callback_after_eval=no_improvement_callback,
                 verbose=verbose,
@@ -428,7 +473,7 @@ class ReforceXY(BaseReinforcementLearningModel):
                 deterministic=True,
                 render=False,
                 best_model_save_path=trial_data_path,
-                use_masking=self.is_maskable,
+                use_masking=self.action_masking,
                 verbose=verbose,
             )
             callbacks.append(self.optuna_callback)
@@ -469,7 +514,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             total_days,
         )
         logger.info("Test: %s steps (%s days)", test_timesteps, test_days)
-        logger.info("Action masking: %s", self.is_maskable)
+        logger.info("Action masking: %s", self.action_masking)
         logger.info("Hyperopt: %s", self.hyperopt)
 
         start_time = time.time()
@@ -551,10 +596,10 @@ class ReforceXY(BaseReinforcementLearningModel):
             if isinstance(position, Positions):
                 return position
             try:
-                f = float(position)
-                if f == float(Positions.Long.value):
+                position = float(position)
+                if position == float(Positions.Long.value):
                     return Positions.Long
-                if f == float(Positions.Short.value):
+                if position == float(Positions.Short.value):
                     return Positions.Short
                 return Positions.Neutral
             except Exception:
@@ -577,8 +622,18 @@ class ReforceXY(BaseReinforcementLearningModel):
 
             return True
 
-        def _action_masks(position: Any) -> list[bool]:
-            return [_is_valid(action.value, position) for action in Actions]
+        simulated_position: Positions = Positions.Neutral
+
+        def _update_simulated_position(action: int, position: Positions) -> Positions:
+            if action == Actions.Long_enter.value and position == Positions.Neutral:
+                return Positions.Long
+            if action == Actions.Short_enter.value and position == Positions.Neutral:
+                return Positions.Short
+            if action == Actions.Long_exit.value and position == Positions.Long:
+                return Positions.Neutral
+            if action == Actions.Short_exit.value and position == Positions.Short:
+                return Positions.Neutral
+            return position
 
         frame_buffer: list[np.ndarray] = []
 
@@ -599,9 +654,6 @@ class ReforceXY(BaseReinforcementLearningModel):
                 observation["position"] = position
                 observation["trade_duration"] = trade_duration
 
-                if self.is_maskable:
-                    action_masks_param = {"action_masks": _action_masks(position)}
-
             np_observation = observation.to_numpy(dtype=np.float32)
 
             fb: list[np.ndarray] = frame_buffer
@@ -621,13 +673,28 @@ class ReforceXY(BaseReinforcementLearningModel):
             else:
                 observations = np_observation.reshape(1, -1)
 
+            if self.action_masking and self.inference_masking:
+                action_masks_param["action_masks"] = ReforceXY.build_action_mask(
+                    simulated_position, None
+                )
+
             action, _ = model.predict(
                 observations, deterministic=True, **action_masks_param
             )
             return int(action)
 
-        actions = dataframe.iloc[:, 0].rolling(window=self.CONV_WIDTH).apply(_predict)
-        return DataFrame({label: actions for label in dk.label_list})
+        predicted_actions: list[int] = []
+        for window_end in range(self.CONV_WIDTH, len(dataframe) + 1):
+            window = dataframe.iloc[window_end - self.CONV_WIDTH : window_end]
+            action = _predict(window)
+            predicted_actions.append(action)
+            simulated_position = _update_simulated_position(action, simulated_position)
+
+        pad = [np.nan] * (self.CONV_WIDTH - 1)
+        actions_list = pad + predicted_actions
+        actions = DataFrame({"action": actions_list}, index=dataframe.index)
+
+        return DataFrame({label: actions["action"] for label in dk.label_list})
 
     def get_storage(self, pair: Optional[str] = None) -> BaseStorage:
         """
@@ -844,8 +911,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             raise RuntimeError("Environments not set. Cannot run HPO trial")
         if "PPO" in self.model_type:
             params = sample_params_ppo(trial, self.n_envs)
-            n_steps = params.get("n_steps", 0)
-            if n_steps * self.n_envs > total_timesteps:
+            if params.get("n_steps", 0) * self.n_envs > total_timesteps:
                 raise TrialPruned("n_steps * n_envs is greater than total_timesteps")
         elif "QRDQN" in self.model_type:
             params = sample_params_qrdqn(trial)
@@ -855,8 +921,10 @@ class ReforceXY(BaseReinforcementLearningModel):
             raise NotImplementedError
 
         if "DQN" in self.model_type:
-            batch_size = params.get("batch_size")
             gradient_steps = params.get("gradient_steps")
+            if isinstance(gradient_steps, int) and gradient_steps <= 0:
+                raise TrialPruned("gradient_steps is negative or zero")
+            batch_size = params.get("batch_size")
             buffer_size = params.get("buffer_size")
             if (batch_size * gradient_steps) > buffer_size:
                 raise TrialPruned(
@@ -898,16 +966,16 @@ class ReforceXY(BaseReinforcementLearningModel):
             nan_encountered = True
         except ValueError as e:
             if any(x in str(e).lower() for x in ("nan", "inf")):
-                logger.warning("Optuna encountered NaN/Inf (ValueError): %s", e)
+                logger.warning("Optuna encountered NaN/Inf (ValueError): %r", e)
                 nan_encountered = True
             else:
                 raise
         except FloatingPointError as e:
-            logger.warning("Optuna encountered NaN/Inf (FloatingPointError): %s", e)
+            logger.warning("Optuna encountered NaN/Inf (FloatingPointError): %r", e)
             nan_encountered = True
         except RuntimeError as e:
             if any(x in str(e).lower() for x in ("nan", "inf")):
-                logger.warning("Optuna encountered NaN/Inf (RuntimeError): %s", e)
+                logger.warning("Optuna encountered NaN/Inf (RuntimeError): %r", e)
                 nan_encountered = True
             else:
                 raise
@@ -946,6 +1014,7 @@ class ReforceXY(BaseReinforcementLearningModel):
         def __init__(self, **kwargs):
             super().__init__(**kwargs)
             self._set_observation_space()
+            self.action_masking: bool = self.rl_config.get("action_masking", False)
             self.force_actions: bool = self.rl_config.get("force_actions", False)
             self._force_action: Optional[ForceActions] = None
             self.take_profit: float = self.config.get("minimal_roi", {}).get("0", 0.03)
@@ -1046,7 +1115,11 @@ class ReforceXY(BaseReinforcementLearningModel):
                     of weights in NN)
             """
             # first, penalize if the action is not valid
-            if not self._force_action and not self._is_valid(action):
+            if (
+                not self.action_masking
+                and not self._force_action
+                and not self._is_valid(action)
+            ):
                 self.tensorboard_log("invalid", category="actions")
                 return self.rl_config.get("model_reward_parameters", {}).get(
                     "invalid_action", -2.0
@@ -1321,6 +1394,17 @@ class ReforceXY(BaseReinforcementLearningModel):
                 )
             )
 
+        def action_masks(self):
+            try:
+                return ReforceXY.build_action_mask(self._position, self._force_action)
+            except Exception:
+                action_mask = np.zeros(len(Actions), dtype=bool)
+                try:
+                    action_mask[Actions.Neutral.value] = True
+                except Exception:
+                    action_mask = np.ones(len(Actions), dtype=bool)
+                return action_mask
+
         def get_feature_value(
             self,
             name: str,
@@ -1360,24 +1444,22 @@ class ReforceXY(BaseReinforcementLearningModel):
             Calculate the most recent maximum unrealized profit if in a trade
             """
             if self._last_trade_tick is None:
-                return -np.inf
+                return 0.0
             if self._position == Positions.Neutral:
-                return -np.inf
+                return 0.0
             pnl_history = self.history.get("pnl")
             if not pnl_history or len(pnl_history) == 0:
-                return -np.inf
+                return 0.0
 
             pnl_history = np.asarray(pnl_history)
             ticks = self.history.get("tick")
             if not ticks:
-                return -np.inf
+                return 0.0
             ticks = np.asarray(ticks)
             start = np.searchsorted(ticks, self._last_trade_tick, side="left")
-            if start >= ticks.shape[0]:
-                return -np.inf
             trade_pnl_history = pnl_history[start:]
             if trade_pnl_history.size == 0:
-                return -np.inf
+                return 0.0
             return np.max(trade_pnl_history)
 
         def get_most_recent_return(self) -> float:
@@ -1912,11 +1994,13 @@ class InfoMetricsCallback(TensorboardCallback):
             if isinstance(metrics, dict):
                 for metric, value in metrics.items():
                     try:
-                        self.logger.record(f"{category}/{metric}", value)
+                        self.logger.record(f"{category}/{metric}_sum", value)
                     except Exception:
                         try:
                             self.logger.record(
-                                f"{category}/{metric}", value, exclude=("tensorboard",)
+                                f"{category}/{metric}_sum",
+                                value,
+                                exclude=("tensorboard",),
                             )
                         except Exception:
                             pass
@@ -2294,9 +2378,9 @@ def sample_params_ppo(trial: Trial, n_envs: int) -> Dict[str, Any]:
             ),
             "learning_rate": trial.suggest_float("learning_rate", 1e-5, 3e-3, log=True),
             "ent_coef": trial.suggest_float("ent_coef", 0.0005, 0.03, log=True),
-            "clip_range": trial.suggest_categorical("clip_range", [0.1, 0.2, 0.3]),
+            "clip_range": trial.suggest_float("clip_range", 0.1, 0.4, step=0.05),
             "n_epochs": trial.suggest_categorical("n_epochs", [1, 2, 3, 4, 5]),
-            "gae_lambda": trial.suggest_float("gae_lambda", 0.9, 0.98, step=0.01),
+            "gae_lambda": trial.suggest_float("gae_lambda", 0.9, 0.99, step=0.01),
             "max_grad_norm": trial.suggest_float("max_grad_norm", 0.3, 1.0, step=0.05),
             "vf_coef": trial.suggest_float("vf_coef", 0.0, 1.0, step=0.05),
             "lr_schedule": trial.suggest_categorical(
@@ -2346,9 +2430,9 @@ def get_common_dqn_optuna_params(trial: Trial) -> Dict[str, Any]:
         raise TrialPruned("learning_starts is greater than or equal to buffer_size")
     return {
         "train_freq": trial.suggest_categorical(
-            "train_freq", [2, 4, 8, 16, 128, 256, 512, 1024]
+            "train_freq", [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
         ),
-        "subsample_steps": trial.suggest_categorical("subsample_steps", [2, 4, 8]),
+        "subsample_steps": trial.suggest_categorical("subsample_steps", [2, 4, 8, 16]),
         "gamma": trial.suggest_categorical(
             "gamma", [0.93, 0.95, 0.97, 0.98, 0.99, 0.995, 0.997, 0.999, 0.9999]
         ),