]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor(reforcexy): do hyperopt on a per pair basis, take 2
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Sun, 16 Feb 2025 17:45:10 +0000 (18:45 +0100)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Sun, 16 Feb 2025 17:45:10 +0000 (18:45 +0100)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/config-template.json
ReforceXY/user_data/freqaimodels/ReforceXY.py

index 8e498f163e015ecf9269d960619fc2e78c0a9838..2d7e4dc8065571634bd9a49ceea74bc310712fa2 100644 (file)
       "plot_new_best": false // Enable tensorboard rollout plot upon finding a new best model
     },
     "rl_config_optuna": {
-      "enabled": false, // Enable optuna hyperopt
-      "n_jobs": 6,
+      "enabled": true, // Enable optuna hyperopt
       "n_trials": 100,
       "n_startup_trials": 10,
       "timeout_hours": 0
index 4c1cc2f7c9316dc49bdf8afdac44ea4c50efba33..3609cbb3218c3198fa41985c5ac5a0ae42d6895f 100644 (file)
@@ -83,7 +83,6 @@ class ReforceXY(BaseReinforcementLearningModel):
             },
             "rl_config_optuna": {
                 "enabled": false,                   // Enable optuna hyperopt
-                "n_jobs": 1,
                 "n_trials": 100,
                 "n_startup_trials": 10,
                 "timeout_hours": 0,
@@ -122,7 +121,7 @@ class ReforceXY(BaseReinforcementLearningModel):
         self.optuna_n_startup_trials: int = self.rl_config_optuna.get(
             "n_startup_trials", 10
         )
-        self.optuna_trial_params: list = []
+        self.optuna_trial_params: Dict[str, list] = {}
         self.optuna_callback: Optional[MaskableTrialEvalCallback] = None
         self.unset_unsupported()
 
@@ -452,7 +451,7 @@ class ReforceXY(BaseReinforcementLearningModel):
         """
         study_name = str(dk.pair)
         storage_dir = str(dk.full_path)
-        storage_backend = self.rl_config_optuna.get("storage", "file")
+        storage_backend = self.rl_config_optuna.get("storage", "sqlite")
         if storage_backend == "sqlite":
             storage = f"sqlite:///{storage_dir}/optuna-{dk.pair.split('/')[0]}.sqlite"
         elif storage_backend == "file":
@@ -484,27 +483,29 @@ class ReforceXY(BaseReinforcementLearningModel):
                 ),
                 gc_after_trial=True,
                 show_progress_bar=self.rl_config.get("progress_bar", False),
-                n_jobs=self.rl_config_optuna.get("n_jobs", 1),
+                n_jobs=1,
             )
         except KeyboardInterrupt:
             pass
 
-        # FIXME: ensure that best trial params are handled on a per pair basis
-        logger.info("------------ Hyperopt results ------------")
+        logger.info("------------ Hyperopt results %s ------------", dk.pair)
         logger.info(
             "Best trial: %s. Score: %s", study.best_trial.number, study.best_trial.value
         )
         logger.info(
-            "Best trial params: %s", self.optuna_trial_params[study.best_trial.number]
+            "Best trial params: %s",
+            self.optuna_trial_params[dk.pair][study.best_trial.number],
         )
-        logger.info("-----------------------------------------")
+        logger.info("---------------------------------------------")
 
-        best_trial_path = Path(dk.full_path / "hyperopt_best_trial.json")
-        logger.info("dumping json to %s", best_trial_path)
+        best_trial_path = Path(
+            dk.full_path / f"{dk.pair.split('/')[0]}_hyperopt_best_params.json"
+        )
+        logger.info("dumping to %s JSON file", best_trial_path)
         with best_trial_path.open("w", encoding="utf-8") as write_file:
             json.dump(study.best_trial.params, write_file, indent=4)
 
-        return self.optuna_trial_params[study.best_trial.number]
+        return self.optuna_trial_params[dk.pair][study.best_trial.number]
 
     def objective(
         self, trial: Trial, train_df, total_timesteps: int, dk: FreqaiDataKitchen
@@ -528,10 +529,13 @@ class ReforceXY(BaseReinforcementLearningModel):
             dk.full_path / "tensorboard" / dk.pair.split("/")[0]
         )
 
-        logger.info("------------ Hyperopt ------------")
-        logger.info("------------ Trial %s ------------", trial.number)
+        logger.info(
+            "------------ Hyperopt trial %d %s ------------", trial.number, dk.pair
+        )
         logger.info("Trial %s params: %s", trial.number, params)
-        self.optuna_trial_params.append(params)
+        if dk.pair not in self.optuna_trial_params:
+            self.optuna_trial_params[dk.pair] = []
+        self.optuna_trial_params[dk.pair].append(params)
 
         model = self.MODELCLASS(
             self.policy_type,