From 8fcfb8a7aab381a04cf340633daab82c4236596f Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Thu, 18 Sep 2025 22:09:48 +0200 Subject: [PATCH] feat(reforcexy): subprocvec support MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/user_data/freqaimodels/ReforceXY.py | 446 ++++++++++++------ 1 file changed, 302 insertions(+), 144 deletions(-) diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index 1f74573..f98cade 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -49,7 +49,12 @@ from stable_baselines3.common.callbacks import ( from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.logger import Figure, HParam from stable_baselines3.common.utils import set_random_seed -from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecMonitor +from stable_baselines3.common.vec_env import ( + DummyVecEnv, + SubprocVecEnv, + VecFrameStack, + VecMonitor, +) matplotlib.use("Agg") warnings.filterwarnings("ignore", category=UserWarning) @@ -152,6 +157,12 @@ class ReforceXY(BaseReinforcementLearningModel): if not isinstance(self.n_envs, int) or self.n_envs < 1: logger.warning("Invalid n_envs=%s. Forcing n_envs=1", self.n_envs) self.n_envs = 1 + vec_env = self.rl_config.get("vec_env", "dummy") + if vec_env == "subproc" and self.plot_new_best: + logger.warning( + "User tried to use plot_new_best with SubprocVecEnv. Deactivating plot_new_best" + ) + self.plot_new_best = False if not isinstance(self.frame_stacking, int) or self.frame_stacking < 0: logger.warning( "Invalid frame_stacking=%s. Forcing frame_stacking=0", @@ -216,34 +227,41 @@ class ReforceXY(BaseReinforcementLearningModel): _eval_env_check.close() logger.info("Populating environments: %s", self.n_envs) - train_env = DummyVecEnv( - [ - make_env( - self.MyRLEnv, - f"train_env{i}", - i, - seed, - train_df, - prices_train, - env_info=env_dict, - ) - for i in range(self.n_envs) - ] - ) - eval_env = DummyVecEnv( - [ - make_env( - self.MyRLEnv, - f"eval_env{i}", - i, - seed + 10_000, - test_df, - prices_test, - env_info=env_dict, - ) - for i in range(self.n_envs) - ] - ) + train_fns = [ + make_env( + self.MyRLEnv, + f"train_env{i}", + i, + seed, + train_df, + prices_train, + env_info=env_dict, + ) + for i in range(self.n_envs) + ] + eval_fns = [ + make_env( + self.MyRLEnv, + f"eval_env{i}", + i, + seed + 10_000, + test_df, + prices_test, + env_info=env_dict, + ) + for i in range(self.n_envs) + ] + vec_env = str(self.rl_config.get("vec_env", "dummy")) + if vec_env == "dummy": + logger.info("Using DummyVecEnv") + train_env = DummyVecEnv(train_fns) + eval_env = DummyVecEnv(eval_fns) + elif vec_env == "subproc": + logger.info("Using SubprocVecEnv") + train_env = SubprocVecEnv(train_fns, start_method="spawn") + eval_env = SubprocVecEnv(eval_fns, start_method="spawn") + else: + raise ValueError(f"Invalid vec_env: {vec_env}") if self.frame_stacking: logger.info("Frame stacking: %s", self.frame_stacking) @@ -346,6 +364,20 @@ class ReforceXY(BaseReinforcementLearningModel): self._model_params_cache = model_params return copy.deepcopy(self._model_params_cache) + def get_eval_freq( + self, train_timesteps: int, model_params: Optional[Dict[str, Any]] = None + ) -> int: + if "PPO" in self.model_type: + if model_params: + n_steps = model_params.get("n_steps") + if isinstance(n_steps, int) and n_steps > 0: + return n_steps + for s in sorted(PPO_N_STEPS, reverse=True): + if s <= train_timesteps: + return s + return PPO_N_STEPS[0] + return max(1, train_timesteps // max(1, self.n_envs)) + def get_callbacks( self, eval_freq: int, data_path: str, trial: Optional[Trial] = None ) -> list[BaseCallback]: @@ -388,7 +420,7 @@ class ReforceXY(BaseReinforcementLearningModel): ) callbacks.append(self.eval_callback) else: - trial_data_path = f"{data_path}/trial_{trial.number}" + trial_data_path = f"{data_path}/hyperopt/trial_{trial.number}" self.optuna_callback = MaskableTrialEvalCallback( self.eval_env, trial, @@ -474,9 +506,8 @@ class ReforceXY(BaseReinforcementLearningModel): **model_params, ) - callbacks = self.get_callbacks( - max(1, train_timesteps // self.n_envs), str(dk.data_path) - ) + eval_freq = self.get_eval_freq(train_timesteps, model_params) + callbacks = self.get_callbacks(eval_freq, str(dk.data_path)) try: model.learn(total_timesteps=total_timesteps, callback=callbacks) finally: @@ -648,8 +679,15 @@ class ReforceXY(BaseReinforcementLearningModel): if self.rl_config_optuna.get("per_pair", False) else self.get_storage() ) - eval_freq = max(1, len(train_df) // self.n_envs) - max_resource = max(1, (total_timesteps + eval_freq - 1) // eval_freq) + if "PPO" in self.model_type: + resource_eval_freq = min(PPO_N_STEPS) + else: + resource_eval_freq = self.get_eval_freq(len(train_df)) + max_resource = max( + 1, + (total_timesteps // self.n_envs + resource_eval_freq - 1) + // resource_eval_freq, + ) min_resource = min(3, max_resource) study: Study = create_study( study_name=study_name, @@ -681,6 +719,7 @@ class ReforceXY(BaseReinforcementLearningModel): ), gc_after_trial=True, show_progress_bar=self.rl_config.get("progress_bar", False), + # SB3 is not fully thread safe n_jobs=1, ) except KeyboardInterrupt: @@ -805,8 +844,9 @@ class ReforceXY(BaseReinforcementLearningModel): raise RuntimeError("Environments not set. Cannot run HPO trial") if "PPO" in self.model_type: params = sample_params_ppo(trial, self.n_envs) - if params.get("n_steps", 0) > total_timesteps: - raise TrialPruned("n_steps is greater than total_timesteps") + n_steps = params.get("n_steps", 0) + if n_steps * self.n_envs > total_timesteps: + raise TrialPruned("n_steps * n_envs is greater than total_timesteps") elif "QRDQN" in self.model_type: params = sample_params_qrdqn(trial) elif "DQN" in self.model_type: @@ -840,9 +880,8 @@ class ReforceXY(BaseReinforcementLearningModel): **params, ) - callbacks = self.get_callbacks( - max(1, len(train_df) // self.n_envs), str(dk.data_path), trial - ) + eval_freq = self.get_eval_freq(len(train_df), params) + callbacks = self.get_callbacks(eval_freq, str(dk.data_path), trial) try: model.learn(total_timesteps=total_timesteps, callback=callbacks) except AssertionError: @@ -1161,18 +1200,17 @@ class ReforceXY(BaseReinforcementLearningModel): self._last_closed_trade_tick = self._current_tick self._last_trade_tick = None - def execute_trade(self, action: int) -> None: + def execute_trade(self, action: int) -> Optional[str]: """ Execute trade based on the given action """ # Force exit trade if self._force_action: - self._exit_trade() # Exit trade due to force action - self.append_trade_history(f"{self._force_action.name}") + self._exit_trade() self.tensorboard_log( f"{self._force_action.name}", category="actions/force" ) - return None + return f"{self._force_action.name}" if not self.is_tradesignal(action): return None @@ -1180,12 +1218,14 @@ class ReforceXY(BaseReinforcementLearningModel): # Enter trade based on action if action in (Actions.Long_enter.value, Actions.Short_enter.value): self._enter_trade(action) - self.append_trade_history(f"{self._position.name}_enter") + return f"{self._position.name}_enter" # Exit trade based on action if action in (Actions.Long_exit.value, Actions.Short_exit.value): self._exit_trade() - self.append_trade_history(f"{self._last_closed_position.name}_exit") + return f"{self._last_closed_position.name}_exit" + + return None def step( self, action: int @@ -1195,13 +1235,17 @@ class ReforceXY(BaseReinforcementLearningModel): """ self._current_tick += 1 self._update_unrealized_total_profit() - pnl = self.get_unrealized_profit() + previous_pnl = self.get_unrealized_profit() self._update_portfolio_log_returns() self._force_action = self._get_force_action() reward = self.calculate_reward(action) self.total_reward += reward self.tensorboard_log(Actions._member_names_[action], category="actions") - self.execute_trade(action) + trade_type = self.execute_trade(action) + if trade_type is not None: + self.append_trade_history( + trade_type, self.current_price(), previous_pnl + ) self._position_history.append(self._position) info = { "tick": self._current_tick, @@ -1210,7 +1254,8 @@ class ReforceXY(BaseReinforcementLearningModel): "force_action": ( self._force_action.name if self._force_action else None ), - "pnl": round(pnl, 5), + "previous_pnl": round(previous_pnl, 5), + "pnl": round(self.get_unrealized_profit(), 5), "reward": round(reward, 5), "total_reward": round(self.total_reward, 5), "total_profit": round(self._total_profit, 5), @@ -1227,13 +1272,15 @@ class ReforceXY(BaseReinforcementLearningModel): info, ) - def append_trade_history(self, trade_type: str) -> None: + def append_trade_history( + self, trade_type: str, price: float, profit: float + ) -> None: self.trade_history.append( { "tick": self._current_tick, "type": trade_type.lower(), - "price": self.current_price(), - "profit": self.get_unrealized_profit(), + "price": price, + "profit": profit, } ) @@ -1361,31 +1408,55 @@ class ReforceXY(BaseReinforcementLearningModel): def get_env_history(self) -> DataFrame: """ - Get environment data from the first to the last trade + Get environment data aligned on ticks, including optional trade events """ - if not self.history or not self.trade_history: - logger.warning("history or trade_history is empty") + if not self.history: + logger.warning("history is empty") return DataFrame() _history_df = DataFrame.from_dict(self.history) - _trade_history_df = DataFrame.from_dict(self.trade_history) - - if ( - "tick" not in _history_df.columns - or "tick" not in _trade_history_df.columns - ): - logger.warning("'tick' column is missing from history or trade_history") + if "tick" not in _history_df.columns: + logger.warning("'tick' column is missing from history") return DataFrame() - _rollout_history = merge( - _history_df, _trade_history_df, on="tick", how="left" - ).ffill() - _price_history = ( - self.prices.iloc[_rollout_history.tick].copy().reset_index() - ) - history = merge( - _rollout_history, _price_history, left_index=True, right_index=True - ) + if self.trade_history: + _trade_history_df = DataFrame.from_dict(self.trade_history) + if "tick" in _trade_history_df.columns: + _rollout_history = merge( + _history_df, _trade_history_df, on="tick", how="left" + ) + else: + _rollout_history = _history_df.copy() + else: + _rollout_history = _history_df.copy() + + try: + history = merge( + _rollout_history, + self.prices, + left_on="tick", + right_index=True, + how="left", + ) + except Exception: + try: + _price_history = ( + self.prices.iloc[_rollout_history.tick] + .copy() + .reset_index(drop=True) + ) + history = merge( + _rollout_history, + _price_history, + left_index=True, + right_index=True, + ) + except Exception as e: + logger.error( + f"Failed to merge history with prices: {repr(e)}", + exc_info=True, + ) + return DataFrame() return history def get_env_plot(self) -> plt.Figure: @@ -1396,79 +1467,140 @@ class ReforceXY(BaseReinforcementLearningModel): def transform_y_offset(ax, offset): return mtransforms.offset_copy(ax.transData, fig=fig, y=offset) - def plot_markers(ax, ticks, marker, color, size, offset): + def plot_markers(ax, xs, ys, marker, color, size, offset): ax.plot( - ticks, + xs, + ys, marker=marker, color=color, markersize=size, fillstyle="full", transform=transform_y_offset(ax, offset), linestyle="none", + zorder=3, ) - plt.style.use("dark_background") - fig, axs = plt.subplots( - nrows=5, - ncols=1, - figsize=(14, 8), - height_ratios=[5, 1, 1, 1, 1], - sharex=True, - ) + with plt.style.context("dark_background"): + fig, axs = plt.subplots( + nrows=5, + ncols=1, + figsize=(14, 8), + height_ratios=[5, 1, 1, 1, 1], + sharex=True, + ) - if len(self.trade_history) == 0: - return fig - - history = self.get_env_history() - if len(history) == 0: - return fig - - history_price = history.get("price") - if history_price is None or len(history_price) == 0: - return fig - history_type = history.get("type") - if history_type is None or len(history_type) == 0: - return fig - history_open = history.get("open") - if history_open is None or len(history_open) == 0: - return fig - - enter_long_prices = history.loc[history_type == "long_enter", "price"] - enter_short_prices = history.loc[history_type == "short_enter", "price"] - exit_long_prices = history.loc[history_type == "long_exit", "price"] - exit_short_prices = history.loc[history_type == "short_exit", "price"] - take_profit_prices = history.loc[history_type == "take_profit", "price"] - stop_loss_prices = history.loc[history_type == "stop_loss", "price"] - timeout_prices = history.loc[history_type == "timeout", "price"] - - axs[0].plot(history_open, linewidth=1, color="orchid") - - plot_markers(axs[0], enter_long_prices, "^", "forestgreen", 5, -0.1) - plot_markers(axs[0], enter_short_prices, "v", "firebrick", 5, 0.1) - plot_markers(axs[0], exit_long_prices, ".", "cornflowerblue", 4, 0.1) - plot_markers(axs[0], exit_short_prices, ".", "thistle", 4, -0.1) - plot_markers(axs[0], take_profit_prices, "*", "lime", 8, 0.1) - plot_markers(axs[0], stop_loss_prices, "x", "red", 8, -0.1) - plot_markers(axs[0], timeout_prices, "1", "yellow", 8, 0.0) - - axs[1].set_ylabel("pnl") - axs[1].plot(history.get("pnl"), linewidth=1, color="gray") - axs[1].axhline(y=0, label="0", alpha=0.33, color="gray") - axs[1].axhline(y=self.take_profit, label="tp", alpha=0.33, color="green") - axs[1].axhline(y=self.stop_loss, label="sl", alpha=0.33, color="red") - - axs[2].set_ylabel("reward") - axs[2].plot(history.get("reward"), linewidth=1, color="gray") - axs[2].axhline(y=0, label="0", alpha=0.33) - - axs[3].set_ylabel("total_profit") - axs[3].plot(history.get("total_profit"), linewidth=1, color="gray") - axs[3].axhline(y=1, label="0", alpha=0.33) - - axs[4].set_ylabel("total_reward") - axs[4].plot(history.get("total_reward"), linewidth=1, color="gray") - axs[4].axhline(y=0, label="0", alpha=0.33) - axs[4].set_xlabel("tick") + history = self.get_env_history() + if len(history) == 0: + return fig + + plot_window = int(self.rl_config.get("plot_window", 2000)) + if plot_window > 0 and len(history) > plot_window: + history = history.iloc[-plot_window:] + + ticks = history.get("tick") + history_open = history.get("open") + if ( + ticks is None + or len(ticks) == 0 + or history_open is None + or len(history_open) == 0 + ): + return fig + + axs[0].plot(ticks, history_open, linewidth=1, color="orchid", zorder=1) + + history_type = history.get("type") + history_price = history.get("price") + if history_type is not None and history_price is not None: + enter_long_mask = history_type == "long_enter" + enter_short_mask = history_type == "short_enter" + exit_long_mask = history_type == "long_exit" + exit_short_mask = history_type == "short_exit" + take_profit_mask = history_type == "take_profit" + stop_loss_mask = history_type == "stop_loss" + timeout_mask = history_type == "timeout" + + enter_long_x = ticks[enter_long_mask] + enter_short_x = ticks[enter_short_mask] + exit_long_x = ticks[exit_long_mask] + exit_short_x = ticks[exit_short_mask] + take_profit_x = ticks[take_profit_mask] + stop_loss_x = ticks[stop_loss_mask] + timeout_x = ticks[timeout_mask] + + enter_long_y = history.loc[enter_long_mask, "price"] + enter_short_y = history.loc[enter_short_mask, "price"] + exit_long_y = history.loc[exit_long_mask, "price"] + exit_short_y = history.loc[exit_short_mask, "price"] + take_profit_y = history.loc[take_profit_mask, "price"] + stop_loss_y = history.loc[stop_loss_mask, "price"] + timeout_y = history.loc[timeout_mask, "price"] + + plot_markers( + axs[0], enter_long_x, enter_long_y, "^", "forestgreen", 5, -0.1 + ) + plot_markers( + axs[0], enter_short_x, enter_short_y, "v", "firebrick", 5, 0.1 + ) + plot_markers( + axs[0], exit_long_x, exit_long_y, ".", "cornflowerblue", 4, 0.1 + ) + plot_markers( + axs[0], exit_short_x, exit_short_y, ".", "thistle", 4, -0.1 + ) + plot_markers( + axs[0], take_profit_x, take_profit_y, "*", "lime", 8, 0.1 + ) + plot_markers(axs[0], stop_loss_x, stop_loss_y, "x", "red", 8, -0.1) + plot_markers(axs[0], timeout_x, timeout_y, "1", "yellow", 8, 0.0) + + axs[1].set_ylabel("pnl") + pnl_series = history.get("pnl") + if pnl_series is not None and len(pnl_series) > 0: + axs[1].plot( + ticks, + pnl_series, + linewidth=1, + color="gray", + label="pnl", + ) + previous_pnl_series = history.get("previous_pnl") + if previous_pnl_series is not None and len(previous_pnl_series) > 0: + axs[1].plot( + ticks, + previous_pnl_series, + linewidth=1, + color="deepskyblue", + label="previous_pnl", + ) + if (pnl_series is not None and len(pnl_series) > 0) or ( + previous_pnl_series is not None and len(previous_pnl_series) > 0 + ): + axs[1].legend(loc="upper left", fontsize=8) + axs[1].axhline(y=0, label="0", alpha=0.33, color="gray") + axs[1].axhline( + y=self.take_profit, label="tp", alpha=0.33, color="green" + ) + axs[1].axhline(y=self.stop_loss, label="sl", alpha=0.33, color="red") + + axs[2].set_ylabel("reward") + reward_series = history.get("reward") + if reward_series is not None and len(reward_series) > 0: + axs[2].plot(ticks, reward_series, linewidth=1, color="gray") + axs[2].axhline(y=0, label="0", alpha=0.33) + + axs[3].set_ylabel("total_profit") + total_profit_series = history.get("total_profit") + if total_profit_series is not None and len(total_profit_series) > 0: + axs[3].plot(ticks, total_profit_series, linewidth=1, color="gray") + axs[3].axhline(y=1, label="1", alpha=0.33) + + axs[4].set_ylabel("total_reward") + total_reward_series = history.get("total_reward") + if total_reward_series is not None and len(total_reward_series) > 0: + axs[4].plot(ticks, total_reward_series, linewidth=1, color="gray") + axs[4].axhline(y=0, label="0", alpha=0.33) + axs[4].set_xlabel("tick") _borders = ["top", "right", "bottom", "left"] for _ax in axs: @@ -1836,18 +1968,41 @@ class MaskableTrialEvalCallback(MaskableEvalCallback): if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0: self.eval_idx += 1 last_mean_reward = getattr(self, "last_mean_reward", np.nan) - if not isinstance(last_mean_reward, (int, float)) or not np.isfinite( - float(last_mean_reward) - ): + try: + last_mean_reward = float(last_mean_reward) + except Exception as e: + logger.warning( + "Optuna: invalid last_mean_reward at eval %s: %r", self.eval_idx, e + ) self.is_pruned = True return False - if hasattr(self.trial, "report"): - try: - self.trial.report(last_mean_reward, self.eval_idx) - except Exception: + if not np.isfinite(last_mean_reward): + logger.warning( + "Optuna: non-finite last_mean_reward at eval %s", self.eval_idx + ) + self.is_pruned = True + return False + try: + self.trial.report(last_mean_reward, self.eval_idx) + except Exception as e: + logger.warning( + "Optuna: trial.report failed at eval %s: %r", self.eval_idx, e + ) + self.is_pruned = True + return False + try: + if self.trial.should_prune(): + logger.info( + "Optuna: trial pruned at eval %s (score=%.5f)", + self.eval_idx, + last_mean_reward, + ) self.is_pruned = True return False - if hasattr(self.trial, "should_prune") and self.trial.should_prune(): + except Exception as e: + logger.warning( + "Optuna: should_prune failed at eval %s: %r", self.eval_idx, e + ) self.is_pruned = True return False @@ -1900,7 +2055,7 @@ def deepmerge(dst: Dict[str, Any], src: Dict[str, Any]) -> Dict[str, Any]: return dst_copy -@lru_cache(maxsize=64) +@lru_cache(maxsize=128) def linear_schedule(initial_value: float) -> Callable[[float], float]: def func(progress_remaining: float) -> float: return progress_remaining * initial_value @@ -2084,11 +2239,14 @@ def convert_optuna_params_to_model_params( return model_params +PPO_N_STEPS: tuple[int, ...] = (512, 1024, 2048, 4096) + + def sample_params_ppo(trial: Trial, n_envs: int) -> Dict[str, Any]: """ Sampler for PPO hyperparams """ - n_steps = trial.suggest_categorical("n_steps", [512, 1024, 2048, 4096]) + n_steps = trial.suggest_categorical("n_steps", list(PPO_N_STEPS)) batch_size = trial.suggest_categorical("batch_size", [64, 128, 256, 512, 1024]) if batch_size > n_steps: raise TrialPruned("batch_size must be less than or equal to n_steps") -- 2.43.0