from collections.abc import Mapping
from enum import IntEnum
from pathlib import Path
-from statistics import stdev
from typing import Any, Callable, Dict, Literal, Optional, Tuple, Type, Union
import matplotlib
from optuna.study import Study, StudyDirection
from pandas import DataFrame, concat, merge
from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback
+from sb3_contrib.common.maskable.utils import is_masking_supported
from stable_baselines3.common.callbacks import (
BaseCallback,
ProgressBarCallback,
- pip install optuna-dashboard
"""
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
self.pairs: list[str] = self.config.get("exchange", {}).get("pair_whitelist")
if not self.pairs:
raise ValueError(
self.action_masking: bool = (
self.model_type == "MaskablePPO"
) # Enable action masking
- self.rl_config["action_masking"] = self.action_masking
+ self.rl_config.setdefault("action_masking", self.action_masking)
self.inference_masking: bool = self.rl_config.get("inference_masking", True)
self.lr_schedule: bool = self.rl_config.get("lr_schedule", False)
self.cr_schedule: bool = self.rl_config.get("cr_schedule", False)
"max_no_improvement_evals", 0
)
self.min_evals: int = self.rl_config.get("min_evals", 0)
+ self.rl_config.setdefault("tensorboard_throttle", 1)
self.plot_new_best: bool = self.rl_config.get("plot_new_best", False)
self.check_envs: bool = self.rl_config.get("check_envs", True)
self.progressbar_callback: Optional[ProgressBarCallback] = None
self.optuna_n_startup_trials: int = self.rl_config_optuna.get(
"n_startup_trials", 15
)
- self.optuna_callback: Optional[MaskableTrialEvalCallback] = None
+ self.optuna_eval_callback: Optional[MaskableTrialEvalCallback] = None
self._model_params_cache: Optional[Dict[str, Any]] = None
self.unset_unsupported()
if self.check_envs:
logger.info("Checking environments")
- _train_env_check = self.MyRLEnv(
+ _train_env_check = MyRLEnv(
df=train_df,
prices=prices_train,
id="train_env_check",
check_env(_train_env_check)
finally:
_train_env_check.close()
- _eval_env_check = self.MyRLEnv(
+ _eval_env_check = MyRLEnv(
df=test_df,
prices=prices_test,
id="eval_env_check",
logger.info("Populating environments: %s", self.n_envs)
train_fns = [
make_env(
- self.MyRLEnv,
+ MyRLEnv,
f"train_env{i}",
i,
seed,
]
eval_fns = [
make_env(
- self.MyRLEnv,
+ MyRLEnv,
f"eval_env{i}",
i,
seed + 10_000,
)
if self.activate_tensorboard:
- info_callback = InfoMetricsCallback(actions=Actions, verbose=verbose)
+ info_callback = InfoMetricsCallback(
+ actions=Actions,
+ verbose=verbose,
+ throttle=self.rl_config.get("tensorboard_throttle", 1),
+ )
callbacks.append(info_callback)
if self.plot_new_best:
rollout_plot_callback = RolloutPlotCallback(verbose=verbose)
self.progressbar_callback = ProgressBarCallback()
callbacks.append(self.progressbar_callback)
+ use_masking = self.action_masking and is_masking_supported(self.eval_env)
if not trial:
self.eval_callback = MaskableEvalCallback(
self.eval_env,
deterministic=True,
render=False,
best_model_save_path=data_path,
- use_masking=self.action_masking,
+ use_masking=use_masking,
callback_on_new_best=rollout_plot_callback,
callback_after_eval=no_improvement_callback,
verbose=verbose,
callbacks.append(self.eval_callback)
else:
trial_data_path = f"{data_path}/hyperopt/trial_{trial.number}"
- self.optuna_callback = MaskableTrialEvalCallback(
+ self.optuna_eval_callback = MaskableTrialEvalCallback(
self.eval_env,
trial,
eval_freq=eval_freq,
deterministic=True,
render=False,
best_model_save_path=trial_data_path,
- use_masking=self.action_masking,
+ use_masking=use_masking,
verbose=verbose,
)
- callbacks.append(self.optuna_callback)
+ callbacks.append(self.optuna_eval_callback)
return callbacks
def fit(
if self.action_masking and self.inference_masking:
action_masks_param["action_masks"] = ReforceXY.get_action_masks(
- simulated_position, None
+ simulated_position
)
action, _ = model.predict(
if nan_encountered:
raise TrialPruned("NaN encountered during training")
- if self.optuna_callback.is_pruned:
+ if self.optuna_eval_callback.is_pruned:
raise TrialPruned()
- return self.optuna_callback.best_mean_reward
+ return self.optuna_eval_callback.last_mean_reward
def close_envs(self) -> None:
"""
finally:
self.eval_env = None
- class MyRLEnv(Base5ActionRLEnv):
+
+def make_env(
+ MyRLEnv: Type[BaseEnvironment],
+ env_id: str,
+ rank: int,
+ seed: int,
+ train_df: DataFrame,
+ price: DataFrame,
+ env_info: Dict[str, Any],
+) -> Callable[[], BaseEnvironment]:
+ """
+ Utility function for multiprocessed env.
+
+ :param MyRLEnv: (Type[BaseEnvironment]) environment class to instantiate
+ :param env_id: (str) the environment ID
+ :param rank: (int) index of the subprocess
+ :param seed: (int) the initial seed for RNG
+ :param train_df: (DataFrame) feature dataframe for the environment
+ :param price: (DataFrame) aligned price dataframe
+ :param env_info: (dict) all required arguments to instantiate the environment
+ :return:
+ (Callable[[], BaseEnvironment]) closure that when called instantiates and returns the environment
+ """
+
+ def _init() -> BaseEnvironment:
+ return MyRLEnv(
+ df=train_df, prices=price, id=env_id, seed=seed + rank, **env_info
+ )
+
+ return _init
+
+
+MyRLEnv: Type[BaseEnvironment]
+
+
+class MyRLEnv(Base5ActionRLEnv):
+ """
+ Env
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._set_observation_space()
+ self.action_masking: bool = self.rl_config.get("action_masking", False)
+ self.force_actions: bool = self.rl_config.get("force_actions", False)
+ self._force_action: Optional[ForceActions] = None
+ self.take_profit: float = self.config.get("minimal_roi", {}).get("0", 0.03)
+ self.stop_loss: float = self.config.get("stoploss", -0.02)
+ self.timeout: int = self.rl_config.get("max_trade_duration_candles", 128)
+ self._last_closed_position: Optional[Positions] = None
+ self._last_closed_trade_tick: int = 0
+ if self.force_actions:
+ logger.info(
+ "%s - take_profit: %s, stop_loss: %s, timeout: %s candles (%s days), observation_space: %s",
+ self.id,
+ self.take_profit,
+ self.stop_loss,
+ self.timeout,
+ steps_to_days(self.timeout, self.config.get("timeframe")),
+ self.observation_space,
+ )
+
+ def _set_observation_space(self) -> None:
"""
- Env
+ Set the observation space
"""
+ signal_features = self.signal_features.shape[1]
+ if self.add_state_info:
+ # STATE_INFO
+ self.state_features = ["pnl", "position", "trade_duration"]
+ self.total_features = signal_features + len(self.state_features)
+ else:
+ self.state_features = []
+ self.total_features = signal_features
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- self._set_observation_space()
- self.action_masking: bool = self.rl_config.get("action_masking", False)
- self.force_actions: bool = self.rl_config.get("force_actions", False)
- self._force_action: Optional[ForceActions] = None
- self.take_profit: float = self.config.get("minimal_roi", {}).get("0", 0.03)
- self.stop_loss: float = self.config.get("stoploss", -0.02)
- self.timeout: int = self.rl_config.get("max_trade_duration_candles", 128)
- self._last_closed_position: Optional[Positions] = None
- self._last_closed_trade_tick: int = 0
- if self.force_actions:
- logger.info(
- "%s - take_profit: %s, stop_loss: %s, timeout: %s candles (%s days), observation_space: %s",
- self.id,
- self.take_profit,
- self.stop_loss,
- self.timeout,
- steps_to_days(self.timeout, self.config.get("timeframe")),
- self.observation_space,
- )
+ self.shape = (self.window_size, self.total_features)
+ self.observation_space = Box(
+ low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float32
+ )
- def _set_observation_space(self) -> None:
- """
- Set the observation space
- """
- signal_features = self.signal_features.shape[1]
- if self.add_state_info:
- # STATE_INFO
- self.state_features = ["pnl", "position", "trade_duration"]
- self.total_features = signal_features + len(self.state_features)
- else:
- self.state_features = []
- self.total_features = signal_features
+ def _is_valid(self, action: int) -> bool:
+ return ReforceXY.get_action_masks(self._position, self._force_action)[action]
- self.shape = (self.window_size, self.total_features)
- self.observation_space = Box(
- low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float32
- )
+ def reset_env(
+ self,
+ df: DataFrame,
+ prices: DataFrame,
+ window_size: int,
+ reward_kwargs: Dict[str, Any],
+ starting_point=True,
+ ) -> None:
+ """
+ Resets the environment when the agent fails
+ """
+ super().reset_env(df, prices, window_size, reward_kwargs, starting_point)
+ self._set_observation_space()
- def _is_valid(self, action: int) -> bool:
- return ReforceXY.get_action_masks(self._position, self._force_action)[
- action
- ]
-
- def reset_env(
- self,
- df: DataFrame,
- prices: DataFrame,
- window_size: int,
- reward_kwargs: Dict[str, Any],
- starting_point=True,
- ) -> None:
- """
- Resets the environment when the agent fails
- """
- super().reset_env(df, prices, window_size, reward_kwargs, starting_point)
- self._set_observation_space()
-
- def reset(
- self, seed=None, **kwargs
- ) -> Tuple[NDArray[np.float32], Dict[str, Any]]:
- """
- Reset is called at the beginning of every episode
- """
- observation, history = super().reset(seed, **kwargs)
- self._force_action: Optional[ForceActions] = None
- self._last_closed_position: Optional[Positions] = None
- self._last_closed_trade_tick: int = 0
- return observation, history
-
- def _get_exit_reward_factor(
- self,
- factor: float,
- pnl: float,
- trade_duration: int,
- max_trade_duration: int,
- ) -> float:
- """
- Compute the reward factor at trade exit
- """
- if trade_duration <= max_trade_duration:
- factor *= 1.5
- elif trade_duration > max_trade_duration:
- factor *= 0.5
- if pnl > self.profit_aim * self.rr:
- factor *= float(
- self.rl_config.get("model_reward_parameters", {}).get(
- "win_reward_factor", 2.0
- )
- )
- return factor
-
- def calculate_reward(self, action: int) -> float:
- """
- An example reward function. This is the one function that users will likely
- wish to inject their own creativity into.
-
- Warning!
- This is function is a showcase of functionality designed to show as many possible
- environment control features as possible. It is also designed to run quickly
- on small computers. This is a benchmark, it is *not* for live production.
-
- :param action: int = The action made by the agent for the current candle.
- :return:
- float = the reward to give to the agent for current step (used for optimization
- of weights in NN)
- """
- # first, penalize if the action is not valid
- if not self.action_masking and not self._is_valid(action):
- self.tensorboard_log("invalid", category="actions")
- return self.rl_config.get("model_reward_parameters", {}).get(
- "invalid_action", -2.0
+ def reset(self, seed=None, **kwargs) -> Tuple[NDArray[np.float32], Dict[str, Any]]:
+ """
+ Reset is called at the beginning of every episode
+ """
+ observation, history = super().reset(seed, **kwargs)
+ self._force_action: Optional[ForceActions] = None
+ self._last_closed_position: Optional[Positions] = None
+ self._last_closed_trade_tick: int = 0
+ return observation, history
+
+ def _get_exit_reward_factor(
+ self,
+ factor: float,
+ pnl: float,
+ trade_duration: int,
+ max_trade_duration: int,
+ ) -> float:
+ """
+ Compute the reward factor at trade exit
+ """
+ if trade_duration <= max_trade_duration:
+ factor *= 1.5
+ elif trade_duration > max_trade_duration:
+ factor *= 0.5
+ if pnl > self.profit_aim * self.rr:
+ factor *= float(
+ self.rl_config.get("model_reward_parameters", {}).get(
+ "win_reward_factor", 2.0
)
+ )
+ return factor
- pnl = self.get_unrealized_profit()
- # mrr = self.get_most_recent_return()
- # mrp = self.get_most_recent_profit()
+ def calculate_reward(self, action: int) -> float:
+ """
+ An example reward function. This is the one function that users will likely
+ wish to inject their own creativity into.
- max_trade_duration = max(1, self.timeout)
- trade_duration = self.get_trade_duration()
+ Warning!
+ This is function is a showcase of functionality designed to show as many possible
+ environment control features as possible. It is also designed to run quickly
+ on small computers. This is a benchmark, it is *not* for live production.
- factor = 100.0
+ :param action: int = The action made by the agent for the current candle.
+ :return:
+ float = the reward to give to the agent for current step (used for optimization
+ of weights in NN)
+ """
+ # first, penalize if the action is not valid
+ if not self.action_masking and not self._is_valid(action):
+ self.tensorboard_log("invalid", category="actions")
+ return self.rl_config.get("model_reward_parameters", {}).get(
+ "invalid_action", -2.0
+ )
- # Force exits
- if self._force_action in (
- ForceActions.Take_profit,
- ForceActions.Stop_loss,
- ForceActions.Timeout,
- ):
- return pnl * self._get_exit_reward_factor(
- factor, pnl, trade_duration, max_trade_duration
- )
+ pnl = self.get_unrealized_profit()
+ # mrr = self.get_most_recent_return()
+ # mrp = self.get_most_recent_profit()
- # # you can use feature values from dataframe
- # rsi_now = self.get_feature_value(
- # name="%-rsi",
- # period=8,
- # pair=self.pair,
- # timeframe=self.config.get("timeframe"),
- # raw=True,
- # )
-
- # # reward agent for entering trades when RSI is low
- # if (
- # action in (Actions.Long_enter.value, Actions.Short_enter.value)
- # and self._position == Positions.Neutral
- # ):
- # if rsi_now < 40:
- # factor = 40 / rsi_now
- # else:
- # factor = 1
- # return 25.0 * factor
-
- # discourage agent from sitting idle too long
- if action == Actions.Neutral.value and self._position == Positions.Neutral:
- return -0.01 * self.get_idle_duration() ** 1.05
-
- # pnl and duration aware agent reward while sitting in position
- if (
- self._position in (Positions.Short, Positions.Long)
- and action == Actions.Neutral.value
- ):
- duration_fraction = trade_duration / max_trade_duration
- max_pnl = max(self.get_most_recent_max_pnl(), pnl)
+ max_trade_duration = max(1, self.timeout)
+ trade_duration = self.get_trade_duration()
- if max_pnl > 0:
- drawdown_penalty = (
- 0.0025 * factor * (max_pnl - pnl) * duration_fraction
- )
- else:
- drawdown_penalty = 0.0
+ factor = 100.0
- lambda1 = 0.05
- lambda2 = 0.1
- if pnl >= 0:
- if duration_fraction <= 1.0:
- duration_penalty_factor = 1.0
- else:
- duration_penalty_factor = 1.0 / (
- 1.0 + lambda1 * (duration_fraction - 1.0)
- )
- return (
- factor * pnl * duration_penalty_factor
- - lambda2 * duration_fraction
- - drawdown_penalty
- )
+ # Force exits
+ if self._force_action in (
+ ForceActions.Take_profit,
+ ForceActions.Stop_loss,
+ ForceActions.Timeout,
+ ):
+ return pnl * self._get_exit_reward_factor(
+ factor, pnl, trade_duration, max_trade_duration
+ )
+
+ # # you can use feature values from dataframe
+ # rsi_now = self.get_feature_value(
+ # name="%-rsi",
+ # period=8,
+ # pair=self.pair,
+ # timeframe=self.config.get("timeframe"),
+ # raw=True,
+ # )
+
+ # # reward agent for entering trades when RSI is low
+ # if (
+ # action in (Actions.Long_enter.value, Actions.Short_enter.value)
+ # and self._position == Positions.Neutral
+ # ):
+ # if rsi_now < 40:
+ # factor = 40 / rsi_now
+ # else:
+ # factor = 1
+ # return 25.0 * factor
+
+ # discourage agent from sitting idle too long
+ if action == Actions.Neutral.value and self._position == Positions.Neutral:
+ return -0.01 * self.get_idle_duration() ** 1.05
+
+ # pnl and duration aware agent reward while sitting in position
+ if (
+ self._position in (Positions.Short, Positions.Long)
+ and action == Actions.Neutral.value
+ ):
+ duration_fraction = trade_duration / max_trade_duration
+ max_pnl = max(self.get_most_recent_max_pnl(), pnl)
+
+ if max_pnl > 0:
+ drawdown_penalty = 0.0025 * factor * (max_pnl - pnl) * duration_fraction
+ else:
+ drawdown_penalty = 0.0
+
+ lambda1 = 0.05
+ lambda2 = 0.1
+ if pnl >= 0:
+ if duration_fraction <= 1.0:
+ duration_penalty_factor = 1.0
else:
- return (
- factor * pnl * (1.0 + lambda1 * duration_fraction)
- - 2.0 * lambda2 * duration_fraction
- - drawdown_penalty
+ duration_penalty_factor = 1.0 / (
+ 1.0 + lambda1 * (duration_fraction - 1.0)
)
-
- # close long
- if action == Actions.Long_exit.value and self._position == Positions.Long:
- return pnl * self._get_exit_reward_factor(
- factor, pnl, trade_duration, max_trade_duration
+ return (
+ factor * pnl * duration_penalty_factor
+ - lambda2 * duration_fraction
+ - drawdown_penalty
)
-
- # close short
- if action == Actions.Short_exit.value and self._position == Positions.Short:
- return pnl * self._get_exit_reward_factor(
- factor, pnl, trade_duration, max_trade_duration
+ else:
+ return (
+ factor * pnl * (1.0 + lambda1 * duration_fraction)
+ - 2.0 * lambda2 * duration_fraction
+ - drawdown_penalty
)
- return 0.0
+ # close long
+ if action == Actions.Long_exit.value and self._position == Positions.Long:
+ return pnl * self._get_exit_reward_factor(
+ factor, pnl, trade_duration, max_trade_duration
+ )
- def _get_observation(self) -> NDArray[np.float32]:
- """
- This may or may not be independent of action types, user can inherit
- this in their custom "MyRLEnv"
- """
- features_window = self.signal_features[
- (self._current_tick - self.window_size) : self._current_tick
- ]
- if self.add_state_info:
- features_and_state = DataFrame(
- np.zeros(
- (len(features_window), len(self.state_features)),
- dtype=np.float32,
- ),
- columns=self.state_features,
- index=features_window.index,
- )
- # STATE_INFO
- features_and_state["pnl"] = self.get_unrealized_profit()
- features_and_state["position"] = self._position.value
- features_and_state["trade_duration"] = self.get_trade_duration()
+ # close short
+ if action == Actions.Short_exit.value and self._position == Positions.Short:
+ return pnl * self._get_exit_reward_factor(
+ factor, pnl, trade_duration, max_trade_duration
+ )
- features_and_state = concat(
- [features_window, features_and_state], axis=1
- )
- return features_and_state.to_numpy(dtype=np.float32)
- else:
- return features_window.to_numpy(dtype=np.float32)
+ return 0.0
- def _get_force_action(self) -> Optional[ForceActions]:
- if not self.force_actions or self._position == Positions.Neutral:
- return None
+ def _get_observation(self) -> NDArray[np.float32]:
+ """
+ This may or may not be independent of action types, user can inherit
+ this in their custom "MyRLEnv"
+ """
+ features_window = self.signal_features[
+ (self._current_tick - self.window_size) : self._current_tick
+ ]
+ if self.add_state_info:
+ features_and_state = DataFrame(
+ np.zeros(
+ (len(features_window), len(self.state_features)),
+ dtype=np.float32,
+ ),
+ columns=self.state_features,
+ index=features_window.index,
+ )
+ # STATE_INFO
+ features_and_state["pnl"] = self.get_unrealized_profit()
+ features_and_state["position"] = self._position.value
+ features_and_state["trade_duration"] = self.get_trade_duration()
- trade_duration = self.get_trade_duration()
- if trade_duration <= 1:
- return None
- if trade_duration >= self.timeout:
- return ForceActions.Timeout
-
- pnl = self.get_unrealized_profit()
- if pnl >= self.take_profit:
- return ForceActions.Take_profit
- if pnl <= self.stop_loss:
- return ForceActions.Stop_loss
+ features_and_state = concat([features_window, features_and_state], axis=1)
+ return features_and_state.to_numpy(dtype=np.float32)
+ else:
+ return features_window.to_numpy(dtype=np.float32)
+
+ def _get_force_action(self) -> Optional[ForceActions]:
+ if not self.force_actions or self._position == Positions.Neutral:
return None
- def _get_position(self, action: int) -> Positions:
- return {
- Actions.Long_enter.value: Positions.Long,
- Actions.Short_enter.value: Positions.Short,
- }[action]
-
- def _enter_trade(self, action: int) -> None:
- self._position = self._get_position(action)
- self._last_trade_tick = self._current_tick
-
- def _exit_trade(self) -> None:
- self._update_total_profit()
- self._last_closed_position = self._position
- self._position = Positions.Neutral
- self._last_closed_trade_tick = self._current_tick
- self._last_trade_tick = None
-
- def execute_trade(self, action: int) -> Optional[str]:
- """
- Execute trade based on the given action
- """
- # Force exit trade
- if self._force_action:
- self._exit_trade()
- self.tensorboard_log(
- f"{self._force_action.name}", category="actions/force"
- )
- return f"{self._force_action.name}"
+ trade_duration = self.get_trade_duration()
+ if trade_duration <= 1:
+ return None
+ if trade_duration >= self.timeout:
+ return ForceActions.Timeout
+
+ pnl = self.get_unrealized_profit()
+ if pnl >= self.take_profit:
+ return ForceActions.Take_profit
+ if pnl <= self.stop_loss:
+ return ForceActions.Stop_loss
+ return None
- if not self.is_tradesignal(action):
- return None
+ def _get_position(self, action: int) -> Positions:
+ return {
+ Actions.Long_enter.value: Positions.Long,
+ Actions.Short_enter.value: Positions.Short,
+ }[action]
+
+ def _enter_trade(self, action: int) -> None:
+ self._position = self._get_position(action)
+ self._last_trade_tick = self._current_tick
+
+ def _exit_trade(self) -> None:
+ self._update_total_profit()
+ self._last_closed_position = self._position
+ self._position = Positions.Neutral
+ self._last_closed_trade_tick = self._current_tick
+ self._last_trade_tick = None
+
+ def execute_trade(self, action: int) -> Optional[str]:
+ """
+ Execute trade based on the given action
+ """
+ # Force exit trade
+ if self._force_action:
+ self._exit_trade()
+ self.tensorboard_log(f"{self._force_action.name}", category="actions/force")
+ return f"{self._force_action.name}"
- # Enter trade based on action
- if action in (Actions.Long_enter.value, Actions.Short_enter.value):
- self._enter_trade(action)
- return f"{self._position.name}_enter"
+ if not self.is_tradesignal(action):
+ return None
- # Exit trade based on action
- if action in (Actions.Long_exit.value, Actions.Short_exit.value):
- self._exit_trade()
- return f"{self._last_closed_position.name}_exit"
+ # Enter trade based on action
+ if action in (Actions.Long_enter.value, Actions.Short_enter.value):
+ self._enter_trade(action)
+ return f"{self._position.name}_enter"
- return None
+ # Exit trade based on action
+ if action in (Actions.Long_exit.value, Actions.Short_exit.value):
+ self._exit_trade()
+ return f"{self._last_closed_position.name}_exit"
- def step(
- self, action: int
- ) -> Tuple[NDArray[np.float32], float, bool, bool, Dict[str, Any]]:
- """
- Take a step in the environment based on the provided action
- """
- self._current_tick += 1
- self._update_unrealized_total_profit()
- previous_pnl = self.get_unrealized_profit()
- self._update_portfolio_log_returns()
- self._force_action = self._get_force_action()
- reward = self.calculate_reward(action)
- self.total_reward += reward
- self.tensorboard_log(Actions._member_names_[action], category="actions")
- trade_type = self.execute_trade(action)
- if trade_type is not None:
- self.append_trade_history(
- trade_type, self.current_price(), previous_pnl
- )
- self._position_history.append(self._position)
- info = {
+ return None
+
+ def step(
+ self, action: int
+ ) -> Tuple[NDArray[np.float32], float, bool, bool, Dict[str, Any]]:
+ """
+ Take a step in the environment based on the provided action
+ """
+ self._current_tick += 1
+ self._update_unrealized_total_profit()
+ previous_pnl = self.get_unrealized_profit()
+ self._update_portfolio_log_returns()
+ self._force_action = self._get_force_action()
+ reward = self.calculate_reward(action)
+ self.total_reward += reward
+ self.tensorboard_log(Actions._member_names_[action], category="actions")
+ trade_type = self.execute_trade(action)
+ if trade_type is not None:
+ self.append_trade_history(trade_type, self.current_price(), previous_pnl)
+ self._position_history.append(self._position)
+ info = {
+ "tick": self._current_tick,
+ "position": self._position.value,
+ "action": action,
+ "force_action": (self._force_action.name if self._force_action else None),
+ "previous_pnl": round(previous_pnl, 5),
+ "pnl": round(self.get_unrealized_profit(), 5),
+ "reward": round(reward, 5),
+ "total_reward": round(self.total_reward, 5),
+ "total_profit": round(self._total_profit, 5),
+ "idle_duration": self.get_idle_duration(),
+ "trade_duration": self.get_trade_duration(),
+ "trade_count": int(len(self.trade_history) // 2),
+ }
+ self._update_history(info)
+ return (
+ self._get_observation(),
+ reward,
+ self.is_terminated(),
+ self.is_truncated(),
+ info,
+ )
+
+ def append_trade_history(
+ self, trade_type: str, price: float, profit: float
+ ) -> None:
+ self.trade_history.append(
+ {
"tick": self._current_tick,
- "position": self._position.value,
- "action": action,
- "force_action": (
- self._force_action.name if self._force_action else None
- ),
- "previous_pnl": round(previous_pnl, 5),
- "pnl": round(self.get_unrealized_profit(), 5),
- "reward": round(reward, 5),
- "total_reward": round(self.total_reward, 5),
- "total_profit": round(self._total_profit, 5),
- "idle_duration": self.get_idle_duration(),
- "trade_duration": self.get_trade_duration(),
- "trade_count": int(len(self.trade_history) // 2),
+ "type": trade_type.lower(),
+ "price": price,
+ "profit": profit,
}
- self._update_history(info)
- return (
- self._get_observation(),
- reward,
- self.is_terminated(),
- self.is_truncated(),
- info,
- )
-
- def append_trade_history(
- self, trade_type: str, price: float, profit: float
- ) -> None:
- self.trade_history.append(
- {
- "tick": self._current_tick,
- "type": trade_type.lower(),
- "price": price,
- "profit": profit,
- }
- )
+ )
- def is_terminated(self) -> bool:
- return bool(
- self._current_tick == self._end_tick
- or self._total_profit <= self.max_drawdown
- or self._total_unrealized_profit <= self.max_drawdown
- )
+ def is_terminated(self) -> bool:
+ return bool(
+ self._current_tick == self._end_tick
+ or self._total_profit <= self.max_drawdown
+ or self._total_unrealized_profit <= self.max_drawdown
+ )
- def is_truncated(self) -> bool:
- return False
+ def is_truncated(self) -> bool:
+ return False
- def is_tradesignal(self, action: int) -> bool:
- """
- Determine if the action is entry or exit
- """
- return (
- (
- action in (Actions.Short_enter.value, Actions.Long_enter.value)
- and self._position == Positions.Neutral
- )
- or (
- action == Actions.Long_exit.value
- and self._position == Positions.Long
- )
- or (
- action == Actions.Short_exit.value
- and self._position == Positions.Short
- )
+ def is_tradesignal(self, action: int) -> bool:
+ """
+ Determine if the action is entry or exit
+ """
+ return (
+ (
+ action in (Actions.Short_enter.value, Actions.Long_enter.value)
+ and self._position == Positions.Neutral
)
-
- def action_masks(self) -> NDArray[bool]:
- return ReforceXY.get_action_masks(self._position, self._force_action)
-
- def get_feature_value(
- self,
- name: str,
- period: int = 0,
- shift: int = 0,
- pair: str = "",
- timeframe: str = "",
- raw: bool = True,
- ) -> float:
- """
- Get feature value
- """
- feature_parts = [name]
- if period:
- feature_parts.append(f"_{period}")
- if shift:
- feature_parts.append(f"_shift-{shift}")
- if pair:
- feature_parts.append(f"_{pair.replace(':', '')}")
- if timeframe:
- feature_parts.append(f"_{timeframe}")
- feature_col = "".join(feature_parts)
-
- if not raw:
- return self.signal_features[feature_col].iloc[self._current_tick]
- return self.raw_features[feature_col].iloc[self._current_tick]
-
- def get_idle_duration(self) -> int:
- if self._position != Positions.Neutral:
- return 0
- if not self._last_closed_trade_tick:
- return self._current_tick - self._start_tick
- return self._current_tick - self._last_closed_trade_tick
-
- def get_most_recent_max_pnl(self) -> float:
- """
- Calculate the most recent maximum unrealized profit if in a trade
- """
- if self._last_trade_tick is None:
- return 0.0
- if self._position == Positions.Neutral:
- return 0.0
- pnl_history = self.history.get("pnl")
- if not pnl_history or len(pnl_history) == 0:
- return 0.0
-
- pnl_history = np.asarray(pnl_history)
- ticks = self.history.get("tick")
- if not ticks:
- return 0.0
- ticks = np.asarray(ticks)
- start = np.searchsorted(ticks, self._last_trade_tick, side="left")
- trade_pnl_history = pnl_history[start:]
- if trade_pnl_history.size == 0:
- return 0.0
- return np.max(trade_pnl_history)
-
- def get_most_recent_return(self) -> float:
- """
- Calculate the tick to tick return if in a trade.
- Return is generated from rising prices in Long and falling prices in Short positions.
- The actions Sell/Buy or Hold during a Long position trigger the sell/buy-fee.
- """
- if self._last_trade_tick is None:
- return 0.0
- if self._position == Positions.Neutral:
- return 0.0
- elif self._position == Positions.Long:
- current_price = self.current_price()
- previous_price = self.previous_price()
- if (
- self._position_history[self._current_tick - 1] == Positions.Short
- or self._position_history[self._current_tick - 1]
- == Positions.Neutral
- ):
- previous_price = self.add_entry_fee(previous_price)
- return np.log(current_price) - np.log(previous_price)
- elif self._position == Positions.Short:
- current_price = self.current_price()
- previous_price = self.previous_price()
- if (
- self._position_history[self._current_tick - 1] == Positions.Long
- or self._position_history[self._current_tick - 1]
- == Positions.Neutral
- ):
- previous_price = self.add_exit_fee(previous_price)
- return np.log(previous_price) - np.log(current_price)
- return 0.0
-
- def _update_portfolio_log_returns(self):
- self.portfolio_log_returns[self._current_tick] = (
- self.get_most_recent_return()
+ or (action == Actions.Long_exit.value and self._position == Positions.Long)
+ or (
+ action == Actions.Short_exit.value and self._position == Positions.Short
)
+ )
- def get_most_recent_profit(self) -> float:
- """
- Calculate the tick to tick unrealized profit if in a trade
- """
- if self._last_trade_tick is None:
- return 0.0
- if self._position == Positions.Neutral:
- return 0.0
- elif self._position == Positions.Long:
- current_price = self.add_exit_fee(self.current_price())
- previous_price = self.add_entry_fee(self.previous_price())
- return (current_price - previous_price) / previous_price
- elif self._position == Positions.Short:
- current_price = self.add_entry_fee(self.current_price())
- previous_price = self.add_exit_fee(self.previous_price())
- return (previous_price - current_price) / previous_price
+ def action_masks(self) -> NDArray[bool]:
+ return ReforceXY.get_action_masks(self._position, self._force_action)
+
+ def get_feature_value(
+ self,
+ name: str,
+ period: int = 0,
+ shift: int = 0,
+ pair: str = "",
+ timeframe: str = "",
+ raw: bool = True,
+ ) -> float:
+ """
+ Get feature value
+ """
+ feature_parts = [name]
+ if period:
+ feature_parts.append(f"_{period}")
+ if shift:
+ feature_parts.append(f"_shift-{shift}")
+ if pair:
+ feature_parts.append(f"_{pair.replace(':', '')}")
+ if timeframe:
+ feature_parts.append(f"_{timeframe}")
+ feature_col = "".join(feature_parts)
+
+ if not raw:
+ return self.signal_features[feature_col].iloc[self._current_tick]
+ return self.raw_features[feature_col].iloc[self._current_tick]
+
+ def get_idle_duration(self) -> int:
+ if self._position != Positions.Neutral:
+ return 0
+ if not self._last_closed_trade_tick:
+ return self._current_tick - self._start_tick
+ return self._current_tick - self._last_closed_trade_tick
+
+ def get_most_recent_max_pnl(self) -> float:
+ """
+ Calculate the most recent maximum unrealized profit if in a trade
+ """
+ if self._last_trade_tick is None:
+ return 0.0
+ if self._position == Positions.Neutral:
+ return 0.0
+ pnl_history = self.history.get("pnl")
+ if not pnl_history or len(pnl_history) == 0:
return 0.0
- def previous_price(self) -> float:
- return self.prices.iloc[self._current_tick - 1].get("open")
+ pnl_history = np.asarray(pnl_history)
+ ticks = self.history.get("tick")
+ if not ticks:
+ return 0.0
+ ticks = np.asarray(ticks)
+ start = np.searchsorted(ticks, self._last_trade_tick, side="left")
+ trade_pnl_history = pnl_history[start:]
+ if trade_pnl_history.size == 0:
+ return 0.0
+ return np.max(trade_pnl_history)
- def get_env_history(self) -> DataFrame:
- """
- Get environment data aligned on ticks, including optional trade events
- """
- if not self.history:
- logger.warning("history is empty")
- return DataFrame()
+ def get_most_recent_return(self) -> float:
+ """
+ Calculate the tick to tick return if in a trade.
+ Return is generated from rising prices in Long and falling prices in Short positions.
+ The actions Sell/Buy or Hold during a Long position trigger the sell/buy-fee.
+ """
+ if self._last_trade_tick is None:
+ return 0.0
+ if self._position == Positions.Neutral:
+ return 0.0
+ elif self._position == Positions.Long:
+ current_price = self.current_price()
+ previous_price = self.previous_price()
+ if (
+ self._position_history[self._current_tick - 1] == Positions.Short
+ or self._position_history[self._current_tick - 1] == Positions.Neutral
+ ):
+ previous_price = self.add_entry_fee(previous_price)
+ return np.log(current_price) - np.log(previous_price)
+ elif self._position == Positions.Short:
+ current_price = self.current_price()
+ previous_price = self.previous_price()
+ if (
+ self._position_history[self._current_tick - 1] == Positions.Long
+ or self._position_history[self._current_tick - 1] == Positions.Neutral
+ ):
+ previous_price = self.add_exit_fee(previous_price)
+ return np.log(previous_price) - np.log(current_price)
+ return 0.0
- _history_df = DataFrame.from_dict(self.history)
- if "tick" not in _history_df.columns:
- logger.warning("'tick' column is missing from history")
- return DataFrame()
+ def _update_portfolio_log_returns(self):
+ self.portfolio_log_returns[self._current_tick] = self.get_most_recent_return()
- if self.trade_history:
- _trade_history_df = DataFrame.from_dict(self.trade_history)
- if "tick" in _trade_history_df.columns:
- _rollout_history = merge(
- _history_df, _trade_history_df, on="tick", how="left"
- )
- else:
- _rollout_history = _history_df.copy()
+ def get_most_recent_profit(self) -> float:
+ """
+ Calculate the tick to tick unrealized profit if in a trade
+ """
+ if self._last_trade_tick is None:
+ return 0.0
+ if self._position == Positions.Neutral:
+ return 0.0
+ elif self._position == Positions.Long:
+ current_price = self.add_exit_fee(self.current_price())
+ previous_price = self.add_entry_fee(self.previous_price())
+ return (current_price - previous_price) / previous_price
+ elif self._position == Positions.Short:
+ current_price = self.add_entry_fee(self.current_price())
+ previous_price = self.add_exit_fee(self.previous_price())
+ return (previous_price - current_price) / previous_price
+ return 0.0
+
+ def previous_price(self) -> float:
+ return self.prices.iloc[self._current_tick - 1].get("open")
+
+ def get_env_history(self) -> DataFrame:
+ """
+ Get environment data aligned on ticks, including optional trade events
+ """
+ if not self.history:
+ logger.warning("history is empty")
+ return DataFrame()
+
+ _history_df = DataFrame.from_dict(self.history)
+ if "tick" not in _history_df.columns:
+ logger.warning("'tick' column is missing from history")
+ return DataFrame()
+
+ if self.trade_history:
+ _trade_history_df = DataFrame.from_dict(self.trade_history)
+ if "tick" in _trade_history_df.columns:
+ _rollout_history = merge(
+ _history_df, _trade_history_df, on="tick", how="left"
+ )
else:
_rollout_history = _history_df.copy()
+ else:
+ _rollout_history = _history_df.copy()
+ try:
+ history = merge(
+ _rollout_history,
+ self.prices,
+ left_on="tick",
+ right_index=True,
+ how="left",
+ )
+ except Exception:
try:
+ _price_history = (
+ self.prices.iloc[_rollout_history.tick]
+ .copy()
+ .reset_index(drop=True)
+ )
history = merge(
_rollout_history,
- self.prices,
- left_on="tick",
+ _price_history,
+ left_index=True,
right_index=True,
- how="left",
)
- except Exception:
- try:
- _price_history = (
- self.prices.iloc[_rollout_history.tick]
- .copy()
- .reset_index(drop=True)
- )
- history = merge(
- _rollout_history,
- _price_history,
- left_index=True,
- right_index=True,
- )
- except Exception as e:
- logger.error(
- f"Failed to merge history with prices: {repr(e)}",
- exc_info=True,
- )
- return DataFrame()
- return history
-
- def get_env_plot(self) -> plt.Figure:
- """
- Plot trades and environment data
- """
-
- def transform_y_offset(ax, offset):
- return mtransforms.offset_copy(ax.transData, fig=fig, y=offset)
-
- def plot_markers(ax, xs, ys, marker, color, size, offset):
- ax.plot(
- xs,
- ys,
- marker=marker,
- color=color,
- markersize=size,
- fillstyle="full",
- transform=transform_y_offset(ax, offset),
- linestyle="none",
- zorder=3,
+ except Exception as e:
+ logger.error(
+ f"Failed to merge history with prices: {repr(e)}",
+ exc_info=True,
)
+ return DataFrame()
+ return history
- with plt.style.context("dark_background"):
- fig, axs = plt.subplots(
- nrows=5,
- ncols=1,
- figsize=(14, 8),
- height_ratios=[5, 1, 1, 1, 1],
- sharex=True,
- )
+ def get_env_plot(self) -> plt.Figure:
+ """
+ Plot trades and environment data
+ """
- history = self.get_env_history()
- if len(history) == 0:
- return fig
-
- plot_window = int(self.rl_config.get("plot_window", 2000))
- if plot_window > 0 and len(history) > plot_window:
- history = history.iloc[-plot_window:]
-
- ticks = history.get("tick")
- history_open = history.get("open")
- if (
- ticks is None
- or len(ticks) == 0
- or history_open is None
- or len(history_open) == 0
- ):
- return fig
-
- axs[0].plot(ticks, history_open, linewidth=1, color="orchid", zorder=1)
-
- history_type = history.get("type")
- history_price = history.get("price")
- if history_type is not None and history_price is not None:
- enter_long_mask = history_type == "long_enter"
- enter_short_mask = history_type == "short_enter"
- exit_long_mask = history_type == "long_exit"
- exit_short_mask = history_type == "short_exit"
- take_profit_mask = history_type == "take_profit"
- stop_loss_mask = history_type == "stop_loss"
- timeout_mask = history_type == "timeout"
-
- enter_long_x = ticks[enter_long_mask]
- enter_short_x = ticks[enter_short_mask]
- exit_long_x = ticks[exit_long_mask]
- exit_short_x = ticks[exit_short_mask]
- take_profit_x = ticks[take_profit_mask]
- stop_loss_x = ticks[stop_loss_mask]
- timeout_x = ticks[timeout_mask]
-
- enter_long_y = history.loc[enter_long_mask, "price"]
- enter_short_y = history.loc[enter_short_mask, "price"]
- exit_long_y = history.loc[exit_long_mask, "price"]
- exit_short_y = history.loc[exit_short_mask, "price"]
- take_profit_y = history.loc[take_profit_mask, "price"]
- stop_loss_y = history.loc[stop_loss_mask, "price"]
- timeout_y = history.loc[timeout_mask, "price"]
-
- plot_markers(
- axs[0], enter_long_x, enter_long_y, "^", "forestgreen", 5, -0.1
- )
- plot_markers(
- axs[0], enter_short_x, enter_short_y, "v", "firebrick", 5, 0.1
- )
- plot_markers(
- axs[0], exit_long_x, exit_long_y, ".", "cornflowerblue", 4, 0.1
- )
- plot_markers(
- axs[0], exit_short_x, exit_short_y, ".", "thistle", 4, -0.1
- )
- plot_markers(
- axs[0], take_profit_x, take_profit_y, "*", "lime", 8, 0.1
- )
- plot_markers(axs[0], stop_loss_x, stop_loss_y, "x", "red", 8, -0.1)
- plot_markers(axs[0], timeout_x, timeout_y, "1", "yellow", 8, 0.0)
-
- axs[1].set_ylabel("pnl")
- pnl_series = history.get("pnl")
- if pnl_series is not None and len(pnl_series) > 0:
- axs[1].plot(
- ticks,
- pnl_series,
- linewidth=1,
- color="gray",
- label="pnl",
- )
- previous_pnl_series = history.get("previous_pnl")
- if previous_pnl_series is not None and len(previous_pnl_series) > 0:
- axs[1].plot(
- ticks,
- previous_pnl_series,
- linewidth=1,
- color="deepskyblue",
- label="previous_pnl",
- )
- if (pnl_series is not None and len(pnl_series) > 0) or (
- previous_pnl_series is not None and len(previous_pnl_series) > 0
- ):
- axs[1].legend(loc="upper left", fontsize=8)
- axs[1].axhline(y=0, label="0", alpha=0.33, color="gray")
- axs[1].axhline(
- y=self.take_profit, label="tp", alpha=0.33, color="green"
- )
- axs[1].axhline(y=self.stop_loss, label="sl", alpha=0.33, color="red")
-
- axs[2].set_ylabel("reward")
- reward_series = history.get("reward")
- if reward_series is not None and len(reward_series) > 0:
- axs[2].plot(ticks, reward_series, linewidth=1, color="gray")
- axs[2].axhline(y=0, label="0", alpha=0.33)
-
- axs[3].set_ylabel("total_profit")
- total_profit_series = history.get("total_profit")
- if total_profit_series is not None and len(total_profit_series) > 0:
- axs[3].plot(ticks, total_profit_series, linewidth=1, color="gray")
- axs[3].axhline(y=1, label="1", alpha=0.33)
-
- axs[4].set_ylabel("total_reward")
- total_reward_series = history.get("total_reward")
- if total_reward_series is not None and len(total_reward_series) > 0:
- axs[4].plot(ticks, total_reward_series, linewidth=1, color="gray")
- axs[4].axhline(y=0, label="0", alpha=0.33)
- axs[4].set_xlabel("tick")
-
- _borders = ["top", "right", "bottom", "left"]
- for _ax in axs:
- for _border in _borders:
- _ax.spines[_border].set_color("#5b5e4b")
-
- fig.suptitle(
- f"Total Reward: {self.total_reward:.2f} ~ "
- + f"Total Profit: {self._total_profit:.2f} ~ "
- + f"Trades: {int(len(self.trade_history) // 2)}",
+ def transform_y_offset(ax, offset):
+ return mtransforms.offset_copy(ax.transData, fig=fig, y=offset)
+
+ def plot_markers(ax, xs, ys, marker, color, size, offset):
+ ax.plot(
+ xs,
+ ys,
+ marker=marker,
+ color=color,
+ markersize=size,
+ fillstyle="full",
+ transform=transform_y_offset(ax, offset),
+ linestyle="none",
+ zorder=3,
)
- fig.tight_layout()
- return fig
- def close(self) -> None:
- plt.close()
- gc.collect()
- if th.cuda.is_available():
- th.cuda.empty_cache()
+ with plt.style.context("dark_background"):
+ fig, axs = plt.subplots(
+ nrows=5,
+ ncols=1,
+ figsize=(14, 8),
+ height_ratios=[5, 1, 1, 1, 1],
+ sharex=True,
+ )
+
+ history = self.get_env_history()
+ if len(history) == 0:
+ return fig
+
+ plot_window = int(self.rl_config.get("plot_window", 2000))
+ if plot_window > 0 and len(history) > plot_window:
+ history = history.iloc[-plot_window:]
+
+ ticks = history.get("tick")
+ history_open = history.get("open")
+ if (
+ ticks is None
+ or len(ticks) == 0
+ or history_open is None
+ or len(history_open) == 0
+ ):
+ return fig
+
+ axs[0].plot(ticks, history_open, linewidth=1, color="orchid", zorder=1)
+
+ history_type = history.get("type")
+ history_price = history.get("price")
+ if history_type is not None and history_price is not None:
+ enter_long_mask = history_type == "long_enter"
+ enter_short_mask = history_type == "short_enter"
+ exit_long_mask = history_type == "long_exit"
+ exit_short_mask = history_type == "short_exit"
+ take_profit_mask = history_type == "take_profit"
+ stop_loss_mask = history_type == "stop_loss"
+ timeout_mask = history_type == "timeout"
+
+ enter_long_x = ticks[enter_long_mask]
+ enter_short_x = ticks[enter_short_mask]
+ exit_long_x = ticks[exit_long_mask]
+ exit_short_x = ticks[exit_short_mask]
+ take_profit_x = ticks[take_profit_mask]
+ stop_loss_x = ticks[stop_loss_mask]
+ timeout_x = ticks[timeout_mask]
+
+ enter_long_y = history.loc[enter_long_mask, "price"]
+ enter_short_y = history.loc[enter_short_mask, "price"]
+ exit_long_y = history.loc[exit_long_mask, "price"]
+ exit_short_y = history.loc[exit_short_mask, "price"]
+ take_profit_y = history.loc[take_profit_mask, "price"]
+ stop_loss_y = history.loc[stop_loss_mask, "price"]
+ timeout_y = history.loc[timeout_mask, "price"]
+
+ plot_markers(
+ axs[0], enter_long_x, enter_long_y, "^", "forestgreen", 5, -0.1
+ )
+ plot_markers(
+ axs[0], enter_short_x, enter_short_y, "v", "firebrick", 5, 0.1
+ )
+ plot_markers(
+ axs[0], exit_long_x, exit_long_y, ".", "cornflowerblue", 4, 0.1
+ )
+ plot_markers(
+ axs[0], exit_short_x, exit_short_y, ".", "thistle", 4, -0.1
+ )
+ plot_markers(axs[0], take_profit_x, take_profit_y, "*", "lime", 8, 0.1)
+ plot_markers(axs[0], stop_loss_x, stop_loss_y, "x", "red", 8, -0.1)
+ plot_markers(axs[0], timeout_x, timeout_y, "1", "yellow", 8, 0.0)
+
+ axs[1].set_ylabel("pnl")
+ pnl_series = history.get("pnl")
+ if pnl_series is not None and len(pnl_series) > 0:
+ axs[1].plot(
+ ticks,
+ pnl_series,
+ linewidth=1,
+ color="gray",
+ label="pnl",
+ )
+ previous_pnl_series = history.get("previous_pnl")
+ if previous_pnl_series is not None and len(previous_pnl_series) > 0:
+ axs[1].plot(
+ ticks,
+ previous_pnl_series,
+ linewidth=1,
+ color="deepskyblue",
+ label="previous_pnl",
+ )
+ if (pnl_series is not None and len(pnl_series) > 0) or (
+ previous_pnl_series is not None and len(previous_pnl_series) > 0
+ ):
+ axs[1].legend(loc="upper left", fontsize=8)
+ axs[1].axhline(y=0, label="0", alpha=0.33, color="gray")
+ axs[1].axhline(y=self.take_profit, label="tp", alpha=0.33, color="green")
+ axs[1].axhline(y=self.stop_loss, label="sl", alpha=0.33, color="red")
+
+ axs[2].set_ylabel("reward")
+ reward_series = history.get("reward")
+ if reward_series is not None and len(reward_series) > 0:
+ axs[2].plot(ticks, reward_series, linewidth=1, color="gray")
+ axs[2].axhline(y=0, label="0", alpha=0.33)
+
+ axs[3].set_ylabel("total_profit")
+ total_profit_series = history.get("total_profit")
+ if total_profit_series is not None and len(total_profit_series) > 0:
+ axs[3].plot(ticks, total_profit_series, linewidth=1, color="gray")
+ axs[3].axhline(y=1, label="1", alpha=0.33)
+
+ axs[4].set_ylabel("total_reward")
+ total_reward_series = history.get("total_reward")
+ if total_reward_series is not None and len(total_reward_series) > 0:
+ axs[4].plot(ticks, total_reward_series, linewidth=1, color="gray")
+ axs[4].axhline(y=0, label="0", alpha=0.33)
+ axs[4].set_xlabel("tick")
+
+ _borders = ["top", "right", "bottom", "left"]
+ for _ax in axs:
+ for _border in _borders:
+ _ax.spines[_border].set_color("#5b5e4b")
+
+ fig.suptitle(
+ f"Total Reward: {self.total_reward:.2f} ~ "
+ + f"Total Profit: {self._total_profit:.2f} ~ "
+ + f"Trades: {int(len(self.trade_history) // 2)}",
+ )
+ fig.tight_layout()
+ return fig
+
+ def close(self) -> None:
+ plt.close()
+ gc.collect()
+ if th.cuda.is_available():
+ th.cuda.empty_cache()
class InfoMetricsCallback(TensorboardCallback):
Tensorboard callback
"""
+ def __init__(self, *args, throttle: int = 1, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.throttle = 1 if throttle < 1 else throttle
+
def _on_training_start(self) -> None:
- _lr = self.model.learning_rate
- _lr = (
- float(_lr) if isinstance(_lr, (int, float, np.floating)) else "lr_schedule"
- )
+ lr_schedule = "unknown"
+ lr_iv = np.nan
+ lr_fv = np.nan
+ lr = getattr(self.model, "learning_rate", None)
+ if callable(lr):
+ lr_schedule = "linear"
+ try:
+ lr_iv = lr(1.0)
+ except Exception:
+ lr_iv = np.nan
+ try:
+ lr_fv = lr(0.0)
+ except Exception:
+ lr_fv = np.nan
+ elif isinstance(lr, (int, float)):
+ lr_schedule = "constant"
+ lr_iv = float(lr)
+ lr_fv = float(lr)
n_stack = 1
env = getattr(self, "training_env", None)
while env is not None:
"algorithm": self.model.__class__.__name__,
"n_envs": int(self.model.n_envs),
"n_stack": n_stack,
- "learning_rate": _lr,
+ "lr_schedule": lr_schedule,
+ "learning_rate_iv": lr_iv,
+ "learning_rate_fv": lr_fv,
"gamma": float(self.model.gamma),
"batch_size": int(self.model.batch_size),
}
if "PPO" in self.model.__class__.__name__:
- _cr = self.model.clip_range
- _cr = (
- float(_cr)
- if isinstance(_cr, (int, float, np.floating))
- else "cr_schedule"
- )
+ cr_schedule = "unknown"
+ cr_iv = np.nan
+ cr_fv = np.nan
+ cr = getattr(self.model, "clip_range", None)
+ if callable(cr):
+ cr_schedule = "linear"
+ try:
+ cr_iv = cr(1.0)
+ except Exception:
+ cr_iv = np.nan
+ try:
+ cr_fv = cr(0.0)
+ except Exception:
+ cr_fv = np.nan
+ elif isinstance(cr, (int, float)):
+ cr_schedule = "constant"
+ cr_iv = float(cr)
+ cr_fv = float(cr)
hparam_dict.update(
{
- "clip_range": _cr,
+ "cr_schedule": cr_schedule,
+ "clip_range_iv": cr_iv,
+ "clip_range_fv": cr_fv,
"gae_lambda": float(self.model.gae_lambda),
"n_steps": int(self.model.n_steps),
"n_epochs": int(self.model.n_epochs),
"exploration_rate": float(self.model.exploration_rate),
}
)
+ train_freq = getattr(self.model, "train_freq", None)
+ train_freq_val: int | None = None
+ try:
+ if isinstance(train_freq, int):
+ train_freq_val = train_freq
+ elif isinstance(train_freq, (tuple, list)) and train_freq:
+ if isinstance(train_freq[0], int):
+ train_freq_val = train_freq[0]
+ elif hasattr(train_freq, "freq"):
+ freq = getattr(train_freq, "freq")
+ if isinstance(freq, int):
+ train_freq_val = freq
+ except Exception:
+ train_freq_val = None
+ if train_freq_val is not None:
+ hparam_dict.update({"train_freq": train_freq_val})
if "QRDQN" in self.model.__class__.__name__:
hparam_dict.update({"n_quantiles": int(self.model.n_quantiles)})
- metric_dict = {
+ metric_dict: dict[str, float | int] = {
+ "eval/mean_reward": 0.0,
+ "eval/mean_reward_std": 0.0,
+ "rollout/ep_rew_mean": 0.0,
+ "rollout/ep_len_mean": 0.0,
+ "train/n_updates": 0,
+ "train/progress_done": 0.0,
+ "train/progress_remaining": 0.0,
+ "train/learning_rate": 0.0,
"info/total_reward": 0.0,
"info/total_profit": 1.0,
"info/trade_count": 0,
"info/trade_duration": 0,
}
+ if "PPO" in self.model.__class__.__name__:
+ metric_dict.update(
+ {
+ "train/approx_kl": 0.0,
+ "train/entropy_loss": 0.0,
+ "train/policy_gradient_loss": 0.0,
+ "train/clip_fraction": 0.0,
+ "train/clip_range": 0.0,
+ "train/value_loss": 0.0,
+ "train/explained_variance": 0.0,
+ }
+ )
+ if "DQN" in self.model.__class__.__name__:
+ metric_dict.update(
+ {
+ "train/loss": 0.0,
+ "train/exploration_rate": 0.0,
+ }
+ )
self.logger.record(
"hparams",
HParam(hparam_dict, metric_dict),
)
def _on_step(self) -> bool:
+ if self.throttle > 1 and (self.num_timesteps % self.throttle) != 0:
+ return True
+
def _is_numeric_non_bool(x: Any) -> bool:
return isinstance(
x, (int, float, np.integer, np.floating)
for k, values in numeric_acc.items():
if not values:
continue
- mean = sum(values) / len(values)
+ mean = np.mean(values)
aggregated_info[k] = mean
if len(values) > 1:
try:
- aggregated_info[f"{k}_std"] = stdev(values)
+ aggregated_info[f"{k}_std"] = np.std(values)
except Exception:
pass
except Exception:
pass
- total_timesteps = getattr(self.model, "_total_timesteps", None)
- if total_timesteps is not None and not np.isclose(total_timesteps, 0.0):
- try:
+ try:
+ total_timesteps = getattr(self.model, "_total_timesteps", None)
+ if total_timesteps is not None and not np.isclose(total_timesteps, 0.0):
progress_done = float(self.num_timesteps) / float(total_timesteps)
progress_done = np.clip(progress_done, 0.0, 1.0)
- except Exception:
+ else:
progress_done = 0.0
- else:
- progress_done = 0.0
- progress_remaining = 1.0 - progress_done
-
- try:
+ progress_remaining = 1.0 - progress_done
self.logger.record("train/progress_done", progress_done)
self.logger.record("train/progress_remaining", progress_remaining)
+ except Exception:
+ progress_remaining = 1.0
+
+ try:
+ lr = getattr(self.model, "learning_rate", None)
+ if callable(lr):
+ lr = lr(progress_remaining)
+ if isinstance(lr, (int, float)) and np.isfinite(lr):
+ self.logger.record("train/learning_rate", float(lr))
except Exception:
pass
- lr_schedule = getattr(self.model, "lr_schedule", None)
- if callable(lr_schedule):
+ if "PPO" in self.model.__class__.__name__:
try:
- lr = lr_schedule(progress_remaining)
- self.logger.record("train/learning_rate", lr)
+ cr = getattr(self.model, "clip_range", None)
+ if callable(cr):
+ cr = cr(progress_remaining)
+ if isinstance(cr, (int, float)) and np.isfinite(cr):
+ self.logger.record("train/clip_range", float(cr))
except Exception:
pass
- else:
- lr = getattr(self.model, "learning_rate", None)
- if isinstance(lr, (int, float, np.floating)):
- self.logger.record("train/learning_rate", lr)
return True
def __init__(
self,
- eval_env,
- trial,
+ eval_env: BaseEnvironment,
+ trial: Trial,
n_eval_episodes: int = 10,
eval_freq: int = 2048,
deterministic: bool = True,
+ render: bool = False,
use_masking: bool = True,
+ best_model_save_path: Optional[str] = None,
+ callback_on_new_best: Optional[BaseCallback] = None,
+ callback_after_eval: Optional[BaseCallback] = None,
verbose: int = 0,
**kwargs,
):
n_eval_episodes=n_eval_episodes,
eval_freq=eval_freq,
deterministic=deterministic,
- verbose=verbose,
+ render=render,
+ best_model_save_path=best_model_save_path,
use_masking=use_masking,
+ callback_on_new_best=callback_on_new_best,
+ callback_after_eval=callback_after_eval,
+ verbose=verbose,
**kwargs,
)
self.trial = trial
return False
if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
self.eval_idx += 1
- last_mean_reward = getattr(self, "last_mean_reward", np.nan)
try:
- last_mean_reward = float(last_mean_reward)
+ last_mean_reward = float(getattr(self, "last_mean_reward", np.nan))
except Exception as e:
logger.warning(
"Optuna: invalid last_mean_reward at eval %s: %r", self.eval_idx, e
)
self.is_pruned = True
return False
+ try:
+ best_mean_reward = float(getattr(self, "best_mean_reward", np.nan))
+ except Exception as e:
+ logger.warning(
+ "Optuna: invalid best_mean_reward at eval %s: %r",
+ self.eval_idx,
+ e,
+ )
+ if np.isfinite(best_mean_reward):
+ try:
+ self.logger.record("eval/best_mean_reward", best_mean_reward)
+ except Exception:
+ pass
+ else:
+ logger.warning(
+ "Optuna: non-finite best_mean_reward at eval %s", self.eval_idx
+ )
try:
if self.trial.should_prune():
logger.info(
return True
-def make_env(
- MyRLEnv: Type[BaseEnvironment],
- env_id: str,
- rank: int,
- seed: int,
- train_df: DataFrame,
- price: DataFrame,
- env_info: Dict[str, Any],
-) -> Callable[[], BaseEnvironment]:
- """
- Utility function for multiprocessed env.
-
- :param MyRLEnv: (Type[BaseEnvironment]) environment class to instantiate
- :param env_id: (str) the environment ID
- :param rank: (int) index of the subprocess
- :param seed: (int) the initial seed for RNG
- :param train_df: (DataFrame) feature dataframe for the environment
- :param price: (DataFrame) aligned price dataframe
- :param env_info: (dict) all required arguments to instantiate the environment
- :return:
- (Callable[[], BaseEnvironment]) closure that when called instantiates and returns the environment
- """
-
- def _init() -> BaseEnvironment:
- return MyRLEnv(
- df=train_df, prices=price, id=env_id, seed=seed + rank, **env_info
- )
-
- return _init
-
-
def deepmerge(dst: Dict[str, Any], src: Dict[str, Any]) -> Dict[str, Any]:
"""Recursively merge two dicts without mutating inputs"""
dst_copy = copy.deepcopy(dst)
def get_activation_fn(
activation_fn_name: Literal["tanh", "relu", "elu", "leaky_relu"],
-) -> type[th.nn.Module]:
+) -> Type[th.nn.Module]:
"""
Get activation function
"""
def get_optimizer_class(
optimizer_class_name: Literal["adam", "adamw"],
-) -> type[th.optim.Optimizer]:
+) -> Type[th.optim.Optimizer]:
"""
Get optimizer class
"""