from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.logger import Figure, HParam
from stable_baselines3.common.utils import set_random_seed
-from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecMonitor
+from stable_baselines3.common.vec_env import (
+ DummyVecEnv,
+ SubprocVecEnv,
+ VecFrameStack,
+ VecMonitor,
+)
matplotlib.use("Agg")
warnings.filterwarnings("ignore", category=UserWarning)
if not isinstance(self.n_envs, int) or self.n_envs < 1:
logger.warning("Invalid n_envs=%s. Forcing n_envs=1", self.n_envs)
self.n_envs = 1
+ vec_env = self.rl_config.get("vec_env", "dummy")
+ if vec_env == "subproc" and self.plot_new_best:
+ logger.warning(
+ "User tried to use plot_new_best with SubprocVecEnv. Deactivating plot_new_best"
+ )
+ self.plot_new_best = False
if not isinstance(self.frame_stacking, int) or self.frame_stacking < 0:
logger.warning(
"Invalid frame_stacking=%s. Forcing frame_stacking=0",
_eval_env_check.close()
logger.info("Populating environments: %s", self.n_envs)
- train_env = DummyVecEnv(
- [
- make_env(
- self.MyRLEnv,
- f"train_env{i}",
- i,
- seed,
- train_df,
- prices_train,
- env_info=env_dict,
- )
- for i in range(self.n_envs)
- ]
- )
- eval_env = DummyVecEnv(
- [
- make_env(
- self.MyRLEnv,
- f"eval_env{i}",
- i,
- seed + 10_000,
- test_df,
- prices_test,
- env_info=env_dict,
- )
- for i in range(self.n_envs)
- ]
- )
+ train_fns = [
+ make_env(
+ self.MyRLEnv,
+ f"train_env{i}",
+ i,
+ seed,
+ train_df,
+ prices_train,
+ env_info=env_dict,
+ )
+ for i in range(self.n_envs)
+ ]
+ eval_fns = [
+ make_env(
+ self.MyRLEnv,
+ f"eval_env{i}",
+ i,
+ seed + 10_000,
+ test_df,
+ prices_test,
+ env_info=env_dict,
+ )
+ for i in range(self.n_envs)
+ ]
+ vec_env = str(self.rl_config.get("vec_env", "dummy"))
+ if vec_env == "dummy":
+ logger.info("Using DummyVecEnv")
+ train_env = DummyVecEnv(train_fns)
+ eval_env = DummyVecEnv(eval_fns)
+ elif vec_env == "subproc":
+ logger.info("Using SubprocVecEnv")
+ train_env = SubprocVecEnv(train_fns, start_method="spawn")
+ eval_env = SubprocVecEnv(eval_fns, start_method="spawn")
+ else:
+ raise ValueError(f"Invalid vec_env: {vec_env}")
if self.frame_stacking:
logger.info("Frame stacking: %s", self.frame_stacking)
self._model_params_cache = model_params
return copy.deepcopy(self._model_params_cache)
+ def get_eval_freq(
+ self, train_timesteps: int, model_params: Optional[Dict[str, Any]] = None
+ ) -> int:
+ if "PPO" in self.model_type:
+ if model_params:
+ n_steps = model_params.get("n_steps")
+ if isinstance(n_steps, int) and n_steps > 0:
+ return n_steps
+ for s in sorted(PPO_N_STEPS, reverse=True):
+ if s <= train_timesteps:
+ return s
+ return PPO_N_STEPS[0]
+ return max(1, train_timesteps // max(1, self.n_envs))
+
def get_callbacks(
self, eval_freq: int, data_path: str, trial: Optional[Trial] = None
) -> list[BaseCallback]:
)
callbacks.append(self.eval_callback)
else:
- trial_data_path = f"{data_path}/trial_{trial.number}"
+ trial_data_path = f"{data_path}/hyperopt/trial_{trial.number}"
self.optuna_callback = MaskableTrialEvalCallback(
self.eval_env,
trial,
**model_params,
)
- callbacks = self.get_callbacks(
- max(1, train_timesteps // self.n_envs), str(dk.data_path)
- )
+ eval_freq = self.get_eval_freq(train_timesteps, model_params)
+ callbacks = self.get_callbacks(eval_freq, str(dk.data_path))
try:
model.learn(total_timesteps=total_timesteps, callback=callbacks)
finally:
if self.rl_config_optuna.get("per_pair", False)
else self.get_storage()
)
- eval_freq = max(1, len(train_df) // self.n_envs)
- max_resource = max(1, (total_timesteps + eval_freq - 1) // eval_freq)
+ if "PPO" in self.model_type:
+ resource_eval_freq = min(PPO_N_STEPS)
+ else:
+ resource_eval_freq = self.get_eval_freq(len(train_df))
+ max_resource = max(
+ 1,
+ (total_timesteps // self.n_envs + resource_eval_freq - 1)
+ // resource_eval_freq,
+ )
min_resource = min(3, max_resource)
study: Study = create_study(
study_name=study_name,
),
gc_after_trial=True,
show_progress_bar=self.rl_config.get("progress_bar", False),
+ # SB3 is not fully thread safe
n_jobs=1,
)
except KeyboardInterrupt:
raise RuntimeError("Environments not set. Cannot run HPO trial")
if "PPO" in self.model_type:
params = sample_params_ppo(trial, self.n_envs)
- if params.get("n_steps", 0) > total_timesteps:
- raise TrialPruned("n_steps is greater than total_timesteps")
+ n_steps = params.get("n_steps", 0)
+ if n_steps * self.n_envs > total_timesteps:
+ raise TrialPruned("n_steps * n_envs is greater than total_timesteps")
elif "QRDQN" in self.model_type:
params = sample_params_qrdqn(trial)
elif "DQN" in self.model_type:
**params,
)
- callbacks = self.get_callbacks(
- max(1, len(train_df) // self.n_envs), str(dk.data_path), trial
- )
+ eval_freq = self.get_eval_freq(len(train_df), params)
+ callbacks = self.get_callbacks(eval_freq, str(dk.data_path), trial)
try:
model.learn(total_timesteps=total_timesteps, callback=callbacks)
except AssertionError:
self._last_closed_trade_tick = self._current_tick
self._last_trade_tick = None
- def execute_trade(self, action: int) -> None:
+ def execute_trade(self, action: int) -> Optional[str]:
"""
Execute trade based on the given action
"""
# Force exit trade
if self._force_action:
- self._exit_trade() # Exit trade due to force action
- self.append_trade_history(f"{self._force_action.name}")
+ self._exit_trade()
self.tensorboard_log(
f"{self._force_action.name}", category="actions/force"
)
- return None
+ return f"{self._force_action.name}"
if not self.is_tradesignal(action):
return None
# Enter trade based on action
if action in (Actions.Long_enter.value, Actions.Short_enter.value):
self._enter_trade(action)
- self.append_trade_history(f"{self._position.name}_enter")
+ return f"{self._position.name}_enter"
# Exit trade based on action
if action in (Actions.Long_exit.value, Actions.Short_exit.value):
self._exit_trade()
- self.append_trade_history(f"{self._last_closed_position.name}_exit")
+ return f"{self._last_closed_position.name}_exit"
+
+ return None
def step(
self, action: int
"""
self._current_tick += 1
self._update_unrealized_total_profit()
- pnl = self.get_unrealized_profit()
+ previous_pnl = self.get_unrealized_profit()
self._update_portfolio_log_returns()
self._force_action = self._get_force_action()
reward = self.calculate_reward(action)
self.total_reward += reward
self.tensorboard_log(Actions._member_names_[action], category="actions")
- self.execute_trade(action)
+ trade_type = self.execute_trade(action)
+ if trade_type is not None:
+ self.append_trade_history(
+ trade_type, self.current_price(), previous_pnl
+ )
self._position_history.append(self._position)
info = {
"tick": self._current_tick,
"force_action": (
self._force_action.name if self._force_action else None
),
- "pnl": round(pnl, 5),
+ "previous_pnl": round(previous_pnl, 5),
+ "pnl": round(self.get_unrealized_profit(), 5),
"reward": round(reward, 5),
"total_reward": round(self.total_reward, 5),
"total_profit": round(self._total_profit, 5),
info,
)
- def append_trade_history(self, trade_type: str) -> None:
+ def append_trade_history(
+ self, trade_type: str, price: float, profit: float
+ ) -> None:
self.trade_history.append(
{
"tick": self._current_tick,
"type": trade_type.lower(),
- "price": self.current_price(),
- "profit": self.get_unrealized_profit(),
+ "price": price,
+ "profit": profit,
}
)
def get_env_history(self) -> DataFrame:
"""
- Get environment data from the first to the last trade
+ Get environment data aligned on ticks, including optional trade events
"""
- if not self.history or not self.trade_history:
- logger.warning("history or trade_history is empty")
+ if not self.history:
+ logger.warning("history is empty")
return DataFrame()
_history_df = DataFrame.from_dict(self.history)
- _trade_history_df = DataFrame.from_dict(self.trade_history)
-
- if (
- "tick" not in _history_df.columns
- or "tick" not in _trade_history_df.columns
- ):
- logger.warning("'tick' column is missing from history or trade_history")
+ if "tick" not in _history_df.columns:
+ logger.warning("'tick' column is missing from history")
return DataFrame()
- _rollout_history = merge(
- _history_df, _trade_history_df, on="tick", how="left"
- ).ffill()
- _price_history = (
- self.prices.iloc[_rollout_history.tick].copy().reset_index()
- )
- history = merge(
- _rollout_history, _price_history, left_index=True, right_index=True
- )
+ if self.trade_history:
+ _trade_history_df = DataFrame.from_dict(self.trade_history)
+ if "tick" in _trade_history_df.columns:
+ _rollout_history = merge(
+ _history_df, _trade_history_df, on="tick", how="left"
+ )
+ else:
+ _rollout_history = _history_df.copy()
+ else:
+ _rollout_history = _history_df.copy()
+
+ try:
+ history = merge(
+ _rollout_history,
+ self.prices,
+ left_on="tick",
+ right_index=True,
+ how="left",
+ )
+ except Exception:
+ try:
+ _price_history = (
+ self.prices.iloc[_rollout_history.tick]
+ .copy()
+ .reset_index(drop=True)
+ )
+ history = merge(
+ _rollout_history,
+ _price_history,
+ left_index=True,
+ right_index=True,
+ )
+ except Exception as e:
+ logger.error(
+ f"Failed to merge history with prices: {repr(e)}",
+ exc_info=True,
+ )
+ return DataFrame()
return history
def get_env_plot(self) -> plt.Figure:
def transform_y_offset(ax, offset):
return mtransforms.offset_copy(ax.transData, fig=fig, y=offset)
- def plot_markers(ax, ticks, marker, color, size, offset):
+ def plot_markers(ax, xs, ys, marker, color, size, offset):
ax.plot(
- ticks,
+ xs,
+ ys,
marker=marker,
color=color,
markersize=size,
fillstyle="full",
transform=transform_y_offset(ax, offset),
linestyle="none",
+ zorder=3,
)
- plt.style.use("dark_background")
- fig, axs = plt.subplots(
- nrows=5,
- ncols=1,
- figsize=(14, 8),
- height_ratios=[5, 1, 1, 1, 1],
- sharex=True,
- )
+ with plt.style.context("dark_background"):
+ fig, axs = plt.subplots(
+ nrows=5,
+ ncols=1,
+ figsize=(14, 8),
+ height_ratios=[5, 1, 1, 1, 1],
+ sharex=True,
+ )
- if len(self.trade_history) == 0:
- return fig
-
- history = self.get_env_history()
- if len(history) == 0:
- return fig
-
- history_price = history.get("price")
- if history_price is None or len(history_price) == 0:
- return fig
- history_type = history.get("type")
- if history_type is None or len(history_type) == 0:
- return fig
- history_open = history.get("open")
- if history_open is None or len(history_open) == 0:
- return fig
-
- enter_long_prices = history.loc[history_type == "long_enter", "price"]
- enter_short_prices = history.loc[history_type == "short_enter", "price"]
- exit_long_prices = history.loc[history_type == "long_exit", "price"]
- exit_short_prices = history.loc[history_type == "short_exit", "price"]
- take_profit_prices = history.loc[history_type == "take_profit", "price"]
- stop_loss_prices = history.loc[history_type == "stop_loss", "price"]
- timeout_prices = history.loc[history_type == "timeout", "price"]
-
- axs[0].plot(history_open, linewidth=1, color="orchid")
-
- plot_markers(axs[0], enter_long_prices, "^", "forestgreen", 5, -0.1)
- plot_markers(axs[0], enter_short_prices, "v", "firebrick", 5, 0.1)
- plot_markers(axs[0], exit_long_prices, ".", "cornflowerblue", 4, 0.1)
- plot_markers(axs[0], exit_short_prices, ".", "thistle", 4, -0.1)
- plot_markers(axs[0], take_profit_prices, "*", "lime", 8, 0.1)
- plot_markers(axs[0], stop_loss_prices, "x", "red", 8, -0.1)
- plot_markers(axs[0], timeout_prices, "1", "yellow", 8, 0.0)
-
- axs[1].set_ylabel("pnl")
- axs[1].plot(history.get("pnl"), linewidth=1, color="gray")
- axs[1].axhline(y=0, label="0", alpha=0.33, color="gray")
- axs[1].axhline(y=self.take_profit, label="tp", alpha=0.33, color="green")
- axs[1].axhline(y=self.stop_loss, label="sl", alpha=0.33, color="red")
-
- axs[2].set_ylabel("reward")
- axs[2].plot(history.get("reward"), linewidth=1, color="gray")
- axs[2].axhline(y=0, label="0", alpha=0.33)
-
- axs[3].set_ylabel("total_profit")
- axs[3].plot(history.get("total_profit"), linewidth=1, color="gray")
- axs[3].axhline(y=1, label="0", alpha=0.33)
-
- axs[4].set_ylabel("total_reward")
- axs[4].plot(history.get("total_reward"), linewidth=1, color="gray")
- axs[4].axhline(y=0, label="0", alpha=0.33)
- axs[4].set_xlabel("tick")
+ history = self.get_env_history()
+ if len(history) == 0:
+ return fig
+
+ plot_window = int(self.rl_config.get("plot_window", 2000))
+ if plot_window > 0 and len(history) > plot_window:
+ history = history.iloc[-plot_window:]
+
+ ticks = history.get("tick")
+ history_open = history.get("open")
+ if (
+ ticks is None
+ or len(ticks) == 0
+ or history_open is None
+ or len(history_open) == 0
+ ):
+ return fig
+
+ axs[0].plot(ticks, history_open, linewidth=1, color="orchid", zorder=1)
+
+ history_type = history.get("type")
+ history_price = history.get("price")
+ if history_type is not None and history_price is not None:
+ enter_long_mask = history_type == "long_enter"
+ enter_short_mask = history_type == "short_enter"
+ exit_long_mask = history_type == "long_exit"
+ exit_short_mask = history_type == "short_exit"
+ take_profit_mask = history_type == "take_profit"
+ stop_loss_mask = history_type == "stop_loss"
+ timeout_mask = history_type == "timeout"
+
+ enter_long_x = ticks[enter_long_mask]
+ enter_short_x = ticks[enter_short_mask]
+ exit_long_x = ticks[exit_long_mask]
+ exit_short_x = ticks[exit_short_mask]
+ take_profit_x = ticks[take_profit_mask]
+ stop_loss_x = ticks[stop_loss_mask]
+ timeout_x = ticks[timeout_mask]
+
+ enter_long_y = history.loc[enter_long_mask, "price"]
+ enter_short_y = history.loc[enter_short_mask, "price"]
+ exit_long_y = history.loc[exit_long_mask, "price"]
+ exit_short_y = history.loc[exit_short_mask, "price"]
+ take_profit_y = history.loc[take_profit_mask, "price"]
+ stop_loss_y = history.loc[stop_loss_mask, "price"]
+ timeout_y = history.loc[timeout_mask, "price"]
+
+ plot_markers(
+ axs[0], enter_long_x, enter_long_y, "^", "forestgreen", 5, -0.1
+ )
+ plot_markers(
+ axs[0], enter_short_x, enter_short_y, "v", "firebrick", 5, 0.1
+ )
+ plot_markers(
+ axs[0], exit_long_x, exit_long_y, ".", "cornflowerblue", 4, 0.1
+ )
+ plot_markers(
+ axs[0], exit_short_x, exit_short_y, ".", "thistle", 4, -0.1
+ )
+ plot_markers(
+ axs[0], take_profit_x, take_profit_y, "*", "lime", 8, 0.1
+ )
+ plot_markers(axs[0], stop_loss_x, stop_loss_y, "x", "red", 8, -0.1)
+ plot_markers(axs[0], timeout_x, timeout_y, "1", "yellow", 8, 0.0)
+
+ axs[1].set_ylabel("pnl")
+ pnl_series = history.get("pnl")
+ if pnl_series is not None and len(pnl_series) > 0:
+ axs[1].plot(
+ ticks,
+ pnl_series,
+ linewidth=1,
+ color="gray",
+ label="pnl",
+ )
+ previous_pnl_series = history.get("previous_pnl")
+ if previous_pnl_series is not None and len(previous_pnl_series) > 0:
+ axs[1].plot(
+ ticks,
+ previous_pnl_series,
+ linewidth=1,
+ color="deepskyblue",
+ label="previous_pnl",
+ )
+ if (pnl_series is not None and len(pnl_series) > 0) or (
+ previous_pnl_series is not None and len(previous_pnl_series) > 0
+ ):
+ axs[1].legend(loc="upper left", fontsize=8)
+ axs[1].axhline(y=0, label="0", alpha=0.33, color="gray")
+ axs[1].axhline(
+ y=self.take_profit, label="tp", alpha=0.33, color="green"
+ )
+ axs[1].axhline(y=self.stop_loss, label="sl", alpha=0.33, color="red")
+
+ axs[2].set_ylabel("reward")
+ reward_series = history.get("reward")
+ if reward_series is not None and len(reward_series) > 0:
+ axs[2].plot(ticks, reward_series, linewidth=1, color="gray")
+ axs[2].axhline(y=0, label="0", alpha=0.33)
+
+ axs[3].set_ylabel("total_profit")
+ total_profit_series = history.get("total_profit")
+ if total_profit_series is not None and len(total_profit_series) > 0:
+ axs[3].plot(ticks, total_profit_series, linewidth=1, color="gray")
+ axs[3].axhline(y=1, label="1", alpha=0.33)
+
+ axs[4].set_ylabel("total_reward")
+ total_reward_series = history.get("total_reward")
+ if total_reward_series is not None and len(total_reward_series) > 0:
+ axs[4].plot(ticks, total_reward_series, linewidth=1, color="gray")
+ axs[4].axhline(y=0, label="0", alpha=0.33)
+ axs[4].set_xlabel("tick")
_borders = ["top", "right", "bottom", "left"]
for _ax in axs:
if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
self.eval_idx += 1
last_mean_reward = getattr(self, "last_mean_reward", np.nan)
- if not isinstance(last_mean_reward, (int, float)) or not np.isfinite(
- float(last_mean_reward)
- ):
+ try:
+ last_mean_reward = float(last_mean_reward)
+ except Exception as e:
+ logger.warning(
+ "Optuna: invalid last_mean_reward at eval %s: %r", self.eval_idx, e
+ )
self.is_pruned = True
return False
- if hasattr(self.trial, "report"):
- try:
- self.trial.report(last_mean_reward, self.eval_idx)
- except Exception:
+ if not np.isfinite(last_mean_reward):
+ logger.warning(
+ "Optuna: non-finite last_mean_reward at eval %s", self.eval_idx
+ )
+ self.is_pruned = True
+ return False
+ try:
+ self.trial.report(last_mean_reward, self.eval_idx)
+ except Exception as e:
+ logger.warning(
+ "Optuna: trial.report failed at eval %s: %r", self.eval_idx, e
+ )
+ self.is_pruned = True
+ return False
+ try:
+ if self.trial.should_prune():
+ logger.info(
+ "Optuna: trial pruned at eval %s (score=%.5f)",
+ self.eval_idx,
+ last_mean_reward,
+ )
self.is_pruned = True
return False
- if hasattr(self.trial, "should_prune") and self.trial.should_prune():
+ except Exception as e:
+ logger.warning(
+ "Optuna: should_prune failed at eval %s: %r", self.eval_idx, e
+ )
self.is_pruned = True
return False
return dst_copy
-@lru_cache(maxsize=64)
+@lru_cache(maxsize=128)
def linear_schedule(initial_value: float) -> Callable[[float], float]:
def func(progress_remaining: float) -> float:
return progress_remaining * initial_value
return model_params
+PPO_N_STEPS: tuple[int, ...] = (512, 1024, 2048, 4096)
+
+
def sample_params_ppo(trial: Trial, n_envs: int) -> Dict[str, Any]:
"""
Sampler for PPO hyperparams
"""
- n_steps = trial.suggest_categorical("n_steps", [512, 1024, 2048, 4096])
+ n_steps = trial.suggest_categorical("n_steps", list(PPO_N_STEPS))
batch_size = trial.suggest_categorical("batch_size", [64, 128, 256, 512, 1024])
if batch_size > n_steps:
raise TrialPruned("batch_size must be less than or equal to n_steps")