From ff582b332d831974a80656365e11afce9074de6a Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Sun, 16 Feb 2025 18:45:10 +0100 Subject: [PATCH] refactor(reforcexy): do hyperopt on a per pair basis, take 2 MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/user_data/config-template.json | 3 +- ReforceXY/user_data/freqaimodels/ReforceXY.py | 32 +++++++++++-------- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/ReforceXY/user_data/config-template.json b/ReforceXY/user_data/config-template.json index 8e498f1..2d7e4dc 100644 --- a/ReforceXY/user_data/config-template.json +++ b/ReforceXY/user_data/config-template.json @@ -179,8 +179,7 @@ "plot_new_best": false // Enable tensorboard rollout plot upon finding a new best model }, "rl_config_optuna": { - "enabled": false, // Enable optuna hyperopt - "n_jobs": 6, + "enabled": true, // Enable optuna hyperopt "n_trials": 100, "n_startup_trials": 10, "timeout_hours": 0 diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index 4c1cc2f..3609cbb 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -83,7 +83,6 @@ class ReforceXY(BaseReinforcementLearningModel): }, "rl_config_optuna": { "enabled": false, // Enable optuna hyperopt - "n_jobs": 1, "n_trials": 100, "n_startup_trials": 10, "timeout_hours": 0, @@ -122,7 +121,7 @@ class ReforceXY(BaseReinforcementLearningModel): self.optuna_n_startup_trials: int = self.rl_config_optuna.get( "n_startup_trials", 10 ) - self.optuna_trial_params: list = [] + self.optuna_trial_params: Dict[str, list] = {} self.optuna_callback: Optional[MaskableTrialEvalCallback] = None self.unset_unsupported() @@ -452,7 +451,7 @@ class ReforceXY(BaseReinforcementLearningModel): """ study_name = str(dk.pair) storage_dir = str(dk.full_path) - storage_backend = self.rl_config_optuna.get("storage", "file") + storage_backend = self.rl_config_optuna.get("storage", "sqlite") if storage_backend == "sqlite": storage = f"sqlite:///{storage_dir}/optuna-{dk.pair.split('/')[0]}.sqlite" elif storage_backend == "file": @@ -484,27 +483,29 @@ class ReforceXY(BaseReinforcementLearningModel): ), gc_after_trial=True, show_progress_bar=self.rl_config.get("progress_bar", False), - n_jobs=self.rl_config_optuna.get("n_jobs", 1), + n_jobs=1, ) except KeyboardInterrupt: pass - # FIXME: ensure that best trial params are handled on a per pair basis - logger.info("------------ Hyperopt results ------------") + logger.info("------------ Hyperopt results %s ------------", dk.pair) logger.info( "Best trial: %s. Score: %s", study.best_trial.number, study.best_trial.value ) logger.info( - "Best trial params: %s", self.optuna_trial_params[study.best_trial.number] + "Best trial params: %s", + self.optuna_trial_params[dk.pair][study.best_trial.number], ) - logger.info("-----------------------------------------") + logger.info("---------------------------------------------") - best_trial_path = Path(dk.full_path / "hyperopt_best_trial.json") - logger.info("dumping json to %s", best_trial_path) + best_trial_path = Path( + dk.full_path / f"{dk.pair.split('/')[0]}_hyperopt_best_params.json" + ) + logger.info("dumping to %s JSON file", best_trial_path) with best_trial_path.open("w", encoding="utf-8") as write_file: json.dump(study.best_trial.params, write_file, indent=4) - return self.optuna_trial_params[study.best_trial.number] + return self.optuna_trial_params[dk.pair][study.best_trial.number] def objective( self, trial: Trial, train_df, total_timesteps: int, dk: FreqaiDataKitchen @@ -528,10 +529,13 @@ class ReforceXY(BaseReinforcementLearningModel): dk.full_path / "tensorboard" / dk.pair.split("/")[0] ) - logger.info("------------ Hyperopt ------------") - logger.info("------------ Trial %s ------------", trial.number) + logger.info( + "------------ Hyperopt trial %d %s ------------", trial.number, dk.pair + ) logger.info("Trial %s params: %s", trial.number, params) - self.optuna_trial_params.append(params) + if dk.pair not in self.optuna_trial_params: + self.optuna_trial_params[dk.pair] = [] + self.optuna_trial_params[dk.pair].append(params) model = self.MODELCLASS( self.policy_type, -- 2.43.0