]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
perf(reforcexy): cache model params
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Tue, 9 Sep 2025 11:08:12 +0000 (13:08 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Tue, 9 Sep 2025 11:08:12 +0000 (13:08 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index 25c73866121dae55423a02baa4a8a01f051a98e9..9bee6541a9e3c88bb6b5ffddcc82ec2f8c7e5aa9 100644 (file)
@@ -140,6 +140,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             "n_startup_trials", 15
         )
         self.optuna_callback: Optional[MaskableTrialEvalCallback] = None
+        self._model_params_cache: Optional[Dict[str, Any]] = None
         self.unset_unsupported()
 
     def unset_unsupported(self) -> None:
@@ -227,6 +228,9 @@ class ReforceXY(BaseReinforcementLearningModel):
         """
         Get model parameters
         """
+        if self._model_params_cache is not None:
+            return copy.deepcopy(self._model_params_cache)
+
         model_params: Dict[str, Any] = copy.deepcopy(self.model_training_parameters)
 
         if self.lr_schedule:
@@ -244,11 +248,28 @@ class ReforceXY(BaseReinforcementLearningModel):
 
         net_arch = model_params.get("policy_kwargs", {}).get("net_arch", [128, 128])
         if "PPO" in self.model_type:
-            if not isinstance(net_arch, dict):
+            if isinstance(net_arch, str):
+                net_arch = get_net_arch(self.model_type, net_arch)
+                if isinstance(net_arch, dict):
+                    model_params["policy_kwargs"]["net_arch"] = net_arch
+                else:
+                    model_params["policy_kwargs"]["net_arch"] = {
+                        "pi": net_arch,
+                        "vf": net_arch,
+                    }
+            elif isinstance(net_arch, list):
                 model_params["policy_kwargs"]["net_arch"] = {
                     "pi": net_arch,
                     "vf": net_arch,
                 }
+            elif isinstance(net_arch, dict):
+                if not ("pi" in net_arch and "vf" in net_arch):
+                    model_params["policy_kwargs"]["net_arch"] = {
+                        "pi": net_arch.get("pi", net_arch.get("vf", [128, 128])),
+                        "vf": net_arch.get("vf", net_arch.get("pi", [128, 128])),
+                    }
+                else:
+                    model_params["policy_kwargs"]["net_arch"] = net_arch
         else:
             model_params["policy_kwargs"]["net_arch"] = net_arch
 
@@ -259,8 +280,8 @@ class ReforceXY(BaseReinforcementLearningModel):
         model_params["policy_kwargs"]["optimizer_class"] = get_optimizer_class(
             model_params.get("policy_kwargs", {}).get("optimizer_class", "adam")
         )
-
-        return model_params
+        self._model_params_cache = copy.deepcopy(model_params)
+        return copy.deepcopy(self._model_params_cache)
 
     def get_callbacks(
         self, eval_freq: int, data_path: str, trial: Trial = None
@@ -1728,8 +1749,14 @@ def get_common_dqn_optuna_params(trial: Trial) -> Dict[str, Any]:
     exploration_initial_eps = trial.suggest_float(
         "exploration_initial_eps", exploration_final_eps, 1.0
     )
+    if exploration_initial_eps >= 0.9:
+        min_fraction = 0.2
+    elif (exploration_initial_eps - exploration_final_eps) > 0.5:
+        min_fraction = 0.15
+    else:
+        min_fraction = 0.05
     exploration_fraction = trial.suggest_float(
-        "exploration_fraction", 0.05, 0.9, step=0.02
+        "exploration_fraction", min_fraction, 0.9, step=0.02
     )
     buffer_size = trial.suggest_categorical(
         "buffer_size", [int(1e4), int(5e4), int(1e5), int(2e5)]