]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
perf(reforcexy): refine optuna defaults
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Wed, 21 May 2025 09:05:01 +0000 (11:05 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Wed, 21 May 2025 09:05:01 +0000 (11:05 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/config-template.json
ReforceXY/user_data/freqaimodels/ReforceXY.py

index 15fba2e2fe024f6027440ba6fc902a461656ad6b..3ab18679fac5429db99423f966778d9c5e983b79 100644 (file)
       "enabled": true, // Enable optuna hyperopt
       "per_pair": false, // Enable per pair hyperopt
       "n_trials": 100,
-      "n_startup_trials": 10,
+      "n_startup_trials": 15,
       "timeout_hours": 0
     }
   },
index 64c4be17f754bcf5ee8da74e3fc52f62a871e789..c9ff9c4f9228fc04694200489835ce528e24822f 100644 (file)
@@ -86,7 +86,7 @@ class ReforceXY(BaseReinforcementLearningModel):
                 "enabled": false,                   // Enable optuna hyperopt
                 "per_pair: false,                   // Enable per pair hyperopt
                 "n_trials": 100,
-                "n_startup_trials": 10,
+                "n_startup_trials": 15,
                 "timeout_hours": 0,
             }
         }
@@ -129,7 +129,7 @@ class ReforceXY(BaseReinforcementLearningModel):
         self.optuna_timeout_hours: float = self.rl_config_optuna.get("timeout_hours", 0)
         self.optuna_n_trials: int = self.rl_config_optuna.get("n_trials", 100)
         self.optuna_n_startup_trials: int = self.rl_config_optuna.get(
-            "n_startup_trials", 10
+            "n_startup_trials", 15
         )
         self.optuna_callback: Optional[MaskableTrialEvalCallback] = None
         self.unset_unsupported()
@@ -724,7 +724,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             model.env.close()
 
         if nan_encountered:
-            return float("nan")
+            return np.nan
 
         if self.optuna_callback.is_pruned:
             raise TrialPruned()