From cd7ce362486cbe991b404436b73e0c4f731c2aee Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Mon, 22 Sep 2025 00:48:47 +0200 Subject: [PATCH] refactor(reforcexy): refine variables and methods namespace MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/user_data/freqaimodels/ReforceXY.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index 5609f60..886e858 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -527,24 +527,24 @@ class ReforceXY(BaseReinforcementLearningModel): start_time = time.time() if self.hyperopt: - best_trial_params = self.study(dk, total_timesteps) - if best_trial_params is None: + best_params = self.optimize(dk, total_timesteps) + if best_params is None: logger.error( "Hyperopt failed. Using default configured model params instead" ) - best_trial_params = self.get_model_params() - model_params = best_trial_params + best_params = self.get_model_params() + model_params = best_params else: model_params = self.get_model_params() logger.info("%s params: %s", self.model_type, model_params) if "PPO" in self.model_type: - min_steps = 2 * model_params.get("n_steps", 0) * self.n_envs - if total_timesteps < min_steps: + min_timesteps = 2 * model_params.get("n_steps", 0) * self.n_envs + if total_timesteps < min_timesteps: logger.warning( "total_timesteps=%s is less than 2*n_steps*n_envs=%s. This may lead to suboptimal training results", total_timesteps, - min_steps, + min_timesteps, ) if self.activate_tensorboard: @@ -718,7 +718,7 @@ class ReforceXY(BaseReinforcementLearningModel): except (ValueError, KeyError): return False - def study( + def optimize( self, dk: FreqaiDataKitchen, total_timesteps: int ) -> Optional[Dict[str, Any]]: """ -- 2.43.0