]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
perf(reforcexy): refine optuna search space validation
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Tue, 9 Sep 2025 02:10:30 +0000 (04:10 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Tue, 9 Sep 2025 02:10:30 +0000 (04:10 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index 3892d30e7e810eb4181c8251911b593ad3d6b3f7..25c73866121dae55423a02baa4a8a01f051a98e9 100644 (file)
@@ -544,6 +544,8 @@ class ReforceXY(BaseReinforcementLearningModel):
             else self.get_storage()
         )
         eval_freq = max(1, len(train_df) // self.n_envs)
+        max_resource = max(1, (total_timesteps + eval_freq - 1) // eval_freq)
+        min_resource = min(3, max_resource)
         study: Study = create_study(
             study_name=study_name,
             sampler=TPESampler(
@@ -553,8 +555,8 @@ class ReforceXY(BaseReinforcementLearningModel):
                 seed=self.rl_config_optuna.get("seed", 42),
             ),
             pruner=HyperbandPruner(
-                min_resource=3,
-                max_resource=(total_timesteps + eval_freq - 1) // eval_freq,
+                min_resource=min_resource,
+                max_resource=max_resource,
                 reduction_factor=3,
             ),
             direction=StudyDirection.MAXIMIZE,
@@ -695,9 +697,9 @@ class ReforceXY(BaseReinforcementLearningModel):
         Defines a single trial for hyperparameter optimization using Optuna
         """
         if "PPO" in self.model_type:
-            params = sample_params_ppo(trial)
+            params = sample_params_ppo(trial, self.n_envs)
             if params.get("n_steps", 0) > total_timesteps:
-                raise TrialPruned("n_steps exceeds total_timesteps")
+                raise TrialPruned("n_steps is greater than total_timesteps")
         elif "QRDQN" in self.model_type:
             params = sample_params_qrdqn(trial)
         elif "DQN" in self.model_type:
@@ -733,8 +735,14 @@ class ReforceXY(BaseReinforcementLearningModel):
         try:
             model.learn(total_timesteps=total_timesteps, callback=callbacks)
         except AssertionError:
-            logger.warning("Optuna encountered NaN")
+            logger.warning("Optuna encountered NaN (AssertionError)")
             nan_encountered = True
+        except ValueError as e:
+            if "NaN" in str(e):
+                logger.warning("Optuna encountered NaN (ValueError)")
+                nan_encountered = True
+            else:
+                raise
         finally:
             if self.progressbar_callback:
                 self.progressbar_callback.on_training_end()
@@ -743,7 +751,7 @@ class ReforceXY(BaseReinforcementLearningModel):
                 model.env.close()
 
         if nan_encountered:
-            return np.nan
+            raise TrialPruned("NaN encountered during training")
 
         if self.optuna_callback.is_pruned:
             raise TrialPruned()
@@ -1664,12 +1672,16 @@ def convert_optuna_params_to_model_params(
     return model_params
 
 
-def sample_params_ppo(trial: Trial) -> Dict[str, Any]:
+def sample_params_ppo(trial: Trial, n_envs: int) -> Dict[str, Any]:
     """
     Sampler for PPO hyperparams
     """
     n_steps = trial.suggest_categorical("n_steps", [512, 1024, 2048, 4096])
     batch_size = trial.suggest_categorical("batch_size", [64, 128, 256, 512, 1024])
+    if batch_size > n_steps:
+        raise TrialPruned("batch_size must be less than or equal to n_steps")
+    if (n_steps * n_envs) % batch_size != 0:
+        raise TrialPruned("(n_steps * n_envs) not divisible by batch_size")
     return convert_optuna_params_to_model_params(
         "PPO",
         {
@@ -1723,12 +1735,18 @@ def get_common_dqn_optuna_params(trial: Trial) -> Dict[str, Any]:
         "buffer_size", [int(1e4), int(5e4), int(1e5), int(2e5)]
     )
     learning_starts = trial.suggest_categorical(
-        "learning_starts", [500, 1000, 2000, 3000, 4000, 5000]
+        "learning_starts", [500, 1000, 2000, 3000, 4000, 5000, 8000, 10000]
     )
+    if learning_starts >= buffer_size:
+        raise TrialPruned("learning_starts must be less than buffer_size")
     return {
         "train_freq": train_freq,
-        "gradient_steps": max(
-            train_freq // trial.suggest_categorical("subsample_steps", [2, 4, 8]), 1
+        "gradient_steps": min(
+            train_freq,
+            max(
+                train_freq // trial.suggest_categorical("subsample_steps", [2, 4, 8]),
+                1,
+            ),
         ),
         "gamma": trial.suggest_categorical(
             "gamma", [0.93, 0.95, 0.97, 0.98, 0.99, 0.995, 0.997, 0.999, 0.9999]