]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
feat(reforcexy): subprocvec support
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Thu, 18 Sep 2025 20:09:48 +0000 (22:09 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Thu, 18 Sep 2025 20:09:48 +0000 (22:09 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index 1f745730645bbe22d43f1aac237abf2b7e9ff152..f98cade43d0e966cb4452b1aaacae4ce73a3ba77 100644 (file)
@@ -49,7 +49,12 @@ from stable_baselines3.common.callbacks import (
 from stable_baselines3.common.env_checker import check_env
 from stable_baselines3.common.logger import Figure, HParam
 from stable_baselines3.common.utils import set_random_seed
-from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecMonitor
+from stable_baselines3.common.vec_env import (
+    DummyVecEnv,
+    SubprocVecEnv,
+    VecFrameStack,
+    VecMonitor,
+)
 
 matplotlib.use("Agg")
 warnings.filterwarnings("ignore", category=UserWarning)
@@ -152,6 +157,12 @@ class ReforceXY(BaseReinforcementLearningModel):
         if not isinstance(self.n_envs, int) or self.n_envs < 1:
             logger.warning("Invalid n_envs=%s. Forcing n_envs=1", self.n_envs)
             self.n_envs = 1
+        vec_env = self.rl_config.get("vec_env", "dummy")
+        if vec_env == "subproc" and self.plot_new_best:
+            logger.warning(
+                "User tried to use plot_new_best with SubprocVecEnv. Deactivating plot_new_best"
+            )
+            self.plot_new_best = False
         if not isinstance(self.frame_stacking, int) or self.frame_stacking < 0:
             logger.warning(
                 "Invalid frame_stacking=%s. Forcing frame_stacking=0",
@@ -216,34 +227,41 @@ class ReforceXY(BaseReinforcementLearningModel):
                 _eval_env_check.close()
 
         logger.info("Populating environments: %s", self.n_envs)
-        train_env = DummyVecEnv(
-            [
-                make_env(
-                    self.MyRLEnv,
-                    f"train_env{i}",
-                    i,
-                    seed,
-                    train_df,
-                    prices_train,
-                    env_info=env_dict,
-                )
-                for i in range(self.n_envs)
-            ]
-        )
-        eval_env = DummyVecEnv(
-            [
-                make_env(
-                    self.MyRLEnv,
-                    f"eval_env{i}",
-                    i,
-                    seed + 10_000,
-                    test_df,
-                    prices_test,
-                    env_info=env_dict,
-                )
-                for i in range(self.n_envs)
-            ]
-        )
+        train_fns = [
+            make_env(
+                self.MyRLEnv,
+                f"train_env{i}",
+                i,
+                seed,
+                train_df,
+                prices_train,
+                env_info=env_dict,
+            )
+            for i in range(self.n_envs)
+        ]
+        eval_fns = [
+            make_env(
+                self.MyRLEnv,
+                f"eval_env{i}",
+                i,
+                seed + 10_000,
+                test_df,
+                prices_test,
+                env_info=env_dict,
+            )
+            for i in range(self.n_envs)
+        ]
+        vec_env = str(self.rl_config.get("vec_env", "dummy"))
+        if vec_env == "dummy":
+            logger.info("Using DummyVecEnv")
+            train_env = DummyVecEnv(train_fns)
+            eval_env = DummyVecEnv(eval_fns)
+        elif vec_env == "subproc":
+            logger.info("Using SubprocVecEnv")
+            train_env = SubprocVecEnv(train_fns, start_method="spawn")
+            eval_env = SubprocVecEnv(eval_fns, start_method="spawn")
+        else:
+            raise ValueError(f"Invalid vec_env: {vec_env}")
 
         if self.frame_stacking:
             logger.info("Frame stacking: %s", self.frame_stacking)
@@ -346,6 +364,20 @@ class ReforceXY(BaseReinforcementLearningModel):
         self._model_params_cache = model_params
         return copy.deepcopy(self._model_params_cache)
 
+    def get_eval_freq(
+        self, train_timesteps: int, model_params: Optional[Dict[str, Any]] = None
+    ) -> int:
+        if "PPO" in self.model_type:
+            if model_params:
+                n_steps = model_params.get("n_steps")
+                if isinstance(n_steps, int) and n_steps > 0:
+                    return n_steps
+            for s in sorted(PPO_N_STEPS, reverse=True):
+                if s <= train_timesteps:
+                    return s
+            return PPO_N_STEPS[0]
+        return max(1, train_timesteps // max(1, self.n_envs))
+
     def get_callbacks(
         self, eval_freq: int, data_path: str, trial: Optional[Trial] = None
     ) -> list[BaseCallback]:
@@ -388,7 +420,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             )
             callbacks.append(self.eval_callback)
         else:
-            trial_data_path = f"{data_path}/trial_{trial.number}"
+            trial_data_path = f"{data_path}/hyperopt/trial_{trial.number}"
             self.optuna_callback = MaskableTrialEvalCallback(
                 self.eval_env,
                 trial,
@@ -474,9 +506,8 @@ class ReforceXY(BaseReinforcementLearningModel):
                 **model_params,
             )
 
-        callbacks = self.get_callbacks(
-            max(1, train_timesteps // self.n_envs), str(dk.data_path)
-        )
+        eval_freq = self.get_eval_freq(train_timesteps, model_params)
+        callbacks = self.get_callbacks(eval_freq, str(dk.data_path))
         try:
             model.learn(total_timesteps=total_timesteps, callback=callbacks)
         finally:
@@ -648,8 +679,15 @@ class ReforceXY(BaseReinforcementLearningModel):
             if self.rl_config_optuna.get("per_pair", False)
             else self.get_storage()
         )
-        eval_freq = max(1, len(train_df) // self.n_envs)
-        max_resource = max(1, (total_timesteps + eval_freq - 1) // eval_freq)
+        if "PPO" in self.model_type:
+            resource_eval_freq = min(PPO_N_STEPS)
+        else:
+            resource_eval_freq = self.get_eval_freq(len(train_df))
+        max_resource = max(
+            1,
+            (total_timesteps // self.n_envs + resource_eval_freq - 1)
+            // resource_eval_freq,
+        )
         min_resource = min(3, max_resource)
         study: Study = create_study(
             study_name=study_name,
@@ -681,6 +719,7 @@ class ReforceXY(BaseReinforcementLearningModel):
                 ),
                 gc_after_trial=True,
                 show_progress_bar=self.rl_config.get("progress_bar", False),
+                # SB3 is not fully thread safe
                 n_jobs=1,
             )
         except KeyboardInterrupt:
@@ -805,8 +844,9 @@ class ReforceXY(BaseReinforcementLearningModel):
             raise RuntimeError("Environments not set. Cannot run HPO trial")
         if "PPO" in self.model_type:
             params = sample_params_ppo(trial, self.n_envs)
-            if params.get("n_steps", 0) > total_timesteps:
-                raise TrialPruned("n_steps is greater than total_timesteps")
+            n_steps = params.get("n_steps", 0)
+            if n_steps * self.n_envs > total_timesteps:
+                raise TrialPruned("n_steps * n_envs is greater than total_timesteps")
         elif "QRDQN" in self.model_type:
             params = sample_params_qrdqn(trial)
         elif "DQN" in self.model_type:
@@ -840,9 +880,8 @@ class ReforceXY(BaseReinforcementLearningModel):
             **params,
         )
 
-        callbacks = self.get_callbacks(
-            max(1, len(train_df) // self.n_envs), str(dk.data_path), trial
-        )
+        eval_freq = self.get_eval_freq(len(train_df), params)
+        callbacks = self.get_callbacks(eval_freq, str(dk.data_path), trial)
         try:
             model.learn(total_timesteps=total_timesteps, callback=callbacks)
         except AssertionError:
@@ -1161,18 +1200,17 @@ class ReforceXY(BaseReinforcementLearningModel):
             self._last_closed_trade_tick = self._current_tick
             self._last_trade_tick = None
 
-        def execute_trade(self, action: int) -> None:
+        def execute_trade(self, action: int) -> Optional[str]:
             """
             Execute trade based on the given action
             """
             # Force exit trade
             if self._force_action:
-                self._exit_trade()  # Exit trade due to force action
-                self.append_trade_history(f"{self._force_action.name}")
+                self._exit_trade()
                 self.tensorboard_log(
                     f"{self._force_action.name}", category="actions/force"
                 )
-                return None
+                return f"{self._force_action.name}"
 
             if not self.is_tradesignal(action):
                 return None
@@ -1180,12 +1218,14 @@ class ReforceXY(BaseReinforcementLearningModel):
             # Enter trade based on action
             if action in (Actions.Long_enter.value, Actions.Short_enter.value):
                 self._enter_trade(action)
-                self.append_trade_history(f"{self._position.name}_enter")
+                return f"{self._position.name}_enter"
 
             # Exit trade based on action
             if action in (Actions.Long_exit.value, Actions.Short_exit.value):
                 self._exit_trade()
-                self.append_trade_history(f"{self._last_closed_position.name}_exit")
+                return f"{self._last_closed_position.name}_exit"
+
+            return None
 
         def step(
             self, action: int
@@ -1195,13 +1235,17 @@ class ReforceXY(BaseReinforcementLearningModel):
             """
             self._current_tick += 1
             self._update_unrealized_total_profit()
-            pnl = self.get_unrealized_profit()
+            previous_pnl = self.get_unrealized_profit()
             self._update_portfolio_log_returns()
             self._force_action = self._get_force_action()
             reward = self.calculate_reward(action)
             self.total_reward += reward
             self.tensorboard_log(Actions._member_names_[action], category="actions")
-            self.execute_trade(action)
+            trade_type = self.execute_trade(action)
+            if trade_type is not None:
+                self.append_trade_history(
+                    trade_type, self.current_price(), previous_pnl
+                )
             self._position_history.append(self._position)
             info = {
                 "tick": self._current_tick,
@@ -1210,7 +1254,8 @@ class ReforceXY(BaseReinforcementLearningModel):
                 "force_action": (
                     self._force_action.name if self._force_action else None
                 ),
-                "pnl": round(pnl, 5),
+                "previous_pnl": round(previous_pnl, 5),
+                "pnl": round(self.get_unrealized_profit(), 5),
                 "reward": round(reward, 5),
                 "total_reward": round(self.total_reward, 5),
                 "total_profit": round(self._total_profit, 5),
@@ -1227,13 +1272,15 @@ class ReforceXY(BaseReinforcementLearningModel):
                 info,
             )
 
-        def append_trade_history(self, trade_type: str) -> None:
+        def append_trade_history(
+            self, trade_type: str, price: float, profit: float
+        ) -> None:
             self.trade_history.append(
                 {
                     "tick": self._current_tick,
                     "type": trade_type.lower(),
-                    "price": self.current_price(),
-                    "profit": self.get_unrealized_profit(),
+                    "price": price,
+                    "profit": profit,
                 }
             )
 
@@ -1361,31 +1408,55 @@ class ReforceXY(BaseReinforcementLearningModel):
 
         def get_env_history(self) -> DataFrame:
             """
-            Get environment data from the first to the last trade
+            Get environment data aligned on ticks, including optional trade events
             """
-            if not self.history or not self.trade_history:
-                logger.warning("history or trade_history is empty")
+            if not self.history:
+                logger.warning("history is empty")
                 return DataFrame()
 
             _history_df = DataFrame.from_dict(self.history)
-            _trade_history_df = DataFrame.from_dict(self.trade_history)
-
-            if (
-                "tick" not in _history_df.columns
-                or "tick" not in _trade_history_df.columns
-            ):
-                logger.warning("'tick' column is missing from history or trade_history")
+            if "tick" not in _history_df.columns:
+                logger.warning("'tick' column is missing from history")
                 return DataFrame()
 
-            _rollout_history = merge(
-                _history_df, _trade_history_df, on="tick", how="left"
-            ).ffill()
-            _price_history = (
-                self.prices.iloc[_rollout_history.tick].copy().reset_index()
-            )
-            history = merge(
-                _rollout_history, _price_history, left_index=True, right_index=True
-            )
+            if self.trade_history:
+                _trade_history_df = DataFrame.from_dict(self.trade_history)
+                if "tick" in _trade_history_df.columns:
+                    _rollout_history = merge(
+                        _history_df, _trade_history_df, on="tick", how="left"
+                    )
+                else:
+                    _rollout_history = _history_df.copy()
+            else:
+                _rollout_history = _history_df.copy()
+
+            try:
+                history = merge(
+                    _rollout_history,
+                    self.prices,
+                    left_on="tick",
+                    right_index=True,
+                    how="left",
+                )
+            except Exception:
+                try:
+                    _price_history = (
+                        self.prices.iloc[_rollout_history.tick]
+                        .copy()
+                        .reset_index(drop=True)
+                    )
+                    history = merge(
+                        _rollout_history,
+                        _price_history,
+                        left_index=True,
+                        right_index=True,
+                    )
+                except Exception as e:
+                    logger.error(
+                        f"Failed to merge history with prices: {repr(e)}",
+                        exc_info=True,
+                    )
+                    return DataFrame()
             return history
 
         def get_env_plot(self) -> plt.Figure:
@@ -1396,79 +1467,140 @@ class ReforceXY(BaseReinforcementLearningModel):
             def transform_y_offset(ax, offset):
                 return mtransforms.offset_copy(ax.transData, fig=fig, y=offset)
 
-            def plot_markers(ax, ticks, marker, color, size, offset):
+            def plot_markers(ax, xs, ys, marker, color, size, offset):
                 ax.plot(
-                    ticks,
+                    xs,
+                    ys,
                     marker=marker,
                     color=color,
                     markersize=size,
                     fillstyle="full",
                     transform=transform_y_offset(ax, offset),
                     linestyle="none",
+                    zorder=3,
                 )
 
-            plt.style.use("dark_background")
-            fig, axs = plt.subplots(
-                nrows=5,
-                ncols=1,
-                figsize=(14, 8),
-                height_ratios=[5, 1, 1, 1, 1],
-                sharex=True,
-            )
+            with plt.style.context("dark_background"):
+                fig, axs = plt.subplots(
+                    nrows=5,
+                    ncols=1,
+                    figsize=(14, 8),
+                    height_ratios=[5, 1, 1, 1, 1],
+                    sharex=True,
+                )
 
-            if len(self.trade_history) == 0:
-                return fig
-
-            history = self.get_env_history()
-            if len(history) == 0:
-                return fig
-
-            history_price = history.get("price")
-            if history_price is None or len(history_price) == 0:
-                return fig
-            history_type = history.get("type")
-            if history_type is None or len(history_type) == 0:
-                return fig
-            history_open = history.get("open")
-            if history_open is None or len(history_open) == 0:
-                return fig
-
-            enter_long_prices = history.loc[history_type == "long_enter", "price"]
-            enter_short_prices = history.loc[history_type == "short_enter", "price"]
-            exit_long_prices = history.loc[history_type == "long_exit", "price"]
-            exit_short_prices = history.loc[history_type == "short_exit", "price"]
-            take_profit_prices = history.loc[history_type == "take_profit", "price"]
-            stop_loss_prices = history.loc[history_type == "stop_loss", "price"]
-            timeout_prices = history.loc[history_type == "timeout", "price"]
-
-            axs[0].plot(history_open, linewidth=1, color="orchid")
-
-            plot_markers(axs[0], enter_long_prices, "^", "forestgreen", 5, -0.1)
-            plot_markers(axs[0], enter_short_prices, "v", "firebrick", 5, 0.1)
-            plot_markers(axs[0], exit_long_prices, ".", "cornflowerblue", 4, 0.1)
-            plot_markers(axs[0], exit_short_prices, ".", "thistle", 4, -0.1)
-            plot_markers(axs[0], take_profit_prices, "*", "lime", 8, 0.1)
-            plot_markers(axs[0], stop_loss_prices, "x", "red", 8, -0.1)
-            plot_markers(axs[0], timeout_prices, "1", "yellow", 8, 0.0)
-
-            axs[1].set_ylabel("pnl")
-            axs[1].plot(history.get("pnl"), linewidth=1, color="gray")
-            axs[1].axhline(y=0, label="0", alpha=0.33, color="gray")
-            axs[1].axhline(y=self.take_profit, label="tp", alpha=0.33, color="green")
-            axs[1].axhline(y=self.stop_loss, label="sl", alpha=0.33, color="red")
-
-            axs[2].set_ylabel("reward")
-            axs[2].plot(history.get("reward"), linewidth=1, color="gray")
-            axs[2].axhline(y=0, label="0", alpha=0.33)
-
-            axs[3].set_ylabel("total_profit")
-            axs[3].plot(history.get("total_profit"), linewidth=1, color="gray")
-            axs[3].axhline(y=1, label="0", alpha=0.33)
-
-            axs[4].set_ylabel("total_reward")
-            axs[4].plot(history.get("total_reward"), linewidth=1, color="gray")
-            axs[4].axhline(y=0, label="0", alpha=0.33)
-            axs[4].set_xlabel("tick")
+                history = self.get_env_history()
+                if len(history) == 0:
+                    return fig
+
+                plot_window = int(self.rl_config.get("plot_window", 2000))
+                if plot_window > 0 and len(history) > plot_window:
+                    history = history.iloc[-plot_window:]
+
+                ticks = history.get("tick")
+                history_open = history.get("open")
+                if (
+                    ticks is None
+                    or len(ticks) == 0
+                    or history_open is None
+                    or len(history_open) == 0
+                ):
+                    return fig
+
+                axs[0].plot(ticks, history_open, linewidth=1, color="orchid", zorder=1)
+
+                history_type = history.get("type")
+                history_price = history.get("price")
+                if history_type is not None and history_price is not None:
+                    enter_long_mask = history_type == "long_enter"
+                    enter_short_mask = history_type == "short_enter"
+                    exit_long_mask = history_type == "long_exit"
+                    exit_short_mask = history_type == "short_exit"
+                    take_profit_mask = history_type == "take_profit"
+                    stop_loss_mask = history_type == "stop_loss"
+                    timeout_mask = history_type == "timeout"
+
+                    enter_long_x = ticks[enter_long_mask]
+                    enter_short_x = ticks[enter_short_mask]
+                    exit_long_x = ticks[exit_long_mask]
+                    exit_short_x = ticks[exit_short_mask]
+                    take_profit_x = ticks[take_profit_mask]
+                    stop_loss_x = ticks[stop_loss_mask]
+                    timeout_x = ticks[timeout_mask]
+
+                    enter_long_y = history.loc[enter_long_mask, "price"]
+                    enter_short_y = history.loc[enter_short_mask, "price"]
+                    exit_long_y = history.loc[exit_long_mask, "price"]
+                    exit_short_y = history.loc[exit_short_mask, "price"]
+                    take_profit_y = history.loc[take_profit_mask, "price"]
+                    stop_loss_y = history.loc[stop_loss_mask, "price"]
+                    timeout_y = history.loc[timeout_mask, "price"]
+
+                    plot_markers(
+                        axs[0], enter_long_x, enter_long_y, "^", "forestgreen", 5, -0.1
+                    )
+                    plot_markers(
+                        axs[0], enter_short_x, enter_short_y, "v", "firebrick", 5, 0.1
+                    )
+                    plot_markers(
+                        axs[0], exit_long_x, exit_long_y, ".", "cornflowerblue", 4, 0.1
+                    )
+                    plot_markers(
+                        axs[0], exit_short_x, exit_short_y, ".", "thistle", 4, -0.1
+                    )
+                    plot_markers(
+                        axs[0], take_profit_x, take_profit_y, "*", "lime", 8, 0.1
+                    )
+                    plot_markers(axs[0], stop_loss_x, stop_loss_y, "x", "red", 8, -0.1)
+                    plot_markers(axs[0], timeout_x, timeout_y, "1", "yellow", 8, 0.0)
+
+                axs[1].set_ylabel("pnl")
+                pnl_series = history.get("pnl")
+                if pnl_series is not None and len(pnl_series) > 0:
+                    axs[1].plot(
+                        ticks,
+                        pnl_series,
+                        linewidth=1,
+                        color="gray",
+                        label="pnl",
+                    )
+                previous_pnl_series = history.get("previous_pnl")
+                if previous_pnl_series is not None and len(previous_pnl_series) > 0:
+                    axs[1].plot(
+                        ticks,
+                        previous_pnl_series,
+                        linewidth=1,
+                        color="deepskyblue",
+                        label="previous_pnl",
+                    )
+                if (pnl_series is not None and len(pnl_series) > 0) or (
+                    previous_pnl_series is not None and len(previous_pnl_series) > 0
+                ):
+                    axs[1].legend(loc="upper left", fontsize=8)
+                axs[1].axhline(y=0, label="0", alpha=0.33, color="gray")
+                axs[1].axhline(
+                    y=self.take_profit, label="tp", alpha=0.33, color="green"
+                )
+                axs[1].axhline(y=self.stop_loss, label="sl", alpha=0.33, color="red")
+
+                axs[2].set_ylabel("reward")
+                reward_series = history.get("reward")
+                if reward_series is not None and len(reward_series) > 0:
+                    axs[2].plot(ticks, reward_series, linewidth=1, color="gray")
+                axs[2].axhline(y=0, label="0", alpha=0.33)
+
+                axs[3].set_ylabel("total_profit")
+                total_profit_series = history.get("total_profit")
+                if total_profit_series is not None and len(total_profit_series) > 0:
+                    axs[3].plot(ticks, total_profit_series, linewidth=1, color="gray")
+                axs[3].axhline(y=1, label="1", alpha=0.33)
+
+                axs[4].set_ylabel("total_reward")
+                total_reward_series = history.get("total_reward")
+                if total_reward_series is not None and len(total_reward_series) > 0:
+                    axs[4].plot(ticks, total_reward_series, linewidth=1, color="gray")
+                axs[4].axhline(y=0, label="0", alpha=0.33)
+                axs[4].set_xlabel("tick")
 
             _borders = ["top", "right", "bottom", "left"]
             for _ax in axs:
@@ -1836,18 +1968,41 @@ class MaskableTrialEvalCallback(MaskableEvalCallback):
         if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
             self.eval_idx += 1
             last_mean_reward = getattr(self, "last_mean_reward", np.nan)
-            if not isinstance(last_mean_reward, (int, float)) or not np.isfinite(
-                float(last_mean_reward)
-            ):
+            try:
+                last_mean_reward = float(last_mean_reward)
+            except Exception as e:
+                logger.warning(
+                    "Optuna: invalid last_mean_reward at eval %s: %r", self.eval_idx, e
+                )
                 self.is_pruned = True
                 return False
-            if hasattr(self.trial, "report"):
-                try:
-                    self.trial.report(last_mean_reward, self.eval_idx)
-                except Exception:
+            if not np.isfinite(last_mean_reward):
+                logger.warning(
+                    "Optuna: non-finite last_mean_reward at eval %s", self.eval_idx
+                )
+                self.is_pruned = True
+                return False
+            try:
+                self.trial.report(last_mean_reward, self.eval_idx)
+            except Exception as e:
+                logger.warning(
+                    "Optuna: trial.report failed at eval %s: %r", self.eval_idx, e
+                )
+                self.is_pruned = True
+                return False
+            try:
+                if self.trial.should_prune():
+                    logger.info(
+                        "Optuna: trial pruned at eval %s (score=%.5f)",
+                        self.eval_idx,
+                        last_mean_reward,
+                    )
                     self.is_pruned = True
                     return False
-            if hasattr(self.trial, "should_prune") and self.trial.should_prune():
+            except Exception as e:
+                logger.warning(
+                    "Optuna: should_prune failed at eval %s: %r", self.eval_idx, e
+                )
                 self.is_pruned = True
                 return False
 
@@ -1900,7 +2055,7 @@ def deepmerge(dst: Dict[str, Any], src: Dict[str, Any]) -> Dict[str, Any]:
     return dst_copy
 
 
-@lru_cache(maxsize=64)
+@lru_cache(maxsize=128)
 def linear_schedule(initial_value: float) -> Callable[[float], float]:
     def func(progress_remaining: float) -> float:
         return progress_remaining * initial_value
@@ -2084,11 +2239,14 @@ def convert_optuna_params_to_model_params(
     return model_params
 
 
+PPO_N_STEPS: tuple[int, ...] = (512, 1024, 2048, 4096)
+
+
 def sample_params_ppo(trial: Trial, n_envs: int) -> Dict[str, Any]:
     """
     Sampler for PPO hyperparams
     """
-    n_steps = trial.suggest_categorical("n_steps", [512, 1024, 2048, 4096])
+    n_steps = trial.suggest_categorical("n_steps", list(PPO_N_STEPS))
     batch_size = trial.suggest_categorical("batch_size", [64, 128, 256, 512, 1024])
     if batch_size > n_steps:
         raise TrialPruned("batch_size must be less than or equal to n_steps")