]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
perf(reforcexy): make MyRLEnv a non nested class
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Sat, 20 Sep 2025 16:35:34 +0000 (18:35 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Sat, 20 Sep 2025 16:35:34 +0000 (18:35 +0200)
also fix another serious logic issue at optuna integration coming from
the orginal implementation.

Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py
quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py

index 86fdbef6583cc0bd4b60add9e59351941315d398..e3eb08835101763694b262016d5869df8953bf84 100644 (file)
@@ -8,7 +8,6 @@ from collections import defaultdict
 from collections.abc import Mapping
 from enum import IntEnum
 from pathlib import Path
-from statistics import stdev
 from typing import Any, Callable, Dict, Literal, Optional, Tuple, Type, Union
 
 import matplotlib
@@ -40,6 +39,7 @@ from optuna.storages.journal import JournalFileBackend
 from optuna.study import Study, StudyDirection
 from pandas import DataFrame, concat, merge
 from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback
+from sb3_contrib.common.maskable.utils import is_masking_supported
 from stable_baselines3.common.callbacks import (
     BaseCallback,
     ProgressBarCallback,
@@ -110,8 +110,8 @@ class ReforceXY(BaseReinforcementLearningModel):
         - pip install optuna-dashboard
     """
 
-    def __init__(self, **kwargs):
-        super().__init__(**kwargs)
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
         self.pairs: list[str] = self.config.get("exchange", {}).get("pair_whitelist")
         if not self.pairs:
             raise ValueError(
@@ -120,7 +120,7 @@ class ReforceXY(BaseReinforcementLearningModel):
         self.action_masking: bool = (
             self.model_type == "MaskablePPO"
         )  # Enable action masking
-        self.rl_config["action_masking"] = self.action_masking
+        self.rl_config.setdefault("action_masking", self.action_masking)
         self.inference_masking: bool = self.rl_config.get("inference_masking", True)
         self.lr_schedule: bool = self.rl_config.get("lr_schedule", False)
         self.cr_schedule: bool = self.rl_config.get("cr_schedule", False)
@@ -131,6 +131,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             "max_no_improvement_evals", 0
         )
         self.min_evals: int = self.rl_config.get("min_evals", 0)
+        self.rl_config.setdefault("tensorboard_throttle", 1)
         self.plot_new_best: bool = self.rl_config.get("plot_new_best", False)
         self.check_envs: bool = self.rl_config.get("check_envs", True)
         self.progressbar_callback: Optional[ProgressBarCallback] = None
@@ -148,7 +149,7 @@ class ReforceXY(BaseReinforcementLearningModel):
         self.optuna_n_startup_trials: int = self.rl_config_optuna.get(
             "n_startup_trials", 15
         )
-        self.optuna_callback: Optional[MaskableTrialEvalCallback] = None
+        self.optuna_eval_callback: Optional[MaskableTrialEvalCallback] = None
         self._model_params_cache: Optional[Dict[str, Any]] = None
         self.unset_unsupported()
 
@@ -251,7 +252,7 @@ class ReforceXY(BaseReinforcementLearningModel):
 
         if self.check_envs:
             logger.info("Checking environments")
-            _train_env_check = self.MyRLEnv(
+            _train_env_check = MyRLEnv(
                 df=train_df,
                 prices=prices_train,
                 id="train_env_check",
@@ -262,7 +263,7 @@ class ReforceXY(BaseReinforcementLearningModel):
                 check_env(_train_env_check)
             finally:
                 _train_env_check.close()
-            _eval_env_check = self.MyRLEnv(
+            _eval_env_check = MyRLEnv(
                 df=test_df,
                 prices=prices_test,
                 id="eval_env_check",
@@ -277,7 +278,7 @@ class ReforceXY(BaseReinforcementLearningModel):
         logger.info("Populating environments: %s", self.n_envs)
         train_fns = [
             make_env(
-                self.MyRLEnv,
+                MyRLEnv,
                 f"train_env{i}",
                 i,
                 seed,
@@ -289,7 +290,7 @@ class ReforceXY(BaseReinforcementLearningModel):
         ]
         eval_fns = [
             make_env(
-                self.MyRLEnv,
+                MyRLEnv,
                 f"eval_env{i}",
                 i,
                 seed + 10_000,
@@ -456,7 +457,11 @@ class ReforceXY(BaseReinforcementLearningModel):
             )
 
         if self.activate_tensorboard:
-            info_callback = InfoMetricsCallback(actions=Actions, verbose=verbose)
+            info_callback = InfoMetricsCallback(
+                actions=Actions,
+                verbose=verbose,
+                throttle=self.rl_config.get("tensorboard_throttle", 1),
+            )
             callbacks.append(info_callback)
             if self.plot_new_best:
                 rollout_plot_callback = RolloutPlotCallback(verbose=verbose)
@@ -465,6 +470,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             self.progressbar_callback = ProgressBarCallback()
             callbacks.append(self.progressbar_callback)
 
+        use_masking = self.action_masking and is_masking_supported(self.eval_env)
         if not trial:
             self.eval_callback = MaskableEvalCallback(
                 self.eval_env,
@@ -472,7 +478,7 @@ class ReforceXY(BaseReinforcementLearningModel):
                 deterministic=True,
                 render=False,
                 best_model_save_path=data_path,
-                use_masking=self.action_masking,
+                use_masking=use_masking,
                 callback_on_new_best=rollout_plot_callback,
                 callback_after_eval=no_improvement_callback,
                 verbose=verbose,
@@ -480,17 +486,17 @@ class ReforceXY(BaseReinforcementLearningModel):
             callbacks.append(self.eval_callback)
         else:
             trial_data_path = f"{data_path}/hyperopt/trial_{trial.number}"
-            self.optuna_callback = MaskableTrialEvalCallback(
+            self.optuna_eval_callback = MaskableTrialEvalCallback(
                 self.eval_env,
                 trial,
                 eval_freq=eval_freq,
                 deterministic=True,
                 render=False,
                 best_model_save_path=trial_data_path,
-                use_masking=self.action_masking,
+                use_masking=use_masking,
                 verbose=verbose,
             )
-            callbacks.append(self.optuna_callback)
+            callbacks.append(self.optuna_eval_callback)
         return callbacks
 
     def fit(
@@ -659,7 +665,7 @@ class ReforceXY(BaseReinforcementLearningModel):
 
             if self.action_masking and self.inference_masking:
                 action_masks_param["action_masks"] = ReforceXY.get_action_masks(
-                    simulated_position, None
+                    simulated_position
                 )
 
             action, _ = model.predict(
@@ -970,10 +976,10 @@ class ReforceXY(BaseReinforcementLearningModel):
         if nan_encountered:
             raise TrialPruned("NaN encountered during training")
 
-        if self.optuna_callback.is_pruned:
+        if self.optuna_eval_callback.is_pruned:
             raise TrialPruned()
 
-        return self.optuna_callback.best_mean_reward
+        return self.optuna_eval_callback.last_mean_reward
 
     def close_envs(self) -> None:
         """
@@ -990,727 +996,736 @@ class ReforceXY(BaseReinforcementLearningModel):
             finally:
                 self.eval_env = None
 
-    class MyRLEnv(Base5ActionRLEnv):
+
+def make_env(
+    MyRLEnv: Type[BaseEnvironment],
+    env_id: str,
+    rank: int,
+    seed: int,
+    train_df: DataFrame,
+    price: DataFrame,
+    env_info: Dict[str, Any],
+) -> Callable[[], BaseEnvironment]:
+    """
+    Utility function for multiprocessed env.
+
+    :param MyRLEnv: (Type[BaseEnvironment]) environment class to instantiate
+    :param env_id: (str) the environment ID
+    :param rank: (int) index of the subprocess
+    :param seed: (int) the initial seed for RNG
+    :param train_df: (DataFrame) feature dataframe for the environment
+    :param price: (DataFrame) aligned price dataframe
+    :param env_info: (dict) all required arguments to instantiate the environment
+    :return:
+    (Callable[[], BaseEnvironment]) closure that when called instantiates and returns the environment
+    """
+
+    def _init() -> BaseEnvironment:
+        return MyRLEnv(
+            df=train_df, prices=price, id=env_id, seed=seed + rank, **env_info
+        )
+
+    return _init
+
+
+MyRLEnv: Type[BaseEnvironment]
+
+
+class MyRLEnv(Base5ActionRLEnv):
+    """
+    Env
+    """
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self._set_observation_space()
+        self.action_masking: bool = self.rl_config.get("action_masking", False)
+        self.force_actions: bool = self.rl_config.get("force_actions", False)
+        self._force_action: Optional[ForceActions] = None
+        self.take_profit: float = self.config.get("minimal_roi", {}).get("0", 0.03)
+        self.stop_loss: float = self.config.get("stoploss", -0.02)
+        self.timeout: int = self.rl_config.get("max_trade_duration_candles", 128)
+        self._last_closed_position: Optional[Positions] = None
+        self._last_closed_trade_tick: int = 0
+        if self.force_actions:
+            logger.info(
+                "%s - take_profit: %s, stop_loss: %s, timeout: %s candles (%s days), observation_space: %s",
+                self.id,
+                self.take_profit,
+                self.stop_loss,
+                self.timeout,
+                steps_to_days(self.timeout, self.config.get("timeframe")),
+                self.observation_space,
+            )
+
+    def _set_observation_space(self) -> None:
         """
-        Env
+        Set the observation space
         """
+        signal_features = self.signal_features.shape[1]
+        if self.add_state_info:
+            # STATE_INFO
+            self.state_features = ["pnl", "position", "trade_duration"]
+            self.total_features = signal_features + len(self.state_features)
+        else:
+            self.state_features = []
+            self.total_features = signal_features
 
-        def __init__(self, **kwargs):
-            super().__init__(**kwargs)
-            self._set_observation_space()
-            self.action_masking: bool = self.rl_config.get("action_masking", False)
-            self.force_actions: bool = self.rl_config.get("force_actions", False)
-            self._force_action: Optional[ForceActions] = None
-            self.take_profit: float = self.config.get("minimal_roi", {}).get("0", 0.03)
-            self.stop_loss: float = self.config.get("stoploss", -0.02)
-            self.timeout: int = self.rl_config.get("max_trade_duration_candles", 128)
-            self._last_closed_position: Optional[Positions] = None
-            self._last_closed_trade_tick: int = 0
-            if self.force_actions:
-                logger.info(
-                    "%s - take_profit: %s, stop_loss: %s, timeout: %s candles (%s days), observation_space: %s",
-                    self.id,
-                    self.take_profit,
-                    self.stop_loss,
-                    self.timeout,
-                    steps_to_days(self.timeout, self.config.get("timeframe")),
-                    self.observation_space,
-                )
+        self.shape = (self.window_size, self.total_features)
+        self.observation_space = Box(
+            low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float32
+        )
 
-        def _set_observation_space(self) -> None:
-            """
-            Set the observation space
-            """
-            signal_features = self.signal_features.shape[1]
-            if self.add_state_info:
-                # STATE_INFO
-                self.state_features = ["pnl", "position", "trade_duration"]
-                self.total_features = signal_features + len(self.state_features)
-            else:
-                self.state_features = []
-                self.total_features = signal_features
+    def _is_valid(self, action: int) -> bool:
+        return ReforceXY.get_action_masks(self._position, self._force_action)[action]
 
-            self.shape = (self.window_size, self.total_features)
-            self.observation_space = Box(
-                low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float32
-            )
+    def reset_env(
+        self,
+        df: DataFrame,
+        prices: DataFrame,
+        window_size: int,
+        reward_kwargs: Dict[str, Any],
+        starting_point=True,
+    ) -> None:
+        """
+        Resets the environment when the agent fails
+        """
+        super().reset_env(df, prices, window_size, reward_kwargs, starting_point)
+        self._set_observation_space()
 
-        def _is_valid(self, action: int) -> bool:
-            return ReforceXY.get_action_masks(self._position, self._force_action)[
-                action
-            ]
-
-        def reset_env(
-            self,
-            df: DataFrame,
-            prices: DataFrame,
-            window_size: int,
-            reward_kwargs: Dict[str, Any],
-            starting_point=True,
-        ) -> None:
-            """
-            Resets the environment when the agent fails
-            """
-            super().reset_env(df, prices, window_size, reward_kwargs, starting_point)
-            self._set_observation_space()
-
-        def reset(
-            self, seed=None, **kwargs
-        ) -> Tuple[NDArray[np.float32], Dict[str, Any]]:
-            """
-            Reset is called at the beginning of every episode
-            """
-            observation, history = super().reset(seed, **kwargs)
-            self._force_action: Optional[ForceActions] = None
-            self._last_closed_position: Optional[Positions] = None
-            self._last_closed_trade_tick: int = 0
-            return observation, history
-
-        def _get_exit_reward_factor(
-            self,
-            factor: float,
-            pnl: float,
-            trade_duration: int,
-            max_trade_duration: int,
-        ) -> float:
-            """
-            Compute the reward factor at trade exit
-            """
-            if trade_duration <= max_trade_duration:
-                factor *= 1.5
-            elif trade_duration > max_trade_duration:
-                factor *= 0.5
-            if pnl > self.profit_aim * self.rr:
-                factor *= float(
-                    self.rl_config.get("model_reward_parameters", {}).get(
-                        "win_reward_factor", 2.0
-                    )
-                )
-            return factor
-
-        def calculate_reward(self, action: int) -> float:
-            """
-            An example reward function. This is the one function that users will likely
-            wish to inject their own creativity into.
-
-            Warning!
-            This is function is a showcase of functionality designed to show as many possible
-            environment control features as possible. It is also designed to run quickly
-            on small computers. This is a benchmark, it is *not* for live production.
-
-            :param action: int = The action made by the agent for the current candle.
-            :return:
-            float = the reward to give to the agent for current step (used for optimization
-                    of weights in NN)
-            """
-            # first, penalize if the action is not valid
-            if not self.action_masking and not self._is_valid(action):
-                self.tensorboard_log("invalid", category="actions")
-                return self.rl_config.get("model_reward_parameters", {}).get(
-                    "invalid_action", -2.0
+    def reset(self, seed=None, **kwargs) -> Tuple[NDArray[np.float32], Dict[str, Any]]:
+        """
+        Reset is called at the beginning of every episode
+        """
+        observation, history = super().reset(seed, **kwargs)
+        self._force_action: Optional[ForceActions] = None
+        self._last_closed_position: Optional[Positions] = None
+        self._last_closed_trade_tick: int = 0
+        return observation, history
+
+    def _get_exit_reward_factor(
+        self,
+        factor: float,
+        pnl: float,
+        trade_duration: int,
+        max_trade_duration: int,
+    ) -> float:
+        """
+        Compute the reward factor at trade exit
+        """
+        if trade_duration <= max_trade_duration:
+            factor *= 1.5
+        elif trade_duration > max_trade_duration:
+            factor *= 0.5
+        if pnl > self.profit_aim * self.rr:
+            factor *= float(
+                self.rl_config.get("model_reward_parameters", {}).get(
+                    "win_reward_factor", 2.0
                 )
+            )
+        return factor
 
-            pnl = self.get_unrealized_profit()
-            # mrr = self.get_most_recent_return()
-            # mrp = self.get_most_recent_profit()
+    def calculate_reward(self, action: int) -> float:
+        """
+        An example reward function. This is the one function that users will likely
+        wish to inject their own creativity into.
 
-            max_trade_duration = max(1, self.timeout)
-            trade_duration = self.get_trade_duration()
+        Warning!
+        This is function is a showcase of functionality designed to show as many possible
+        environment control features as possible. It is also designed to run quickly
+        on small computers. This is a benchmark, it is *not* for live production.
 
-            factor = 100.0
+        :param action: int = The action made by the agent for the current candle.
+        :return:
+        float = the reward to give to the agent for current step (used for optimization
+                of weights in NN)
+        """
+        # first, penalize if the action is not valid
+        if not self.action_masking and not self._is_valid(action):
+            self.tensorboard_log("invalid", category="actions")
+            return self.rl_config.get("model_reward_parameters", {}).get(
+                "invalid_action", -2.0
+            )
 
-            # Force exits
-            if self._force_action in (
-                ForceActions.Take_profit,
-                ForceActions.Stop_loss,
-                ForceActions.Timeout,
-            ):
-                return pnl * self._get_exit_reward_factor(
-                    factor, pnl, trade_duration, max_trade_duration
-                )
+        pnl = self.get_unrealized_profit()
+        # mrr = self.get_most_recent_return()
+        # mrp = self.get_most_recent_profit()
 
-            # # you can use feature values from dataframe
-            # rsi_now = self.get_feature_value(
-            #     name="%-rsi",
-            #     period=8,
-            #     pair=self.pair,
-            #     timeframe=self.config.get("timeframe"),
-            #     raw=True,
-            # )
-
-            # # reward agent for entering trades when RSI is low
-            # if (
-            #     action in (Actions.Long_enter.value, Actions.Short_enter.value)
-            #     and self._position == Positions.Neutral
-            # ):
-            #     if rsi_now < 40:
-            #         factor = 40 / rsi_now
-            #     else:
-            #         factor = 1
-            #     return 25.0 * factor
-
-            # discourage agent from sitting idle too long
-            if action == Actions.Neutral.value and self._position == Positions.Neutral:
-                return -0.01 * self.get_idle_duration() ** 1.05
-
-            # pnl and duration aware agent reward while sitting in position
-            if (
-                self._position in (Positions.Short, Positions.Long)
-                and action == Actions.Neutral.value
-            ):
-                duration_fraction = trade_duration / max_trade_duration
-                max_pnl = max(self.get_most_recent_max_pnl(), pnl)
+        max_trade_duration = max(1, self.timeout)
+        trade_duration = self.get_trade_duration()
 
-                if max_pnl > 0:
-                    drawdown_penalty = (
-                        0.0025 * factor * (max_pnl - pnl) * duration_fraction
-                    )
-                else:
-                    drawdown_penalty = 0.0
+        factor = 100.0
 
-                lambda1 = 0.05
-                lambda2 = 0.1
-                if pnl >= 0:
-                    if duration_fraction <= 1.0:
-                        duration_penalty_factor = 1.0
-                    else:
-                        duration_penalty_factor = 1.0 / (
-                            1.0 + lambda1 * (duration_fraction - 1.0)
-                        )
-                    return (
-                        factor * pnl * duration_penalty_factor
-                        - lambda2 * duration_fraction
-                        - drawdown_penalty
-                    )
+        # Force exits
+        if self._force_action in (
+            ForceActions.Take_profit,
+            ForceActions.Stop_loss,
+            ForceActions.Timeout,
+        ):
+            return pnl * self._get_exit_reward_factor(
+                factor, pnl, trade_duration, max_trade_duration
+            )
+
+        # # you can use feature values from dataframe
+        # rsi_now = self.get_feature_value(
+        #     name="%-rsi",
+        #     period=8,
+        #     pair=self.pair,
+        #     timeframe=self.config.get("timeframe"),
+        #     raw=True,
+        # )
+
+        # # reward agent for entering trades when RSI is low
+        # if (
+        #     action in (Actions.Long_enter.value, Actions.Short_enter.value)
+        #     and self._position == Positions.Neutral
+        # ):
+        #     if rsi_now < 40:
+        #         factor = 40 / rsi_now
+        #     else:
+        #         factor = 1
+        #     return 25.0 * factor
+
+        # discourage agent from sitting idle too long
+        if action == Actions.Neutral.value and self._position == Positions.Neutral:
+            return -0.01 * self.get_idle_duration() ** 1.05
+
+        # pnl and duration aware agent reward while sitting in position
+        if (
+            self._position in (Positions.Short, Positions.Long)
+            and action == Actions.Neutral.value
+        ):
+            duration_fraction = trade_duration / max_trade_duration
+            max_pnl = max(self.get_most_recent_max_pnl(), pnl)
+
+            if max_pnl > 0:
+                drawdown_penalty = 0.0025 * factor * (max_pnl - pnl) * duration_fraction
+            else:
+                drawdown_penalty = 0.0
+
+            lambda1 = 0.05
+            lambda2 = 0.1
+            if pnl >= 0:
+                if duration_fraction <= 1.0:
+                    duration_penalty_factor = 1.0
                 else:
-                    return (
-                        factor * pnl * (1.0 + lambda1 * duration_fraction)
-                        - 2.0 * lambda2 * duration_fraction
-                        - drawdown_penalty
+                    duration_penalty_factor = 1.0 / (
+                        1.0 + lambda1 * (duration_fraction - 1.0)
                     )
-
-            # close long
-            if action == Actions.Long_exit.value and self._position == Positions.Long:
-                return pnl * self._get_exit_reward_factor(
-                    factor, pnl, trade_duration, max_trade_duration
+                return (
+                    factor * pnl * duration_penalty_factor
+                    - lambda2 * duration_fraction
+                    - drawdown_penalty
                 )
-
-            # close short
-            if action == Actions.Short_exit.value and self._position == Positions.Short:
-                return pnl * self._get_exit_reward_factor(
-                    factor, pnl, trade_duration, max_trade_duration
+            else:
+                return (
+                    factor * pnl * (1.0 + lambda1 * duration_fraction)
+                    - 2.0 * lambda2 * duration_fraction
+                    - drawdown_penalty
                 )
 
-            return 0.0
+        # close long
+        if action == Actions.Long_exit.value and self._position == Positions.Long:
+            return pnl * self._get_exit_reward_factor(
+                factor, pnl, trade_duration, max_trade_duration
+            )
 
-        def _get_observation(self) -> NDArray[np.float32]:
-            """
-            This may or may not be independent of action types, user can inherit
-            this in their custom "MyRLEnv"
-            """
-            features_window = self.signal_features[
-                (self._current_tick - self.window_size) : self._current_tick
-            ]
-            if self.add_state_info:
-                features_and_state = DataFrame(
-                    np.zeros(
-                        (len(features_window), len(self.state_features)),
-                        dtype=np.float32,
-                    ),
-                    columns=self.state_features,
-                    index=features_window.index,
-                )
-                # STATE_INFO
-                features_and_state["pnl"] = self.get_unrealized_profit()
-                features_and_state["position"] = self._position.value
-                features_and_state["trade_duration"] = self.get_trade_duration()
+        # close short
+        if action == Actions.Short_exit.value and self._position == Positions.Short:
+            return pnl * self._get_exit_reward_factor(
+                factor, pnl, trade_duration, max_trade_duration
+            )
 
-                features_and_state = concat(
-                    [features_window, features_and_state], axis=1
-                )
-                return features_and_state.to_numpy(dtype=np.float32)
-            else:
-                return features_window.to_numpy(dtype=np.float32)
+        return 0.0
 
-        def _get_force_action(self) -> Optional[ForceActions]:
-            if not self.force_actions or self._position == Positions.Neutral:
-                return None
+    def _get_observation(self) -> NDArray[np.float32]:
+        """
+        This may or may not be independent of action types, user can inherit
+        this in their custom "MyRLEnv"
+        """
+        features_window = self.signal_features[
+            (self._current_tick - self.window_size) : self._current_tick
+        ]
+        if self.add_state_info:
+            features_and_state = DataFrame(
+                np.zeros(
+                    (len(features_window), len(self.state_features)),
+                    dtype=np.float32,
+                ),
+                columns=self.state_features,
+                index=features_window.index,
+            )
+            # STATE_INFO
+            features_and_state["pnl"] = self.get_unrealized_profit()
+            features_and_state["position"] = self._position.value
+            features_and_state["trade_duration"] = self.get_trade_duration()
 
-            trade_duration = self.get_trade_duration()
-            if trade_duration <= 1:
-                return None
-            if trade_duration >= self.timeout:
-                return ForceActions.Timeout
-
-            pnl = self.get_unrealized_profit()
-            if pnl >= self.take_profit:
-                return ForceActions.Take_profit
-            if pnl <= self.stop_loss:
-                return ForceActions.Stop_loss
+            features_and_state = concat([features_window, features_and_state], axis=1)
+            return features_and_state.to_numpy(dtype=np.float32)
+        else:
+            return features_window.to_numpy(dtype=np.float32)
+
+    def _get_force_action(self) -> Optional[ForceActions]:
+        if not self.force_actions or self._position == Positions.Neutral:
             return None
 
-        def _get_position(self, action: int) -> Positions:
-            return {
-                Actions.Long_enter.value: Positions.Long,
-                Actions.Short_enter.value: Positions.Short,
-            }[action]
-
-        def _enter_trade(self, action: int) -> None:
-            self._position = self._get_position(action)
-            self._last_trade_tick = self._current_tick
-
-        def _exit_trade(self) -> None:
-            self._update_total_profit()
-            self._last_closed_position = self._position
-            self._position = Positions.Neutral
-            self._last_closed_trade_tick = self._current_tick
-            self._last_trade_tick = None
-
-        def execute_trade(self, action: int) -> Optional[str]:
-            """
-            Execute trade based on the given action
-            """
-            # Force exit trade
-            if self._force_action:
-                self._exit_trade()
-                self.tensorboard_log(
-                    f"{self._force_action.name}", category="actions/force"
-                )
-                return f"{self._force_action.name}"
+        trade_duration = self.get_trade_duration()
+        if trade_duration <= 1:
+            return None
+        if trade_duration >= self.timeout:
+            return ForceActions.Timeout
+
+        pnl = self.get_unrealized_profit()
+        if pnl >= self.take_profit:
+            return ForceActions.Take_profit
+        if pnl <= self.stop_loss:
+            return ForceActions.Stop_loss
+        return None
 
-            if not self.is_tradesignal(action):
-                return None
+    def _get_position(self, action: int) -> Positions:
+        return {
+            Actions.Long_enter.value: Positions.Long,
+            Actions.Short_enter.value: Positions.Short,
+        }[action]
+
+    def _enter_trade(self, action: int) -> None:
+        self._position = self._get_position(action)
+        self._last_trade_tick = self._current_tick
+
+    def _exit_trade(self) -> None:
+        self._update_total_profit()
+        self._last_closed_position = self._position
+        self._position = Positions.Neutral
+        self._last_closed_trade_tick = self._current_tick
+        self._last_trade_tick = None
+
+    def execute_trade(self, action: int) -> Optional[str]:
+        """
+        Execute trade based on the given action
+        """
+        # Force exit trade
+        if self._force_action:
+            self._exit_trade()
+            self.tensorboard_log(f"{self._force_action.name}", category="actions/force")
+            return f"{self._force_action.name}"
 
-            # Enter trade based on action
-            if action in (Actions.Long_enter.value, Actions.Short_enter.value):
-                self._enter_trade(action)
-                return f"{self._position.name}_enter"
+        if not self.is_tradesignal(action):
+            return None
 
-            # Exit trade based on action
-            if action in (Actions.Long_exit.value, Actions.Short_exit.value):
-                self._exit_trade()
-                return f"{self._last_closed_position.name}_exit"
+        # Enter trade based on action
+        if action in (Actions.Long_enter.value, Actions.Short_enter.value):
+            self._enter_trade(action)
+            return f"{self._position.name}_enter"
 
-            return None
+        # Exit trade based on action
+        if action in (Actions.Long_exit.value, Actions.Short_exit.value):
+            self._exit_trade()
+            return f"{self._last_closed_position.name}_exit"
 
-        def step(
-            self, action: int
-        ) -> Tuple[NDArray[np.float32], float, bool, bool, Dict[str, Any]]:
-            """
-            Take a step in the environment based on the provided action
-            """
-            self._current_tick += 1
-            self._update_unrealized_total_profit()
-            previous_pnl = self.get_unrealized_profit()
-            self._update_portfolio_log_returns()
-            self._force_action = self._get_force_action()
-            reward = self.calculate_reward(action)
-            self.total_reward += reward
-            self.tensorboard_log(Actions._member_names_[action], category="actions")
-            trade_type = self.execute_trade(action)
-            if trade_type is not None:
-                self.append_trade_history(
-                    trade_type, self.current_price(), previous_pnl
-                )
-            self._position_history.append(self._position)
-            info = {
+        return None
+
+    def step(
+        self, action: int
+    ) -> Tuple[NDArray[np.float32], float, bool, bool, Dict[str, Any]]:
+        """
+        Take a step in the environment based on the provided action
+        """
+        self._current_tick += 1
+        self._update_unrealized_total_profit()
+        previous_pnl = self.get_unrealized_profit()
+        self._update_portfolio_log_returns()
+        self._force_action = self._get_force_action()
+        reward = self.calculate_reward(action)
+        self.total_reward += reward
+        self.tensorboard_log(Actions._member_names_[action], category="actions")
+        trade_type = self.execute_trade(action)
+        if trade_type is not None:
+            self.append_trade_history(trade_type, self.current_price(), previous_pnl)
+        self._position_history.append(self._position)
+        info = {
+            "tick": self._current_tick,
+            "position": self._position.value,
+            "action": action,
+            "force_action": (self._force_action.name if self._force_action else None),
+            "previous_pnl": round(previous_pnl, 5),
+            "pnl": round(self.get_unrealized_profit(), 5),
+            "reward": round(reward, 5),
+            "total_reward": round(self.total_reward, 5),
+            "total_profit": round(self._total_profit, 5),
+            "idle_duration": self.get_idle_duration(),
+            "trade_duration": self.get_trade_duration(),
+            "trade_count": int(len(self.trade_history) // 2),
+        }
+        self._update_history(info)
+        return (
+            self._get_observation(),
+            reward,
+            self.is_terminated(),
+            self.is_truncated(),
+            info,
+        )
+
+    def append_trade_history(
+        self, trade_type: str, price: float, profit: float
+    ) -> None:
+        self.trade_history.append(
+            {
                 "tick": self._current_tick,
-                "position": self._position.value,
-                "action": action,
-                "force_action": (
-                    self._force_action.name if self._force_action else None
-                ),
-                "previous_pnl": round(previous_pnl, 5),
-                "pnl": round(self.get_unrealized_profit(), 5),
-                "reward": round(reward, 5),
-                "total_reward": round(self.total_reward, 5),
-                "total_profit": round(self._total_profit, 5),
-                "idle_duration": self.get_idle_duration(),
-                "trade_duration": self.get_trade_duration(),
-                "trade_count": int(len(self.trade_history) // 2),
+                "type": trade_type.lower(),
+                "price": price,
+                "profit": profit,
             }
-            self._update_history(info)
-            return (
-                self._get_observation(),
-                reward,
-                self.is_terminated(),
-                self.is_truncated(),
-                info,
-            )
-
-        def append_trade_history(
-            self, trade_type: str, price: float, profit: float
-        ) -> None:
-            self.trade_history.append(
-                {
-                    "tick": self._current_tick,
-                    "type": trade_type.lower(),
-                    "price": price,
-                    "profit": profit,
-                }
-            )
+        )
 
-        def is_terminated(self) -> bool:
-            return bool(
-                self._current_tick == self._end_tick
-                or self._total_profit <= self.max_drawdown
-                or self._total_unrealized_profit <= self.max_drawdown
-            )
+    def is_terminated(self) -> bool:
+        return bool(
+            self._current_tick == self._end_tick
+            or self._total_profit <= self.max_drawdown
+            or self._total_unrealized_profit <= self.max_drawdown
+        )
 
-        def is_truncated(self) -> bool:
-            return False
+    def is_truncated(self) -> bool:
+        return False
 
-        def is_tradesignal(self, action: int) -> bool:
-            """
-            Determine if the action is entry or exit
-            """
-            return (
-                (
-                    action in (Actions.Short_enter.value, Actions.Long_enter.value)
-                    and self._position == Positions.Neutral
-                )
-                or (
-                    action == Actions.Long_exit.value
-                    and self._position == Positions.Long
-                )
-                or (
-                    action == Actions.Short_exit.value
-                    and self._position == Positions.Short
-                )
+    def is_tradesignal(self, action: int) -> bool:
+        """
+        Determine if the action is entry or exit
+        """
+        return (
+            (
+                action in (Actions.Short_enter.value, Actions.Long_enter.value)
+                and self._position == Positions.Neutral
             )
-
-        def action_masks(self) -> NDArray[bool]:
-            return ReforceXY.get_action_masks(self._position, self._force_action)
-
-        def get_feature_value(
-            self,
-            name: str,
-            period: int = 0,
-            shift: int = 0,
-            pair: str = "",
-            timeframe: str = "",
-            raw: bool = True,
-        ) -> float:
-            """
-            Get feature value
-            """
-            feature_parts = [name]
-            if period:
-                feature_parts.append(f"_{period}")
-            if shift:
-                feature_parts.append(f"_shift-{shift}")
-            if pair:
-                feature_parts.append(f"_{pair.replace(':', '')}")
-            if timeframe:
-                feature_parts.append(f"_{timeframe}")
-            feature_col = "".join(feature_parts)
-
-            if not raw:
-                return self.signal_features[feature_col].iloc[self._current_tick]
-            return self.raw_features[feature_col].iloc[self._current_tick]
-
-        def get_idle_duration(self) -> int:
-            if self._position != Positions.Neutral:
-                return 0
-            if not self._last_closed_trade_tick:
-                return self._current_tick - self._start_tick
-            return self._current_tick - self._last_closed_trade_tick
-
-        def get_most_recent_max_pnl(self) -> float:
-            """
-            Calculate the most recent maximum unrealized profit if in a trade
-            """
-            if self._last_trade_tick is None:
-                return 0.0
-            if self._position == Positions.Neutral:
-                return 0.0
-            pnl_history = self.history.get("pnl")
-            if not pnl_history or len(pnl_history) == 0:
-                return 0.0
-
-            pnl_history = np.asarray(pnl_history)
-            ticks = self.history.get("tick")
-            if not ticks:
-                return 0.0
-            ticks = np.asarray(ticks)
-            start = np.searchsorted(ticks, self._last_trade_tick, side="left")
-            trade_pnl_history = pnl_history[start:]
-            if trade_pnl_history.size == 0:
-                return 0.0
-            return np.max(trade_pnl_history)
-
-        def get_most_recent_return(self) -> float:
-            """
-            Calculate the tick to tick return if in a trade.
-            Return is generated from rising prices in Long and falling prices in Short positions.
-            The actions Sell/Buy or Hold during a Long position trigger the sell/buy-fee.
-            """
-            if self._last_trade_tick is None:
-                return 0.0
-            if self._position == Positions.Neutral:
-                return 0.0
-            elif self._position == Positions.Long:
-                current_price = self.current_price()
-                previous_price = self.previous_price()
-                if (
-                    self._position_history[self._current_tick - 1] == Positions.Short
-                    or self._position_history[self._current_tick - 1]
-                    == Positions.Neutral
-                ):
-                    previous_price = self.add_entry_fee(previous_price)
-                return np.log(current_price) - np.log(previous_price)
-            elif self._position == Positions.Short:
-                current_price = self.current_price()
-                previous_price = self.previous_price()
-                if (
-                    self._position_history[self._current_tick - 1] == Positions.Long
-                    or self._position_history[self._current_tick - 1]
-                    == Positions.Neutral
-                ):
-                    previous_price = self.add_exit_fee(previous_price)
-                return np.log(previous_price) - np.log(current_price)
-            return 0.0
-
-        def _update_portfolio_log_returns(self):
-            self.portfolio_log_returns[self._current_tick] = (
-                self.get_most_recent_return()
+            or (action == Actions.Long_exit.value and self._position == Positions.Long)
+            or (
+                action == Actions.Short_exit.value and self._position == Positions.Short
             )
+        )
 
-        def get_most_recent_profit(self) -> float:
-            """
-            Calculate the tick to tick unrealized profit if in a trade
-            """
-            if self._last_trade_tick is None:
-                return 0.0
-            if self._position == Positions.Neutral:
-                return 0.0
-            elif self._position == Positions.Long:
-                current_price = self.add_exit_fee(self.current_price())
-                previous_price = self.add_entry_fee(self.previous_price())
-                return (current_price - previous_price) / previous_price
-            elif self._position == Positions.Short:
-                current_price = self.add_entry_fee(self.current_price())
-                previous_price = self.add_exit_fee(self.previous_price())
-                return (previous_price - current_price) / previous_price
+    def action_masks(self) -> NDArray[bool]:
+        return ReforceXY.get_action_masks(self._position, self._force_action)
+
+    def get_feature_value(
+        self,
+        name: str,
+        period: int = 0,
+        shift: int = 0,
+        pair: str = "",
+        timeframe: str = "",
+        raw: bool = True,
+    ) -> float:
+        """
+        Get feature value
+        """
+        feature_parts = [name]
+        if period:
+            feature_parts.append(f"_{period}")
+        if shift:
+            feature_parts.append(f"_shift-{shift}")
+        if pair:
+            feature_parts.append(f"_{pair.replace(':', '')}")
+        if timeframe:
+            feature_parts.append(f"_{timeframe}")
+        feature_col = "".join(feature_parts)
+
+        if not raw:
+            return self.signal_features[feature_col].iloc[self._current_tick]
+        return self.raw_features[feature_col].iloc[self._current_tick]
+
+    def get_idle_duration(self) -> int:
+        if self._position != Positions.Neutral:
+            return 0
+        if not self._last_closed_trade_tick:
+            return self._current_tick - self._start_tick
+        return self._current_tick - self._last_closed_trade_tick
+
+    def get_most_recent_max_pnl(self) -> float:
+        """
+        Calculate the most recent maximum unrealized profit if in a trade
+        """
+        if self._last_trade_tick is None:
+            return 0.0
+        if self._position == Positions.Neutral:
+            return 0.0
+        pnl_history = self.history.get("pnl")
+        if not pnl_history or len(pnl_history) == 0:
             return 0.0
 
-        def previous_price(self) -> float:
-            return self.prices.iloc[self._current_tick - 1].get("open")
+        pnl_history = np.asarray(pnl_history)
+        ticks = self.history.get("tick")
+        if not ticks:
+            return 0.0
+        ticks = np.asarray(ticks)
+        start = np.searchsorted(ticks, self._last_trade_tick, side="left")
+        trade_pnl_history = pnl_history[start:]
+        if trade_pnl_history.size == 0:
+            return 0.0
+        return np.max(trade_pnl_history)
 
-        def get_env_history(self) -> DataFrame:
-            """
-            Get environment data aligned on ticks, including optional trade events
-            """
-            if not self.history:
-                logger.warning("history is empty")
-                return DataFrame()
+    def get_most_recent_return(self) -> float:
+        """
+        Calculate the tick to tick return if in a trade.
+        Return is generated from rising prices in Long and falling prices in Short positions.
+        The actions Sell/Buy or Hold during a Long position trigger the sell/buy-fee.
+        """
+        if self._last_trade_tick is None:
+            return 0.0
+        if self._position == Positions.Neutral:
+            return 0.0
+        elif self._position == Positions.Long:
+            current_price = self.current_price()
+            previous_price = self.previous_price()
+            if (
+                self._position_history[self._current_tick - 1] == Positions.Short
+                or self._position_history[self._current_tick - 1] == Positions.Neutral
+            ):
+                previous_price = self.add_entry_fee(previous_price)
+            return np.log(current_price) - np.log(previous_price)
+        elif self._position == Positions.Short:
+            current_price = self.current_price()
+            previous_price = self.previous_price()
+            if (
+                self._position_history[self._current_tick - 1] == Positions.Long
+                or self._position_history[self._current_tick - 1] == Positions.Neutral
+            ):
+                previous_price = self.add_exit_fee(previous_price)
+            return np.log(previous_price) - np.log(current_price)
+        return 0.0
 
-            _history_df = DataFrame.from_dict(self.history)
-            if "tick" not in _history_df.columns:
-                logger.warning("'tick' column is missing from history")
-                return DataFrame()
+    def _update_portfolio_log_returns(self):
+        self.portfolio_log_returns[self._current_tick] = self.get_most_recent_return()
 
-            if self.trade_history:
-                _trade_history_df = DataFrame.from_dict(self.trade_history)
-                if "tick" in _trade_history_df.columns:
-                    _rollout_history = merge(
-                        _history_df, _trade_history_df, on="tick", how="left"
-                    )
-                else:
-                    _rollout_history = _history_df.copy()
+    def get_most_recent_profit(self) -> float:
+        """
+        Calculate the tick to tick unrealized profit if in a trade
+        """
+        if self._last_trade_tick is None:
+            return 0.0
+        if self._position == Positions.Neutral:
+            return 0.0
+        elif self._position == Positions.Long:
+            current_price = self.add_exit_fee(self.current_price())
+            previous_price = self.add_entry_fee(self.previous_price())
+            return (current_price - previous_price) / previous_price
+        elif self._position == Positions.Short:
+            current_price = self.add_entry_fee(self.current_price())
+            previous_price = self.add_exit_fee(self.previous_price())
+            return (previous_price - current_price) / previous_price
+        return 0.0
+
+    def previous_price(self) -> float:
+        return self.prices.iloc[self._current_tick - 1].get("open")
+
+    def get_env_history(self) -> DataFrame:
+        """
+        Get environment data aligned on ticks, including optional trade events
+        """
+        if not self.history:
+            logger.warning("history is empty")
+            return DataFrame()
+
+        _history_df = DataFrame.from_dict(self.history)
+        if "tick" not in _history_df.columns:
+            logger.warning("'tick' column is missing from history")
+            return DataFrame()
+
+        if self.trade_history:
+            _trade_history_df = DataFrame.from_dict(self.trade_history)
+            if "tick" in _trade_history_df.columns:
+                _rollout_history = merge(
+                    _history_df, _trade_history_df, on="tick", how="left"
+                )
             else:
                 _rollout_history = _history_df.copy()
+        else:
+            _rollout_history = _history_df.copy()
 
+        try:
+            history = merge(
+                _rollout_history,
+                self.prices,
+                left_on="tick",
+                right_index=True,
+                how="left",
+            )
+        except Exception:
             try:
+                _price_history = (
+                    self.prices.iloc[_rollout_history.tick]
+                    .copy()
+                    .reset_index(drop=True)
+                )
                 history = merge(
                     _rollout_history,
-                    self.prices,
-                    left_on="tick",
+                    _price_history,
+                    left_index=True,
                     right_index=True,
-                    how="left",
                 )
-            except Exception:
-                try:
-                    _price_history = (
-                        self.prices.iloc[_rollout_history.tick]
-                        .copy()
-                        .reset_index(drop=True)
-                    )
-                    history = merge(
-                        _rollout_history,
-                        _price_history,
-                        left_index=True,
-                        right_index=True,
-                    )
-                except Exception as e:
-                    logger.error(
-                        f"Failed to merge history with prices: {repr(e)}",
-                        exc_info=True,
-                    )
-                    return DataFrame()
-            return history
-
-        def get_env_plot(self) -> plt.Figure:
-            """
-            Plot trades and environment data
-            """
-
-            def transform_y_offset(ax, offset):
-                return mtransforms.offset_copy(ax.transData, fig=fig, y=offset)
-
-            def plot_markers(ax, xs, ys, marker, color, size, offset):
-                ax.plot(
-                    xs,
-                    ys,
-                    marker=marker,
-                    color=color,
-                    markersize=size,
-                    fillstyle="full",
-                    transform=transform_y_offset(ax, offset),
-                    linestyle="none",
-                    zorder=3,
+            except Exception as e:
+                logger.error(
+                    f"Failed to merge history with prices: {repr(e)}",
+                    exc_info=True,
                 )
+                return DataFrame()
+        return history
 
-            with plt.style.context("dark_background"):
-                fig, axs = plt.subplots(
-                    nrows=5,
-                    ncols=1,
-                    figsize=(14, 8),
-                    height_ratios=[5, 1, 1, 1, 1],
-                    sharex=True,
-                )
+    def get_env_plot(self) -> plt.Figure:
+        """
+        Plot trades and environment data
+        """
 
-                history = self.get_env_history()
-                if len(history) == 0:
-                    return fig
-
-                plot_window = int(self.rl_config.get("plot_window", 2000))
-                if plot_window > 0 and len(history) > plot_window:
-                    history = history.iloc[-plot_window:]
-
-                ticks = history.get("tick")
-                history_open = history.get("open")
-                if (
-                    ticks is None
-                    or len(ticks) == 0
-                    or history_open is None
-                    or len(history_open) == 0
-                ):
-                    return fig
-
-                axs[0].plot(ticks, history_open, linewidth=1, color="orchid", zorder=1)
-
-                history_type = history.get("type")
-                history_price = history.get("price")
-                if history_type is not None and history_price is not None:
-                    enter_long_mask = history_type == "long_enter"
-                    enter_short_mask = history_type == "short_enter"
-                    exit_long_mask = history_type == "long_exit"
-                    exit_short_mask = history_type == "short_exit"
-                    take_profit_mask = history_type == "take_profit"
-                    stop_loss_mask = history_type == "stop_loss"
-                    timeout_mask = history_type == "timeout"
-
-                    enter_long_x = ticks[enter_long_mask]
-                    enter_short_x = ticks[enter_short_mask]
-                    exit_long_x = ticks[exit_long_mask]
-                    exit_short_x = ticks[exit_short_mask]
-                    take_profit_x = ticks[take_profit_mask]
-                    stop_loss_x = ticks[stop_loss_mask]
-                    timeout_x = ticks[timeout_mask]
-
-                    enter_long_y = history.loc[enter_long_mask, "price"]
-                    enter_short_y = history.loc[enter_short_mask, "price"]
-                    exit_long_y = history.loc[exit_long_mask, "price"]
-                    exit_short_y = history.loc[exit_short_mask, "price"]
-                    take_profit_y = history.loc[take_profit_mask, "price"]
-                    stop_loss_y = history.loc[stop_loss_mask, "price"]
-                    timeout_y = history.loc[timeout_mask, "price"]
-
-                    plot_markers(
-                        axs[0], enter_long_x, enter_long_y, "^", "forestgreen", 5, -0.1
-                    )
-                    plot_markers(
-                        axs[0], enter_short_x, enter_short_y, "v", "firebrick", 5, 0.1
-                    )
-                    plot_markers(
-                        axs[0], exit_long_x, exit_long_y, ".", "cornflowerblue", 4, 0.1
-                    )
-                    plot_markers(
-                        axs[0], exit_short_x, exit_short_y, ".", "thistle", 4, -0.1
-                    )
-                    plot_markers(
-                        axs[0], take_profit_x, take_profit_y, "*", "lime", 8, 0.1
-                    )
-                    plot_markers(axs[0], stop_loss_x, stop_loss_y, "x", "red", 8, -0.1)
-                    plot_markers(axs[0], timeout_x, timeout_y, "1", "yellow", 8, 0.0)
-
-                axs[1].set_ylabel("pnl")
-                pnl_series = history.get("pnl")
-                if pnl_series is not None and len(pnl_series) > 0:
-                    axs[1].plot(
-                        ticks,
-                        pnl_series,
-                        linewidth=1,
-                        color="gray",
-                        label="pnl",
-                    )
-                previous_pnl_series = history.get("previous_pnl")
-                if previous_pnl_series is not None and len(previous_pnl_series) > 0:
-                    axs[1].plot(
-                        ticks,
-                        previous_pnl_series,
-                        linewidth=1,
-                        color="deepskyblue",
-                        label="previous_pnl",
-                    )
-                if (pnl_series is not None and len(pnl_series) > 0) or (
-                    previous_pnl_series is not None and len(previous_pnl_series) > 0
-                ):
-                    axs[1].legend(loc="upper left", fontsize=8)
-                axs[1].axhline(y=0, label="0", alpha=0.33, color="gray")
-                axs[1].axhline(
-                    y=self.take_profit, label="tp", alpha=0.33, color="green"
-                )
-                axs[1].axhline(y=self.stop_loss, label="sl", alpha=0.33, color="red")
-
-                axs[2].set_ylabel("reward")
-                reward_series = history.get("reward")
-                if reward_series is not None and len(reward_series) > 0:
-                    axs[2].plot(ticks, reward_series, linewidth=1, color="gray")
-                axs[2].axhline(y=0, label="0", alpha=0.33)
-
-                axs[3].set_ylabel("total_profit")
-                total_profit_series = history.get("total_profit")
-                if total_profit_series is not None and len(total_profit_series) > 0:
-                    axs[3].plot(ticks, total_profit_series, linewidth=1, color="gray")
-                axs[3].axhline(y=1, label="1", alpha=0.33)
-
-                axs[4].set_ylabel("total_reward")
-                total_reward_series = history.get("total_reward")
-                if total_reward_series is not None and len(total_reward_series) > 0:
-                    axs[4].plot(ticks, total_reward_series, linewidth=1, color="gray")
-                axs[4].axhline(y=0, label="0", alpha=0.33)
-                axs[4].set_xlabel("tick")
-
-            _borders = ["top", "right", "bottom", "left"]
-            for _ax in axs:
-                for _border in _borders:
-                    _ax.spines[_border].set_color("#5b5e4b")
-
-            fig.suptitle(
-                f"Total Reward: {self.total_reward:.2f} ~ "
-                + f"Total Profit: {self._total_profit:.2f} ~ "
-                + f"Trades: {int(len(self.trade_history) // 2)}",
+        def transform_y_offset(ax, offset):
+            return mtransforms.offset_copy(ax.transData, fig=fig, y=offset)
+
+        def plot_markers(ax, xs, ys, marker, color, size, offset):
+            ax.plot(
+                xs,
+                ys,
+                marker=marker,
+                color=color,
+                markersize=size,
+                fillstyle="full",
+                transform=transform_y_offset(ax, offset),
+                linestyle="none",
+                zorder=3,
             )
-            fig.tight_layout()
-            return fig
 
-        def close(self) -> None:
-            plt.close()
-            gc.collect()
-            if th.cuda.is_available():
-                th.cuda.empty_cache()
+        with plt.style.context("dark_background"):
+            fig, axs = plt.subplots(
+                nrows=5,
+                ncols=1,
+                figsize=(14, 8),
+                height_ratios=[5, 1, 1, 1, 1],
+                sharex=True,
+            )
+
+            history = self.get_env_history()
+            if len(history) == 0:
+                return fig
+
+            plot_window = int(self.rl_config.get("plot_window", 2000))
+            if plot_window > 0 and len(history) > plot_window:
+                history = history.iloc[-plot_window:]
+
+            ticks = history.get("tick")
+            history_open = history.get("open")
+            if (
+                ticks is None
+                or len(ticks) == 0
+                or history_open is None
+                or len(history_open) == 0
+            ):
+                return fig
+
+            axs[0].plot(ticks, history_open, linewidth=1, color="orchid", zorder=1)
+
+            history_type = history.get("type")
+            history_price = history.get("price")
+            if history_type is not None and history_price is not None:
+                enter_long_mask = history_type == "long_enter"
+                enter_short_mask = history_type == "short_enter"
+                exit_long_mask = history_type == "long_exit"
+                exit_short_mask = history_type == "short_exit"
+                take_profit_mask = history_type == "take_profit"
+                stop_loss_mask = history_type == "stop_loss"
+                timeout_mask = history_type == "timeout"
+
+                enter_long_x = ticks[enter_long_mask]
+                enter_short_x = ticks[enter_short_mask]
+                exit_long_x = ticks[exit_long_mask]
+                exit_short_x = ticks[exit_short_mask]
+                take_profit_x = ticks[take_profit_mask]
+                stop_loss_x = ticks[stop_loss_mask]
+                timeout_x = ticks[timeout_mask]
+
+                enter_long_y = history.loc[enter_long_mask, "price"]
+                enter_short_y = history.loc[enter_short_mask, "price"]
+                exit_long_y = history.loc[exit_long_mask, "price"]
+                exit_short_y = history.loc[exit_short_mask, "price"]
+                take_profit_y = history.loc[take_profit_mask, "price"]
+                stop_loss_y = history.loc[stop_loss_mask, "price"]
+                timeout_y = history.loc[timeout_mask, "price"]
+
+                plot_markers(
+                    axs[0], enter_long_x, enter_long_y, "^", "forestgreen", 5, -0.1
+                )
+                plot_markers(
+                    axs[0], enter_short_x, enter_short_y, "v", "firebrick", 5, 0.1
+                )
+                plot_markers(
+                    axs[0], exit_long_x, exit_long_y, ".", "cornflowerblue", 4, 0.1
+                )
+                plot_markers(
+                    axs[0], exit_short_x, exit_short_y, ".", "thistle", 4, -0.1
+                )
+                plot_markers(axs[0], take_profit_x, take_profit_y, "*", "lime", 8, 0.1)
+                plot_markers(axs[0], stop_loss_x, stop_loss_y, "x", "red", 8, -0.1)
+                plot_markers(axs[0], timeout_x, timeout_y, "1", "yellow", 8, 0.0)
+
+            axs[1].set_ylabel("pnl")
+            pnl_series = history.get("pnl")
+            if pnl_series is not None and len(pnl_series) > 0:
+                axs[1].plot(
+                    ticks,
+                    pnl_series,
+                    linewidth=1,
+                    color="gray",
+                    label="pnl",
+                )
+            previous_pnl_series = history.get("previous_pnl")
+            if previous_pnl_series is not None and len(previous_pnl_series) > 0:
+                axs[1].plot(
+                    ticks,
+                    previous_pnl_series,
+                    linewidth=1,
+                    color="deepskyblue",
+                    label="previous_pnl",
+                )
+            if (pnl_series is not None and len(pnl_series) > 0) or (
+                previous_pnl_series is not None and len(previous_pnl_series) > 0
+            ):
+                axs[1].legend(loc="upper left", fontsize=8)
+            axs[1].axhline(y=0, label="0", alpha=0.33, color="gray")
+            axs[1].axhline(y=self.take_profit, label="tp", alpha=0.33, color="green")
+            axs[1].axhline(y=self.stop_loss, label="sl", alpha=0.33, color="red")
+
+            axs[2].set_ylabel("reward")
+            reward_series = history.get("reward")
+            if reward_series is not None and len(reward_series) > 0:
+                axs[2].plot(ticks, reward_series, linewidth=1, color="gray")
+            axs[2].axhline(y=0, label="0", alpha=0.33)
+
+            axs[3].set_ylabel("total_profit")
+            total_profit_series = history.get("total_profit")
+            if total_profit_series is not None and len(total_profit_series) > 0:
+                axs[3].plot(ticks, total_profit_series, linewidth=1, color="gray")
+            axs[3].axhline(y=1, label="1", alpha=0.33)
+
+            axs[4].set_ylabel("total_reward")
+            total_reward_series = history.get("total_reward")
+            if total_reward_series is not None and len(total_reward_series) > 0:
+                axs[4].plot(ticks, total_reward_series, linewidth=1, color="gray")
+            axs[4].axhline(y=0, label="0", alpha=0.33)
+            axs[4].set_xlabel("tick")
+
+        _borders = ["top", "right", "bottom", "left"]
+        for _ax in axs:
+            for _border in _borders:
+                _ax.spines[_border].set_color("#5b5e4b")
+
+        fig.suptitle(
+            f"Total Reward: {self.total_reward:.2f} ~ "
+            + f"Total Profit: {self._total_profit:.2f} ~ "
+            + f"Trades: {int(len(self.trade_history) // 2)}",
+        )
+        fig.tight_layout()
+        return fig
+
+    def close(self) -> None:
+        plt.close()
+        gc.collect()
+        if th.cuda.is_available():
+            th.cuda.empty_cache()
 
 
 class InfoMetricsCallback(TensorboardCallback):
@@ -1718,11 +1733,29 @@ class InfoMetricsCallback(TensorboardCallback):
     Tensorboard callback
     """
 
+    def __init__(self, *args, throttle: int = 1, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.throttle = 1 if throttle < 1 else throttle
+
     def _on_training_start(self) -> None:
-        _lr = self.model.learning_rate
-        _lr = (
-            float(_lr) if isinstance(_lr, (int, float, np.floating)) else "lr_schedule"
-        )
+        lr_schedule = "unknown"
+        lr_iv = np.nan
+        lr_fv = np.nan
+        lr = getattr(self.model, "learning_rate", None)
+        if callable(lr):
+            lr_schedule = "linear"
+            try:
+                lr_iv = lr(1.0)
+            except Exception:
+                lr_iv = np.nan
+            try:
+                lr_fv = lr(0.0)
+            except Exception:
+                lr_fv = np.nan
+        elif isinstance(lr, (int, float)):
+            lr_schedule = "constant"
+            lr_iv = float(lr)
+            lr_fv = float(lr)
         n_stack = 1
         env = getattr(self, "training_env", None)
         while env is not None:
@@ -1737,20 +1770,36 @@ class InfoMetricsCallback(TensorboardCallback):
             "algorithm": self.model.__class__.__name__,
             "n_envs": int(self.model.n_envs),
             "n_stack": n_stack,
-            "learning_rate": _lr,
+            "lr_schedule": lr_schedule,
+            "learning_rate_iv": lr_iv,
+            "learning_rate_fv": lr_fv,
             "gamma": float(self.model.gamma),
             "batch_size": int(self.model.batch_size),
         }
         if "PPO" in self.model.__class__.__name__:
-            _cr = self.model.clip_range
-            _cr = (
-                float(_cr)
-                if isinstance(_cr, (int, float, np.floating))
-                else "cr_schedule"
-            )
+            cr_schedule = "unknown"
+            cr_iv = np.nan
+            cr_fv = np.nan
+            cr = getattr(self.model, "clip_range", None)
+            if callable(cr):
+                cr_schedule = "linear"
+                try:
+                    cr_iv = cr(1.0)
+                except Exception:
+                    cr_iv = np.nan
+                try:
+                    cr_fv = cr(0.0)
+                except Exception:
+                    cr_fv = np.nan
+            elif isinstance(cr, (int, float)):
+                cr_schedule = "constant"
+                cr_iv = float(cr)
+                cr_fv = float(cr)
             hparam_dict.update(
                 {
-                    "clip_range": _cr,
+                    "cr_schedule": cr_schedule,
+                    "clip_range_iv": cr_iv,
+                    "clip_range_fv": cr_fv,
                     "gae_lambda": float(self.model.gae_lambda),
                     "n_steps": int(self.model.n_steps),
                     "n_epochs": int(self.model.n_epochs),
@@ -1776,14 +1825,57 @@ class InfoMetricsCallback(TensorboardCallback):
                     "exploration_rate": float(self.model.exploration_rate),
                 }
             )
+            train_freq = getattr(self.model, "train_freq", None)
+            train_freq_val: int | None = None
+            try:
+                if isinstance(train_freq, int):
+                    train_freq_val = train_freq
+                elif isinstance(train_freq, (tuple, list)) and train_freq:
+                    if isinstance(train_freq[0], int):
+                        train_freq_val = train_freq[0]
+                elif hasattr(train_freq, "freq"):
+                    freq = getattr(train_freq, "freq")
+                    if isinstance(freq, int):
+                        train_freq_val = freq
+            except Exception:
+                train_freq_val = None
+            if train_freq_val is not None:
+                hparam_dict.update({"train_freq": train_freq_val})
             if "QRDQN" in self.model.__class__.__name__:
                 hparam_dict.update({"n_quantiles": int(self.model.n_quantiles)})
-        metric_dict = {
+        metric_dict: dict[str, float | int] = {
+            "eval/mean_reward": 0.0,
+            "eval/mean_reward_std": 0.0,
+            "rollout/ep_rew_mean": 0.0,
+            "rollout/ep_len_mean": 0.0,
+            "train/n_updates": 0,
+            "train/progress_done": 0.0,
+            "train/progress_remaining": 0.0,
+            "train/learning_rate": 0.0,
             "info/total_reward": 0.0,
             "info/total_profit": 1.0,
             "info/trade_count": 0,
             "info/trade_duration": 0,
         }
+        if "PPO" in self.model.__class__.__name__:
+            metric_dict.update(
+                {
+                    "train/approx_kl": 0.0,
+                    "train/entropy_loss": 0.0,
+                    "train/policy_gradient_loss": 0.0,
+                    "train/clip_fraction": 0.0,
+                    "train/clip_range": 0.0,
+                    "train/value_loss": 0.0,
+                    "train/explained_variance": 0.0,
+                }
+            )
+        if "DQN" in self.model.__class__.__name__:
+            metric_dict.update(
+                {
+                    "train/loss": 0.0,
+                    "train/exploration_rate": 0.0,
+                }
+            )
         self.logger.record(
             "hparams",
             HParam(hparam_dict, metric_dict),
@@ -1791,6 +1883,9 @@ class InfoMetricsCallback(TensorboardCallback):
         )
 
     def _on_step(self) -> bool:
+        if self.throttle > 1 and (self.num_timesteps % self.throttle) != 0:
+            return True
+
         def _is_numeric_non_bool(x: Any) -> bool:
             return isinstance(
                 x, (int, float, np.integer, np.floating)
@@ -1830,11 +1925,11 @@ class InfoMetricsCallback(TensorboardCallback):
             for k, values in numeric_acc.items():
                 if not values:
                     continue
-                mean = sum(values) / len(values)
+                mean = np.mean(values)
                 aggregated_info[k] = mean
                 if len(values) > 1:
                     try:
-                        aggregated_info[f"{k}_std"] = stdev(values)
+                        aggregated_info[f"{k}_std"] = np.std(values)
                     except Exception:
                         pass
 
@@ -2003,34 +2098,37 @@ class InfoMetricsCallback(TensorboardCallback):
                         except Exception:
                             pass
 
-        total_timesteps = getattr(self.model, "_total_timesteps", None)
-        if total_timesteps is not None and not np.isclose(total_timesteps, 0.0):
-            try:
+        try:
+            total_timesteps = getattr(self.model, "_total_timesteps", None)
+            if total_timesteps is not None and not np.isclose(total_timesteps, 0.0):
                 progress_done = float(self.num_timesteps) / float(total_timesteps)
                 progress_done = np.clip(progress_done, 0.0, 1.0)
-            except Exception:
+            else:
                 progress_done = 0.0
-        else:
-            progress_done = 0.0
-        progress_remaining = 1.0 - progress_done
-
-        try:
+            progress_remaining = 1.0 - progress_done
             self.logger.record("train/progress_done", progress_done)
             self.logger.record("train/progress_remaining", progress_remaining)
+        except Exception:
+            progress_remaining = 1.0
+
+        try:
+            lr = getattr(self.model, "learning_rate", None)
+            if callable(lr):
+                lr = lr(progress_remaining)
+            if isinstance(lr, (int, float)) and np.isfinite(lr):
+                self.logger.record("train/learning_rate", float(lr))
         except Exception:
             pass
 
-        lr_schedule = getattr(self.model, "lr_schedule", None)
-        if callable(lr_schedule):
+        if "PPO" in self.model.__class__.__name__:
             try:
-                lr = lr_schedule(progress_remaining)
-                self.logger.record("train/learning_rate", lr)
+                cr = getattr(self.model, "clip_range", None)
+                if callable(cr):
+                    cr = cr(progress_remaining)
+                if isinstance(cr, (int, float)) and np.isfinite(cr):
+                    self.logger.record("train/clip_range", float(cr))
             except Exception:
                 pass
-        else:
-            lr = getattr(self.model, "learning_rate", None)
-            if isinstance(lr, (int, float, np.floating)):
-                self.logger.record("train/learning_rate", lr)
 
         return True
 
@@ -2060,12 +2158,16 @@ class MaskableTrialEvalCallback(MaskableEvalCallback):
 
     def __init__(
         self,
-        eval_env,
-        trial,
+        eval_env: BaseEnvironment,
+        trial: Trial,
         n_eval_episodes: int = 10,
         eval_freq: int = 2048,
         deterministic: bool = True,
+        render: bool = False,
         use_masking: bool = True,
+        best_model_save_path: Optional[str] = None,
+        callback_on_new_best: Optional[BaseCallback] = None,
+        callback_after_eval: Optional[BaseCallback] = None,
         verbose: int = 0,
         **kwargs,
     ):
@@ -2074,8 +2176,12 @@ class MaskableTrialEvalCallback(MaskableEvalCallback):
             n_eval_episodes=n_eval_episodes,
             eval_freq=eval_freq,
             deterministic=deterministic,
-            verbose=verbose,
+            render=render,
+            best_model_save_path=best_model_save_path,
             use_masking=use_masking,
+            callback_on_new_best=callback_on_new_best,
+            callback_after_eval=callback_after_eval,
+            verbose=verbose,
             **kwargs,
         )
         self.trial = trial
@@ -2090,9 +2196,8 @@ class MaskableTrialEvalCallback(MaskableEvalCallback):
             return False
         if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
             self.eval_idx += 1
-            last_mean_reward = getattr(self, "last_mean_reward", np.nan)
             try:
-                last_mean_reward = float(last_mean_reward)
+                last_mean_reward = float(getattr(self, "last_mean_reward", np.nan))
             except Exception as e:
                 logger.warning(
                     "Optuna: invalid last_mean_reward at eval %s: %r", self.eval_idx, e
@@ -2113,6 +2218,23 @@ class MaskableTrialEvalCallback(MaskableEvalCallback):
                 )
                 self.is_pruned = True
                 return False
+            try:
+                best_mean_reward = float(getattr(self, "best_mean_reward", np.nan))
+            except Exception as e:
+                logger.warning(
+                    "Optuna: invalid best_mean_reward at eval %s: %r",
+                    self.eval_idx,
+                    e,
+                )
+            if np.isfinite(best_mean_reward):
+                try:
+                    self.logger.record("eval/best_mean_reward", best_mean_reward)
+                except Exception:
+                    pass
+            else:
+                logger.warning(
+                    "Optuna: non-finite best_mean_reward at eval %s", self.eval_idx
+                )
             try:
                 if self.trial.should_prune():
                     logger.info(
@@ -2132,37 +2254,6 @@ class MaskableTrialEvalCallback(MaskableEvalCallback):
         return True
 
 
-def make_env(
-    MyRLEnv: Type[BaseEnvironment],
-    env_id: str,
-    rank: int,
-    seed: int,
-    train_df: DataFrame,
-    price: DataFrame,
-    env_info: Dict[str, Any],
-) -> Callable[[], BaseEnvironment]:
-    """
-    Utility function for multiprocessed env.
-
-    :param MyRLEnv: (Type[BaseEnvironment]) environment class to instantiate
-    :param env_id: (str) the environment ID
-    :param rank: (int) index of the subprocess
-    :param seed: (int) the initial seed for RNG
-    :param train_df: (DataFrame) feature dataframe for the environment
-    :param price: (DataFrame) aligned price dataframe
-    :param env_info: (dict) all required arguments to instantiate the environment
-    :return:
-    (Callable[[], BaseEnvironment]) closure that when called instantiates and returns the environment
-    """
-
-    def _init() -> BaseEnvironment:
-        return MyRLEnv(
-            df=train_df, prices=price, id=env_id, seed=seed + rank, **env_info
-        )
-
-    return _init
-
-
 def deepmerge(dst: Dict[str, Any], src: Dict[str, Any]) -> Dict[str, Any]:
     """Recursively merge two dicts without mutating inputs"""
     dst_copy = copy.deepcopy(dst)
@@ -2245,7 +2336,7 @@ def get_net_arch(
 
 def get_activation_fn(
     activation_fn_name: Literal["tanh", "relu", "elu", "leaky_relu"],
-) -> type[th.nn.Module]:
+) -> Type[th.nn.Module]:
     """
     Get activation function
     """
@@ -2259,7 +2350,7 @@ def get_activation_fn(
 
 def get_optimizer_class(
     optimizer_class_name: Literal["adam", "adamw"],
-) -> type[th.optim.Optimizer]:
+) -> Type[th.optim.Optimizer]:
     """
     Get optimizer class
     """
index 977b32c683a51908f421c2a41138d67524bea59f..96a249f7a3d8a03329bf2163e0ea32f2da552ce3 100644 (file)
@@ -107,8 +107,8 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
             ]
         return copy.deepcopy(self._optuna_label_candle_pool_full_cache[cache_key])
 
-    def __init__(self, **kwargs):
-        super().__init__(**kwargs)
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
         self.pairs: list[str] = self.config.get("exchange", {}).get("pair_whitelist")
         if not self.pairs:
             raise ValueError(