]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
perf(reforcexy): make evaluation lighter
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Sat, 27 Sep 2025 18:56:45 +0000 (20:56 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Sat, 27 Sep 2025 18:56:45 +0000 (20:56 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/config-template.json
ReforceXY/user_data/freqaimodels/ReforceXY.py

index 6dbef2982172234f92547240ae9bc3b1ef538649..852a146a0cf4459d2db1ecf99004676365b19082 100644 (file)
       "max_training_drawdown_pct": 0.02,
       "max_trade_duration_candles": 96, // Timeout exit value used with force_actions
       "force_actions": false, // Utilize minimal_roi, stoploss, and max_trade_duration_candles as TP/SL/Timeout in the environment
-      "n_envs": 32, // Number of DummyVecEnv or SubProcVecEnv environments
+      "n_envs": 32, // Number of DummyVecEnv or SubProcVecEnv training environments
       "multiprocessing": false, // Use SubprocVecEnv if n_envs>1 (otherwise DummyVecEnv)
       "frame_stacking": 2, // Number of VecFrameStack stacks (set > 1 to use)
       "lr_schedule": false, // Enable learning rate linear schedule
index 3fc9bf8e7ecc9182913f9b6ab64a22c1d6d0816a..cbbc8b9dda62c62561c39506ccbf2e68012b2368 100644 (file)
@@ -66,9 +66,9 @@ logger = logging.getLogger(__name__)
 
 
 class ForceActions(IntEnum):
-    Take_profit = 1
-    Stop_loss = 2
-    Timeout = 3
+    Take_profit = 0
+    Stop_loss = 1
+    Timeout = 2
 
 
 class ReforceXY(BaseReinforcementLearningModel):
@@ -87,21 +87,26 @@ class ReforceXY(BaseReinforcementLearningModel):
                 ...
                 "max_trade_duration_candles": 96,   // Timeout exit value used with force_actions
                 "force_actions": false,             // Utilize minimal_roi, stoploss, and max_trade_duration_candles as TP/SL/Timeout in the environment
-                "n_envs": 1,                        // Number of DummyVecEnv or SubProcVecEnv environments
+                "n_envs": 1,                        // Number of DummyVecEnv or SubProcVecEnv training environments
+                "n_eval_envs": 1,                   // Number of DummyVecEnv or SubProcVecEnv evaluation environments
                 "multiprocessing": false,           // Use SubprocVecEnv if n_envs>1 (otherwise DummyVecEnv)
+                "eval_multiprocessing": false,      // Use SubprocVecEnv if n_eval_envs>1 (otherwise DummyVecEnv)
                 "frame_stacking": 0,                // Number of VecFrameStack stacks (set > 1 to use)
                 "inference_masking": true,          // Enable action masking during inference
                 "lr_schedule": false,               // Enable learning rate linear schedule
                 "cr_schedule": false,               // Enable clip range linear schedule
-                "n_steps_eval": 10_000,             // Number of environment steps between evaluations
+                "n_eval_steps": 10_000,             // Number of environment steps between evaluations
+                "n_eval_episodes": 5,               // Number of episodes per evaluation
                 "max_no_improvement_evals": 0,      // Maximum consecutive evaluations without a new best model
                 "min_evals": 0,                     // Number of evaluations before start to count evaluations without improvements
                 "check_envs": true,                 // Check that an environment follows Gym API
+                "tensorboard_throttle": 1,          // Number of training calls between tensorboard logs
                 "plot_new_best": false,             // Enable tensorboard rollout plot upon finding a new best model
+                "plot_window": 2000,                // Environment history window used for tensorboard rollout plot
             },
             "rl_config_optuna": {
                 "enabled": false,                   // Enable optuna hyperopt
-                "per_pair: false,                   // Enable per pair hyperopt
+                "per_pair": false,                  // Enable per pair hyperopt
                 "n_trials": 100,
                 "n_startup_trials": 15,
                 "timeout_hours": 0,
@@ -115,6 +120,8 @@ class ReforceXY(BaseReinforcementLearningModel):
         - pip install optuna-dashboard
     """
 
+    _action_masks_cache: Dict[Tuple[str, int, Optional[int]], NDArray[np.bool_]] = {}
+
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
         self.pairs: List[str] = self.config.get("exchange", {}).get("pair_whitelist")
@@ -130,9 +137,14 @@ class ReforceXY(BaseReinforcementLearningModel):
         self.lr_schedule: bool = self.rl_config.get("lr_schedule", False)
         self.cr_schedule: bool = self.rl_config.get("cr_schedule", False)
         self.n_envs: int = self.rl_config.get("n_envs", 1)
+        self.n_eval_envs: int = self.rl_config.get("n_eval_envs", 1)
         self.multiprocessing: bool = self.rl_config.get("multiprocessing", False)
+        self.eval_multiprocessing: bool = self.rl_config.get(
+            "eval_multiprocessing", False
+        )
         self.frame_stacking: int = self.rl_config.get("frame_stacking", 0)
-        self.n_steps_eval: int = self.rl_config.get("n_steps_eval", 10_000)
+        self.n_eval_steps: int = self.rl_config.get("n_eval_steps", 10_000)
+        self.n_eval_episodes: int = self.rl_config.get("n_eval_episodes", 5)
         self.max_no_improvement_evals: int = self.rl_config.get(
             "max_no_improvement_evals", 0
         )
@@ -188,9 +200,18 @@ class ReforceXY(BaseReinforcementLearningModel):
         position: Positions,
         force_action: Optional[ForceActions] = None,
     ) -> NDArray[np.bool_]:
-        is_short_allowed = ReforceXY._is_short_allowed(trading_mode)
         position = ReforceXY._normalize_position(position)
 
+        cache_key = (
+            trading_mode,
+            position.value,
+            force_action.value if force_action else None,
+        )
+        if cache_key in ReforceXY._action_masks_cache:
+            return ReforceXY._action_masks_cache[cache_key]
+
+        is_short_allowed = ReforceXY._is_short_allowed(trading_mode)
+
         action_masks = np.zeros(len(Actions), dtype=np.bool_)
 
         if force_action is not None and position in (Positions.Long, Positions.Short):
@@ -198,7 +219,8 @@ class ReforceXY(BaseReinforcementLearningModel):
                 action_masks[Actions.Long_exit.value] = True
             else:
                 action_masks[Actions.Short_exit.value] = True
-            return action_masks
+            ReforceXY._action_masks_cache[cache_key] = action_masks
+            return ReforceXY._action_masks_cache[cache_key]
 
         action_masks[Actions.Neutral.value] = True
         if position == Positions.Neutral:
@@ -210,7 +232,8 @@ class ReforceXY(BaseReinforcementLearningModel):
         elif position == Positions.Short:
             action_masks[Actions.Short_exit.value] = True
 
-        return action_masks
+        ReforceXY._action_masks_cache[cache_key] = action_masks
+        return ReforceXY._action_masks_cache[cache_key]
 
     def unset_unsupported(self) -> None:
         """
@@ -220,11 +243,21 @@ class ReforceXY(BaseReinforcementLearningModel):
         if not isinstance(self.n_envs, int) or self.n_envs < 1:
             logger.warning("Invalid n_envs=%s. Forcing n_envs=1", self.n_envs)
             self.n_envs = 1
+        if not isinstance(self.n_eval_envs, int) or self.n_eval_envs < 1:
+            logger.warning(
+                "Invalid n_eval_envs=%s. Forcing n_eval_envs=1", self.n_eval_envs
+            )
+            self.n_eval_envs = 1
         if self.multiprocessing and self.n_envs <= 1:
             logger.warning(
                 "User tried to use multiprocessing with n_envs=1. Deactivating multiprocessing"
             )
             self.multiprocessing = False
+        if self.eval_multiprocessing and self.n_eval_envs <= 1:
+            logger.warning(
+                "User tried to use eval_multiprocessing with n_eval_envs=1. Deactivating eval_multiprocessing"
+            )
+            self.eval_multiprocessing = False
         if self.multiprocessing and self.plot_new_best:
             logger.warning(
                 "User tried to use plot_new_best with multiprocessing. Deactivating plot_new_best"
@@ -242,12 +275,18 @@ class ReforceXY(BaseReinforcementLearningModel):
                 self.frame_stacking,
             )
             self.frame_stacking = 0
-        if self.n_steps_eval <= 0:
+        if self.n_eval_steps <= 0:
+            logger.warning(
+                "Invalid n_eval_steps=%s. Forcing n_eval_steps=10_000",
+                self.n_eval_steps,
+            )
+            self.n_eval_steps = 10_000
+        if self.n_eval_episodes <= 0:
             logger.warning(
-                "Invalid n_steps_eval=%s. Forcing n_steps_eval=10_000",
-                self.n_steps_eval,
+                "Invalid n_eval_episodes=%s. Forcing n_eval_episodes=5",
+                self.n_eval_episodes,
             )
-            self.n_steps_eval = 10_000
+            self.n_eval_episodes = 5
         if self.continual_learning and self.frame_stacking:
             logger.warning(
                 "User tried to use continual_learning with frame_stacking. \
@@ -429,9 +468,9 @@ class ReforceXY(BaseReinforcementLearningModel):
                 )
         else:
             if hyperopt and hyperopt_reduction_factor > 1.0:
-                eval_freq = int(self.n_steps_eval / hyperopt_reduction_factor)
+                eval_freq = int(self.n_eval_steps / hyperopt_reduction_factor)
             else:
-                eval_freq = self.n_steps_eval
+                eval_freq = self.n_eval_steps
             eval_freq = max(1, (eval_freq + self.n_envs - 1) // self.n_envs)
 
         return min(eval_freq, total_timesteps)
@@ -476,6 +515,7 @@ class ReforceXY(BaseReinforcementLearningModel):
         if not trial:
             self.eval_callback = MaskableEvalCallback(
                 eval_env,
+                n_eval_episodes=self.n_eval_episodes,
                 eval_freq=eval_freq,
                 deterministic=True,
                 render=False,
@@ -491,6 +531,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             self.optuna_eval_callback = MaskableTrialEvalCallback(
                 eval_env,
                 trial,
+                n_eval_episodes=self.n_eval_episodes,
                 eval_freq=eval_freq,
                 deterministic=True,
                 render=False,
@@ -537,6 +578,7 @@ class ReforceXY(BaseReinforcementLearningModel):
         )
         logger.info("Test: %s steps (%s days)", test_timesteps, test_days)
         logger.info("Multiprocessing: %s", self.multiprocessing)
+        logger.info("Eval multiprocessing: %s", self.eval_multiprocessing)
         logger.info("Frame stacking: %s", self.frame_stacking)
         logger.info("Action masking: %s", self.action_masking)
         logger.info("Hyperopt: %s", self.hyperopt)
@@ -942,14 +984,9 @@ class ReforceXY(BaseReinforcementLearningModel):
             prices_train, prices_test = self.build_ohlc_price_dataframes(
                 dk.data_dictionary, dk.pair, dk
             )
-        seed = (
-            (
-                self.get_model_params().get("seed", 42)
-                + (trial.number if trial is not None else 0)
-            )
-            if seed is None
-            else seed
-        )
+        seed = self.get_model_params().get("seed", 42) if seed is None else seed
+        if trial is not None:
+            seed += trial.number
         set_random_seed(seed)
         env_info = self.pack_env_dict(dk.pair) if env_info is None else env_info
         env_prefix = f"trial_{trial.number}_" if trial is not None else ""
@@ -976,14 +1013,16 @@ class ReforceXY(BaseReinforcementLearningModel):
                 prices_test,
                 env_info=env_info,
             )
-            for i in range(self.n_envs)
+            for i in range(self.n_eval_envs)
         ]
 
         if self.multiprocessing and self.n_envs > 1:
             train_env = SubprocVecEnv(train_fns, start_method="spawn")
-            eval_env = SubprocVecEnv(eval_fns, start_method="spawn")
         else:
             train_env = DummyVecEnv(train_fns)
+        if self.eval_multiprocessing and self.n_eval_envs > 1:
+            eval_env = SubprocVecEnv(eval_fns, start_method="spawn")
+        else:
             eval_env = DummyVecEnv(eval_fns)
 
         if self.frame_stacking: