]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
perf(reforcexy): ensure total timesteps is aligned to model rollout with
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Fri, 26 Sep 2025 19:31:38 +0000 (21:31 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Fri, 26 Sep 2025 19:31:38 +0000 (21:31 +0200)
PPO

Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index 60cd86a1457569d66fde7846ec93639fdb43a290..150dc5ded6c4b49f17c0c1ef7b4e3ec262c286c0 100644 (file)
@@ -555,7 +555,8 @@ class ReforceXY(BaseReinforcementLearningModel):
         logger.info("%s params: %s", self.model_type, model_params)
 
         if "PPO" in self.model_type:
-            min_timesteps = 2 * model_params.get("n_steps", 0) * self.n_envs
+            n_steps = model_params.get("n_steps", 0)
+            min_timesteps = 2 * n_steps * self.n_envs
             if total_timesteps <= min_timesteps:
                 logger.warning(
                     "total_timesteps=%s is less than or equal to 2*n_steps*n_envs=%s. This may lead to suboptimal training results for model %s",
@@ -563,6 +564,19 @@ class ReforceXY(BaseReinforcementLearningModel):
                     min_timesteps,
                     self.model_type,
                 )
+            if n_steps > 0:
+                rollout = n_steps * self.n_envs
+                aligned_total_timesteps = (
+                    (total_timesteps + rollout - 1) // rollout
+                ) * rollout
+                if aligned_total_timesteps != total_timesteps:
+                    total_timesteps = aligned_total_timesteps
+                    logger.info(
+                        "Train: aligned total %s steps (%s days) for model %s",
+                        total_timesteps,
+                        steps_to_days(total_timesteps, self.config.get("timeframe")),
+                        self.model_type,
+                    )
 
         if self.activate_tensorboard:
             tensorboard_log_path = Path(
@@ -1010,9 +1024,18 @@ class ReforceXY(BaseReinforcementLearningModel):
         # Ensure that the sampled parameters take precedence
         params = deepmerge(self.get_model_params(), params)
         params["seed"] = params.get("seed", 42) + trial.number
-
         logger.info("Trial %s params: %s", trial.number, params)
 
+        if "PPO" in self.model_type:
+            n_steps = params.get("n_steps", 0)
+            if n_steps > 0:
+                rollout = n_steps * self.n_envs
+                aligned_total_timesteps = (
+                    (total_timesteps + rollout - 1) // rollout
+                ) * rollout
+                if aligned_total_timesteps != total_timesteps:
+                    total_timesteps = aligned_total_timesteps
+
         nan_encountered = False
 
         if self.activate_tensorboard: