From: Jérôme Benoit Date: Tue, 9 Sep 2025 11:08:12 +0000 (+0200) Subject: perf(reforcexy): cache model params X-Git-Url: https://git.piment-noir.org/?a=commitdiff_plain;h=3609678fe54a0a45576e4621399a95fc770e91f6;p=freqai-strategies.git perf(reforcexy): cache model params Signed-off-by: Jérôme Benoit --- diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index 25c7386..9bee654 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -140,6 +140,7 @@ class ReforceXY(BaseReinforcementLearningModel): "n_startup_trials", 15 ) self.optuna_callback: Optional[MaskableTrialEvalCallback] = None + self._model_params_cache: Optional[Dict[str, Any]] = None self.unset_unsupported() def unset_unsupported(self) -> None: @@ -227,6 +228,9 @@ class ReforceXY(BaseReinforcementLearningModel): """ Get model parameters """ + if self._model_params_cache is not None: + return copy.deepcopy(self._model_params_cache) + model_params: Dict[str, Any] = copy.deepcopy(self.model_training_parameters) if self.lr_schedule: @@ -244,11 +248,28 @@ class ReforceXY(BaseReinforcementLearningModel): net_arch = model_params.get("policy_kwargs", {}).get("net_arch", [128, 128]) if "PPO" in self.model_type: - if not isinstance(net_arch, dict): + if isinstance(net_arch, str): + net_arch = get_net_arch(self.model_type, net_arch) + if isinstance(net_arch, dict): + model_params["policy_kwargs"]["net_arch"] = net_arch + else: + model_params["policy_kwargs"]["net_arch"] = { + "pi": net_arch, + "vf": net_arch, + } + elif isinstance(net_arch, list): model_params["policy_kwargs"]["net_arch"] = { "pi": net_arch, "vf": net_arch, } + elif isinstance(net_arch, dict): + if not ("pi" in net_arch and "vf" in net_arch): + model_params["policy_kwargs"]["net_arch"] = { + "pi": net_arch.get("pi", net_arch.get("vf", [128, 128])), + "vf": net_arch.get("vf", net_arch.get("pi", [128, 128])), + } + else: + model_params["policy_kwargs"]["net_arch"] = net_arch else: model_params["policy_kwargs"]["net_arch"] = net_arch @@ -259,8 +280,8 @@ class ReforceXY(BaseReinforcementLearningModel): model_params["policy_kwargs"]["optimizer_class"] = get_optimizer_class( model_params.get("policy_kwargs", {}).get("optimizer_class", "adam") ) - - return model_params + self._model_params_cache = copy.deepcopy(model_params) + return copy.deepcopy(self._model_params_cache) def get_callbacks( self, eval_freq: int, data_path: str, trial: Trial = None @@ -1728,8 +1749,14 @@ def get_common_dqn_optuna_params(trial: Trial) -> Dict[str, Any]: exploration_initial_eps = trial.suggest_float( "exploration_initial_eps", exploration_final_eps, 1.0 ) + if exploration_initial_eps >= 0.9: + min_fraction = 0.2 + elif (exploration_initial_eps - exploration_final_eps) > 0.5: + min_fraction = 0.15 + else: + min_fraction = 0.05 exploration_fraction = trial.suggest_float( - "exploration_fraction", 0.05, 0.9, step=0.02 + "exploration_fraction", min_fraction, 0.9, step=0.02 ) buffer_size = trial.suggest_categorical( "buffer_size", [int(1e4), int(5e4), int(1e5), int(2e5)]