]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
perf(reforcexy): refine optuna trial pruning condition
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Thu, 18 Sep 2025 21:09:23 +0000 (23:09 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Thu, 18 Sep 2025 21:09:23 +0000 (23:09 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index f98cade43d0e966cb4452b1aaacae4ce73a3ba77..cd7aa3721bcac29a4a35ee5f322e0db88250a02c 100644 (file)
@@ -854,6 +854,15 @@ class ReforceXY(BaseReinforcementLearningModel):
         else:
             raise NotImplementedError
 
+        if "DQN" in self.model_type:
+            batch_size = params.get("batch_size")
+            gradient_steps = params.get("gradient_steps")
+            buffer_size = params.get("buffer_size")
+            if (batch_size * gradient_steps) > buffer_size:
+                raise TrialPruned(
+                    "batch_size * gradient_steps is greater than buffer_size"
+                )
+
         # Ensure that the sampled parameters take precedence
         params = deepmerge(self.get_model_params(), params)
 
@@ -2249,9 +2258,9 @@ def sample_params_ppo(trial: Trial, n_envs: int) -> Dict[str, Any]:
     n_steps = trial.suggest_categorical("n_steps", list(PPO_N_STEPS))
     batch_size = trial.suggest_categorical("batch_size", [64, 128, 256, 512, 1024])
     if batch_size > n_steps:
-        raise TrialPruned("batch_size must be less than or equal to n_steps")
+        raise TrialPruned("batch_size is greater than n_steps")
     if (n_steps * n_envs) % batch_size != 0:
-        raise TrialPruned("(n_steps * n_envs) not divisible by batch_size")
+        raise TrialPruned("n_steps * n_envs is not divisible by batch_size")
     return convert_optuna_params_to_model_params(
         "PPO",
         {
@@ -2311,7 +2320,7 @@ def get_common_dqn_optuna_params(trial: Trial) -> Dict[str, Any]:
         "learning_starts", [500, 1000, 2000, 3000, 4000, 5000, 8000, 10000]
     )
     if learning_starts >= buffer_size:
-        raise TrialPruned("learning_starts must be less than buffer_size")
+        raise TrialPruned("learning_starts is greater than or equal to buffer_size")
     return {
         "train_freq": trial.suggest_categorical(
             "train_freq", [2, 4, 8, 16, 128, 256, 512, 1024]