]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor(reforcexy): cleanup plotting and prediction code
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Wed, 1 Oct 2025 20:45:18 +0000 (22:45 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Wed, 1 Oct 2025 20:45:18 +0000 (22:45 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index 3c207cb11b44c8f0d33d472932e9e7e0d1718991..ea3be6a9ecdb57f528f5a0330cc7a69cba798967 100644 (file)
@@ -25,6 +25,7 @@ from freqtrade.freqai.RL.BaseReinforcementLearningModel import (
 from freqtrade.freqai.tensorboard.TensorboardCallback import TensorboardCallback
 from freqtrade.strategy import timeframe_to_minutes
 from gymnasium.spaces import Box
+from matplotlib.lines import Line2D
 from numpy.typing import NDArray
 from optuna import Trial, TrialPruned, create_study, delete_study
 from optuna.exceptions import ExperimentalWarning
@@ -401,40 +402,46 @@ class ReforceXY(BaseReinforcementLearningModel):
 
         if "PPO" in self.model_type:
             if isinstance(net_arch, str):
-                resolved_net_arch = get_net_arch(self.model_type, net_arch)
-                if isinstance(resolved_net_arch, dict):
-                    model_params["policy_kwargs"]["net_arch"] = resolved_net_arch
-                else:
-                    model_params["policy_kwargs"]["net_arch"] = {
-                        "pi": resolved_net_arch,
-                        "vf": resolved_net_arch,
-                    }
+                model_params["policy_kwargs"]["net_arch"] = get_net_arch(
+                    self.model_type, net_arch
+                )
             elif isinstance(net_arch, list):
                 model_params["policy_kwargs"]["net_arch"] = {
                     "pi": net_arch,
                     "vf": net_arch,
                 }
             elif isinstance(net_arch, dict):
-                pi: Optional[List[int]] = net_arch.get("pi")
-                vf: Optional[List[int]] = net_arch.get("vf")
-                if not isinstance(pi, list) or not isinstance(vf, list):
-                    model_params["policy_kwargs"]["net_arch"] = {
-                        "pi": pi
-                        if isinstance(pi, list)
-                        else (vf if isinstance(vf, list) else default_net_arch),
-                        "vf": vf
-                        if isinstance(vf, list)
-                        else (pi if isinstance(pi, list) else default_net_arch),
-                    }
-                else:
-                    model_params["policy_kwargs"]["net_arch"] = net_arch
+                pi = (
+                    net_arch.get("pi")
+                    if isinstance(net_arch.get("pi"), list)
+                    else default_net_arch
+                )
+                vf = (
+                    net_arch.get("vf")
+                    if isinstance(net_arch.get("vf"), list)
+                    else default_net_arch
+                )
+                model_params["policy_kwargs"]["net_arch"] = {"pi": pi, "vf": vf}
+            else:
+                logger.warning(
+                    "Unexpected net_arch type=%s, using default", type(net_arch)
+                )
+                model_params["policy_kwargs"]["net_arch"] = {
+                    "pi": default_net_arch,
+                    "vf": default_net_arch,
+                }
         else:
             if isinstance(net_arch, str):
                 model_params["policy_kwargs"]["net_arch"] = get_net_arch(
                     self.model_type, net_arch
                 )
-            else:
+            elif isinstance(net_arch, list):
                 model_params["policy_kwargs"]["net_arch"] = net_arch
+            else:
+                logger.warning(
+                    "Unexpected net_arch type=%s, using default", type(net_arch)
+                )
+                model_params["policy_kwargs"]["net_arch"] = default_net_arch
 
         model_params["policy_kwargs"]["activation_fn"] = get_activation_fn(
             model_params.get("policy_kwargs", {}).get("activation_fn", "relu")
@@ -453,6 +460,28 @@ class ReforceXY(BaseReinforcementLearningModel):
         hyperopt_reduction_factor: float = 4.0,
         model_params: Optional[Dict[str, Any]] = None,
     ) -> int:
+        """
+        Calculate evaluation frequency (number of steps between evaluations).
+
+        For PPO:
+        - Use n_steps from model_params if available
+        - Otherwise, select the largest value from PPO_N_STEPS that is <= total_timesteps
+
+        For DQN:
+        - Use n_eval_steps divided by n_envs (rounded up)
+
+        For hyperopt:
+        - Reduce eval_freq by hyperopt_reduction_factor to speed up trials
+
+        Args:
+            total_timesteps: Total training timesteps
+            hyperopt: If True, reduce eval_freq for hyperopt
+            hyperopt_reduction_factor: Reduction factor for hyperopt (default: 4.0)
+            model_params: Model parameters (to get n_steps for PPO)
+
+        Returns:
+            int: Evaluation frequency (capped at total_timesteps)
+        """
         if total_timesteps <= 0:
             return 1
         if "PPO" in self.model_type:
@@ -692,18 +721,14 @@ class ReforceXY(BaseReinforcementLearningModel):
         :param dk: FreqaiDatakitchen = data kitchen for the current pair
         :param model: Any = the trained model used to inference the features.
         """
-
         virtual_position: Positions = Positions.Neutral
+        virtual_trade_duration: int = 0
         np_dataframe: NDArray[np.float32] = dataframe.to_numpy(
             dtype=np.float32, copy=False
         )
         n = int(np_dataframe.shape[0])
         window_length = int(self.CONV_WIDTH)
-        if self.rl_config.get("add_state_info", False) and not self.live:
-            static_state_block = np.tile(
-                np.array([0.0, float(Positions.Neutral.value), 0.0], dtype=np.float32),
-                (window_length, 1),
-            )
+        add_state_info = self.rl_config.get("add_state_info", False)
 
         def _update_virtual_position(action: int, position: Positions) -> Positions:
             if action == Actions.Long_enter.value and position == Positions.Neutral:
@@ -716,6 +741,19 @@ class ReforceXY(BaseReinforcementLearningModel):
                 return Positions.Neutral
             return position
 
+        def _update_virtual_trade_duration(
+            action: int,
+            virtual_position: Positions,
+            previous_virtual_position: Positions,
+            current_virtual_trade_duration: int,
+        ) -> int:
+            if virtual_position != Positions.Neutral:
+                if previous_virtual_position == Positions.Neutral:
+                    return 1
+                else:
+                    return current_virtual_trade_duration + 1
+            return 0
+
         frame_buffer: List[NDArray[np.float32]] = []
         zero_frame: Optional[NDArray[np.float32]] = None
 
@@ -725,7 +763,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             np_observation = np_dataframe[start_idx:end_idx, :]
             action_masks_param: Dict[str, Any] = {}
 
-            if self.rl_config.get("add_state_info", False):
+            if add_state_info:
                 if self.live:
                     position, pnl, trade_duration = self.get_state_info(dk.pair)
                     position = ReforceXY._normalize_position(position)
@@ -737,7 +775,17 @@ class ReforceXY(BaseReinforcementLearningModel):
                         (np_observation.shape[0], 1),
                     )
                 else:
-                    state_block = static_state_block
+                    state_block = np.tile(
+                        np.array(
+                            [
+                                0.0,
+                                float(virtual_position.value),
+                                float(virtual_trade_duration),
+                            ],
+                            dtype=np.float32,
+                        ),
+                        (np_observation.shape[0], 1),
+                    )
                 np_observation = np.concatenate([np_observation, state_block], axis=1)
 
             fb: List[NDArray[np.float32]] = frame_buffer
@@ -776,9 +824,16 @@ class ReforceXY(BaseReinforcementLearningModel):
         for start_idx in range(0, n - window_length + 1):
             action = _predict(start_idx)
             predicted_actions.append(action)
+            previous_virtual_position = virtual_position
             virtual_position = _update_virtual_position(action, virtual_position)
+            virtual_trade_duration = _update_virtual_trade_duration(
+                action,
+                virtual_position,
+                previous_virtual_position,
+                virtual_trade_duration,
+            )
 
-        pad_count = n - len(predicted_actions)
+        pad_count = max(0, n - len(predicted_actions))
         actions_list = ([np.nan] * pad_count) + predicted_actions
         actions = DataFrame({"action": actions_list}, index=dataframe.index)
 
@@ -1305,6 +1360,8 @@ class MyRLEnv(Base5ActionRLEnv):
         self._force_action: Optional[ForceActions] = None
         self._last_closed_position: Optional[Positions] = None
         self._last_closed_trade_tick: int = 0
+        self._max_unrealized_profit = -np.inf
+        self._min_unrealized_profit = np.inf
         return observation, history
 
     def _get_exit_factor(
@@ -1316,6 +1373,13 @@ class MyRLEnv(Base5ActionRLEnv):
         """
         Compute the reward factor at trade exit
         """
+        if (
+            not np.isfinite(factor)
+            or not np.isfinite(pnl)
+            or not np.isfinite(duration_ratio)
+        ):
+            return 0.0
+
         model_reward_parameters = self.rl_config.get("model_reward_parameters", {})
         exit_factor_mode = model_reward_parameters.get("exit_factor_mode", "piecewise")
 
@@ -1377,15 +1441,13 @@ class MyRLEnv(Base5ActionRLEnv):
         return factor
 
     def _get_pnl_factor(self, pnl: float, pnl_target: float) -> float:
+        if not np.isfinite(pnl) or not np.isfinite(pnl_target):
+            return 0.0
+
         model_reward_parameters = self.rl_config.get("model_reward_parameters", {})
 
         pnl_target_factor = 1.0
-        if (
-            np.isfinite(pnl_target)
-            and np.isfinite(pnl)
-            and pnl_target > 0.0
-            and pnl > pnl_target
-        ):
+        if pnl_target > 0.0 and pnl > pnl_target:
             win_reward_factor = float(
                 model_reward_parameters.get("win_reward_factor", 2.0)
             )
@@ -1506,6 +1568,8 @@ class MyRLEnv(Base5ActionRLEnv):
             holding_duration_ratio_grace = float(
                 model_reward_parameters.get("holding_duration_ratio_grace", 1.0)
             )
+            if holding_duration_ratio_grace <= 0.0:
+                holding_duration_ratio_grace = 1.0
             holding_penalty_scale = float(
                 model_reward_parameters.get("holding_penalty_scale", 0.3)
             )
@@ -1513,9 +1577,7 @@ class MyRLEnv(Base5ActionRLEnv):
                 model_reward_parameters.get("holding_penalty_power", 1.0)
             )
             if pnl >= pnl_target:
-                if duration_ratio <= holding_duration_ratio_grace and not np.isclose(
-                    holding_duration_ratio_grace, 0.0
-                ):
+                if duration_ratio <= holding_duration_ratio_grace:
                     effective_duration_ratio = (
                         duration_ratio / holding_duration_ratio_grace
                     )
@@ -1552,27 +1614,39 @@ class MyRLEnv(Base5ActionRLEnv):
         This may or may not be independent of action types, user can inherit
         this in their custom "MyRLEnv"
         """
-        features_window = self.signal_features[
-            (self._current_tick - self.window_size) : self._current_tick
-        ]
+        start_idx = max(0, self._current_tick - self.window_size)
+        end_idx = self._current_tick
+        features_window = self.signal_features.iloc[start_idx:end_idx]
+        if len(features_window) < self.window_size:
+            pad_size = self.window_size - len(features_window)
+            pad_df = DataFrame(
+                np.zeros((pad_size, features_window.shape[1]), dtype=np.float32),
+                columns=features_window.columns,
+            )
+            features_window = concat(
+                [pad_df, features_window], axis=0, ignore_index=True
+            )
+        features_window_array = features_window.to_numpy(dtype=np.float32)
         if self.add_state_info:
-            features_and_state = DataFrame(
-                np.zeros(
-                    (len(features_window), len(self.state_features)),
-                    dtype=np.float32,
-                ),
-                columns=self.state_features,
-                index=features_window.index,
+            return np.concatenate(
+                [
+                    features_window_array,
+                    np.tile(
+                        np.array(
+                            [
+                                self.get_unrealized_profit(),
+                                self._position.value,
+                                self.get_trade_duration(),
+                            ],
+                            dtype=np.float32,
+                        ),
+                        (self.window_size, 1),
+                    ),
+                ],
+                axis=1,
             )
-            # STATE_INFO
-            features_and_state["pnl"] = self.get_unrealized_profit()
-            features_and_state["position"] = self._position.value
-            features_and_state["trade_duration"] = self.get_trade_duration()
-
-            features_and_state = concat([features_window, features_and_state], axis=1)
-            return features_and_state.to_numpy(dtype=np.float32)
         else:
-            return features_window.to_numpy(dtype=np.float32)
+            return features_window_array
 
     def _get_force_action(self) -> Optional[ForceActions]:
         if not self.force_actions or self._position == Positions.Neutral:
@@ -1617,7 +1691,7 @@ class MyRLEnv(Base5ActionRLEnv):
         Execute trade based on the given action
         """
         # Force exit trade
-        if self._force_action:
+        if self._force_action is not None:
             self._exit_trade()
             self.tensorboard_log(f"{self._force_action.name}", category="actions/force")
             return f"{self._force_action.name}"
@@ -1903,23 +1977,6 @@ class MyRLEnv(Base5ActionRLEnv):
         """
         Plot trades and environment data
         """
-
-        def transform_y_offset(ax, offset):
-            return mtransforms.offset_copy(ax.transData, fig=fig, y=offset)
-
-        def plot_markers(ax, xs, ys, marker, color, size, offset):
-            ax.plot(
-                xs,
-                ys,
-                marker=marker,
-                color=color,
-                markersize=size,
-                fillstyle="full",
-                transform=transform_y_offset(ax, offset),
-                linestyle="none",
-                zorder=3,
-            )
-
         with plt.style.context("dark_background"):
             fig, axs = plt.subplots(
                 nrows=5,
@@ -1929,6 +1986,22 @@ class MyRLEnv(Base5ActionRLEnv):
                 sharex=True,
             )
 
+            def transform_y_offset(ax, offset):
+                return mtransforms.offset_copy(ax.transData, fig=fig, y=offset)
+
+            def plot_markers(ax, xs, ys, marker, color, size, offset):
+                ax.plot(
+                    xs,
+                    ys,
+                    marker=marker,
+                    color=color,
+                    markersize=size,
+                    fillstyle="full",
+                    transform=transform_y_offset(ax, offset),
+                    linestyle="none",
+                    zorder=3,
+                )
+
             history = self.get_env_history()
             if len(history) == 0:
                 return fig
@@ -1952,45 +2025,47 @@ class MyRLEnv(Base5ActionRLEnv):
             history_type = history.get("type")
             history_price = history.get("price")
             if history_type is not None and history_price is not None:
-                enter_long_mask = history_type == "long_enter"
-                enter_short_mask = history_type == "short_enter"
-                exit_long_mask = history_type == "long_exit"
-                exit_short_mask = history_type == "short_exit"
-                take_profit_mask = history_type == "take_profit"
-                stop_loss_mask = history_type == "stop_loss"
-                timeout_mask = history_type == "timeout"
-
-                enter_long_x = ticks[enter_long_mask]
-                enter_short_x = ticks[enter_short_mask]
-                exit_long_x = ticks[exit_long_mask]
-                exit_short_x = ticks[exit_short_mask]
-                take_profit_x = ticks[take_profit_mask]
-                stop_loss_x = ticks[stop_loss_mask]
-                timeout_x = ticks[timeout_mask]
-
-                enter_long_y = history.loc[enter_long_mask, "price"]
-                enter_short_y = history.loc[enter_short_mask, "price"]
-                exit_long_y = history.loc[exit_long_mask, "price"]
-                exit_short_y = history.loc[exit_short_mask, "price"]
-                take_profit_y = history.loc[take_profit_mask, "price"]
-                stop_loss_y = history.loc[stop_loss_mask, "price"]
-                timeout_y = history.loc[timeout_mask, "price"]
-
-                plot_markers(
-                    axs[0], enter_long_x, enter_long_y, "^", "forestgreen", 5, -0.1
-                )
-                plot_markers(
-                    axs[0], enter_short_x, enter_short_y, "v", "firebrick", 5, 0.1
-                )
-                plot_markers(
-                    axs[0], exit_long_x, exit_long_y, ".", "cornflowerblue", 4, 0.1
-                )
-                plot_markers(
-                    axs[0], exit_short_x, exit_short_y, ".", "thistle", 4, -0.1
-                )
-                plot_markers(axs[0], take_profit_x, take_profit_y, "*", "lime", 8, 0.1)
-                plot_markers(axs[0], stop_loss_x, stop_loss_y, "x", "red", 8, -0.1)
-                plot_markers(axs[0], timeout_x, timeout_y, "1", "yellow", 8, 0.0)
+                trade_markers_config = [
+                    ("long_enter", "^", "forestgreen", 5, -0.1, "Long enter"),
+                    ("short_enter", "v", "firebrick", 5, 0.1, "Short enter"),
+                    ("long_exit", ".", "cornflowerblue", 4, 0.1, "Long exit"),
+                    ("short_exit", ".", "thistle", 4, -0.1, "Short exit"),
+                    ("take_profit", "*", "lime", 8, 0.1, "Take profit"),
+                    ("stop_loss", "x", "red", 8, -0.1, "Stop loss"),
+                    ("timeout", "1", "yellow", 8, 0.0, "Timeout"),
+                ]
+
+                legend_scale_factor = 1.5
+                markers_legend = []
+
+                for (
+                    type_name,
+                    marker,
+                    color,
+                    size,
+                    offset,
+                    label,
+                ) in trade_markers_config:
+                    mask = history_type == type_name
+                    if mask.any():
+                        xs = ticks[mask]
+                        ys = history.loc[mask, "price"]
+
+                        plot_markers(axs[0], xs, ys, marker, color, size, offset)
+
+                    markers_legend.append(
+                        Line2D(
+                            [0],
+                            [0],
+                            marker=marker,
+                            color="w",
+                            markerfacecolor=color,
+                            markersize=size * legend_scale_factor,
+                            linestyle="none",
+                            label=label,
+                        )
+                    )
+                axs[0].legend(handles=markers_legend, loc="upper right", fontsize=8)
 
             axs[1].set_ylabel("pnl")
             pnl_series = history.get("pnl")
@@ -2773,9 +2848,21 @@ def convert_optuna_params_to_model_params(
     lr = get_schedule(optuna_params.get("lr_schedule", "constant"), float(lr))
 
     if "PPO" in model_type:
+        required_ppo_params = [
+            "clip_range",
+            "n_steps",
+            "batch_size",
+            "gamma",
+            "ent_coef",
+            "n_epochs",
+            "gae_lambda",
+            "max_grad_norm",
+            "vf_coef",
+        ]
+        for param in required_ppo_params:
+            if optuna_params.get(param) is None:
+                raise ValueError(f"missing '{param}' in optuna params for {model_type}")
         cr = optuna_params.get("clip_range")
-        if cr is None:
-            raise ValueError(f"missing 'clip_range' in optuna params for {model_type}")
         cr = get_schedule(optuna_params.get("cr_schedule", "constant"), float(cr))
 
         model_params.update(
@@ -2795,6 +2882,21 @@ def convert_optuna_params_to_model_params(
         if optuna_params.get("target_kl") is not None:
             model_params["target_kl"] = float(optuna_params.get("target_kl"))
     elif "DQN" in model_type:
+        required_dqn_params = [
+            "gamma",
+            "batch_size",
+            "buffer_size",
+            "train_freq",
+            "exploration_fraction",
+            "exploration_initial_eps",
+            "exploration_final_eps",
+            "target_update_interval",
+            "learning_starts",
+            "subsample_steps",
+        ]
+        for param in required_dqn_params:
+            if optuna_params.get(param) is None:
+                raise ValueError(f"missing '{param}' in optuna params for {model_type}")
         train_freq = optuna_params.get("train_freq")
         subsample_steps = optuna_params.get("subsample_steps")
         gradient_steps = compute_gradient_steps(train_freq, subsample_steps)