]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
fix(reforcexy): properly display HPO study best params
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Fri, 21 Feb 2025 11:44:37 +0000 (12:44 +0100)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Fri, 21 Feb 2025 11:44:37 +0000 (12:44 +0100)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py
quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py
quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py

index d12c85a60de08568895cf31b13c92525ec77b0d4..acd1d477aa017ac5714b0ee2741f6fa327a0170d 100644 (file)
@@ -125,9 +125,6 @@ class ReforceXY(BaseReinforcementLearningModel):
         self.optuna_n_startup_trials: int = self.rl_config_optuna.get(
             "n_startup_trials", 10
         )
-        self.optuna_trial_params: Dict[str, list] = {}
-        for pair in self.pairs:
-            self.optuna_trial_params[pair] = []
         self.optuna_callback: Optional[MaskableTrialEvalCallback] = None
         self.unset_unsupported()
 
@@ -514,15 +511,12 @@ class ReforceXY(BaseReinforcementLearningModel):
         logger.info(
             "Best trial: %s. Score: %s", study.best_trial.number, study.best_trial.value
         )
-        logger.info(
-            "Best trial params: %s",
-            self.optuna_trial_params[dk.pair][study.best_trial.number],
-        )
+        logger.info("Best trial params: %s", study.best_trial.params)
         logger.info("-------------------------------------------------------")
 
         self.save_best_params(dk.pair, study.best_trial.params)
 
-        return self.optuna_trial_params[dk.pair][study.best_trial.number]
+        return study.best_trial.params
 
     def save_best_params(self, pair: str, best_params: Dict) -> None:
         """
@@ -531,7 +525,7 @@ class ReforceXY(BaseReinforcementLearningModel):
         best_params_path = Path(
             self.full_path / f"{pair.split('/')[0]}_hyperopt_best_params.json"
         )
-        logger.info("saving to %s JSON file", best_params_path)
+        logger.info(f"{pair}: saving best params to %s JSON file", best_params_path)
         with best_params_path.open("w", encoding="utf-8") as write_file:
             json.dump(best_params, write_file, indent=4)
 
@@ -543,7 +537,9 @@ class ReforceXY(BaseReinforcementLearningModel):
             self.full_path / f"{pair.split('/')[0]}_hyperopt_best_params.json"
         )
         if best_params_path.is_file():
-            logger.info("loading from %s JSON file", best_params_path)
+            logger.info(
+                f"{pair}: loading best params from %s JSON file", best_params_path
+            )
             with best_params_path.open("r", encoding="utf-8") as read_file:
                 best_params = json.load(read_file)
             return best_params
@@ -575,7 +571,6 @@ class ReforceXY(BaseReinforcementLearningModel):
             "------------ Hyperopt trial %d %s ------------", trial.number, dk.pair
         )
         logger.info("Trial %s params: %s", trial.number, params)
-        self.optuna_trial_params[dk.pair].append(params)
 
         model = self.MODELCLASS(
             self.policy_type,
index 5d909b32e7717aa40ac1f36fa106dda868014f13..d034a2727cd222fe46b1294212af7a2c9a2b5b64 100644 (file)
@@ -444,13 +444,17 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel):
                 return json.load(read_file)
         return None
 
-    def optuna_study_delete(self, study_name: str, storage) -> None:
+    def optuna_study_delete(
+        self, study_name: str, storage: optuna.storages.BaseStorage
+    ) -> None:
         try:
             optuna.delete_study(study_name=study_name, storage=storage)
         except Exception:
             pass
 
-    def optuna_study_load(self, study_name: str, storage) -> optuna.study.Study | None:
+    def optuna_study_load(
+        self, study_name: str, storage: optuna.storages.BaseStorage
+    ) -> optuna.study.Study | None:
         try:
             study = optuna.load_study(study_name=study_name, storage=storage)
         except Exception:
index 872bfa5d6d235f70d6beaaada3a658fb25694030..36db9bdb4d12cd251d92d6f7975dcdf44ac061fd 100644 (file)
@@ -445,13 +445,17 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel):
                 return json.load(read_file)
         return None
 
-    def optuna_study_delete(self, study_name: str, storage) -> None:
+    def optuna_study_delete(
+        self, study_name: str, storage: optuna.storages.BaseStorage
+    ) -> None:
         try:
             optuna.delete_study(study_name=study_name, storage=storage)
         except Exception:
             pass
 
-    def optuna_study_load(self, study_name: str, storage) -> optuna.study.Study | None:
+    def optuna_study_load(
+        self, study_name: str, storage: optuna.storages.BaseStorage
+    ) -> optuna.study.Study | None:
         try:
             study = optuna.load_study(study_name=study_name, storage=storage)
         except Exception: