From 85129672f965951154190d2b8a052b8c3761e641 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Mon, 22 Dec 2025 23:16:46 +0100 Subject: [PATCH] feat(ReforceXY): add purge_period to optuna config to periodically purge optuna studies MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/user_data/config-template.json | 1 - ReforceXY/user_data/freqaimodels/ReforceXY.py | 91 ++++++++++++++++++- .../freqaimodels/QuickAdapterRegressorV3.py | 4 +- 3 files changed, 88 insertions(+), 8 deletions(-) diff --git a/ReforceXY/user_data/config-template.json b/ReforceXY/user_data/config-template.json index 0a58257..34f497c 100644 --- a/ReforceXY/user_data/config-template.json +++ b/ReforceXY/user_data/config-template.json @@ -180,7 +180,6 @@ }, "rl_config_optuna": { "enabled": true, // Enable optuna hyperopt - "per_pair": true, // Enable per pair hyperopt "n_trials": 100, "n_startup_trials": 15, "timeout_hours": 0 diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index db83ef1..2b68140 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -136,6 +136,7 @@ class ReforceXY(BaseReinforcementLearningModel): "warm_start": false, // If true, enqueue previous best params if exists "sampler": "tpe", // Optuna sampler (tpe|auto) "storage": "sqlite", // Optuna storage backend (sqlite|file) + "purge_period": 0, // Purge Optuna study every X retrains (0 disables) "seed": 42, // RNG seed } } @@ -307,6 +308,9 @@ class ReforceXY(BaseReinforcementLearningModel): self.optuna_n_startup_trials: int = self.rl_config_optuna.get( "n_startup_trials", 15 ) + self.optuna_purge_period: int = int( + self.rl_config_optuna.get("purge_period", 0) + ) self.optuna_eval_callback: Optional[MaskableTrialEvalCallback] = None self._model_params_cache: Optional[Dict[str, Any]] = None self.unset_unsupported() @@ -409,6 +413,24 @@ class ReforceXY(BaseReinforcementLearningModel): self.n_eval_episodes, ) self.n_eval_episodes = 5 + if ( + not isinstance(self.optuna_purge_period, int) + or self.optuna_purge_period < 0 + ): + logger.warning( + "Invalid purge_period=%s. Forcing purge_period=0", + self.optuna_purge_period, + ) + self.optuna_purge_period = 0 + if ( + self.rl_config_optuna.get("continuous", False) + and self.optuna_purge_period > 0 + ): + logger.warning( + "purge_period=%s has no effect when continuous=True. Forcing purge_period=0", + self.optuna_purge_period, + ) + self.optuna_purge_period = 0 add_state_info = self.rl_config.get("add_state_info", False) if not add_state_info: logger.warning( @@ -1079,8 +1101,50 @@ class ReforceXY(BaseReinforcementLearningModel): def delete_study(study_name: str, storage: BaseStorage) -> None: try: delete_study(study_name=study_name, storage=storage) - except Exception: - pass + except Exception as e: + logger.warning("Failed to delete study %s: %r", study_name, e) + + @staticmethod + def _sanitize_pair(pair: str) -> str: + """Normalize a trading pair into a safe key.""" + sanitized = pair.replace("/", "_").replace(":", "_") + return "".join(ch for ch in sanitized if ch.isalnum() or ch in ("_", "-", ".")) + + def _optuna_retrain_counters_path(self) -> Path: + return Path(self.full_path / "optuna-retrain-counters.json") + + def _load_optuna_retrain_counters(self) -> Dict[str, int]: + counters_path = self._optuna_retrain_counters_path() + if not counters_path.is_file(): + return {} + try: + with counters_path.open("r", encoding="utf-8") as read_file: + data: Dict[str, int] = json.load(read_file) + if isinstance(data, dict): + result: Dict[str, int] = {} + for key, value in data.items(): + if isinstance(key, str) and isinstance(value, int): + result[key] = value + return result + except Exception as e: + logger.warning("Failed to load optuna retrain counters: %r", e) + return {} + + def _save_optuna_retrain_counters(self, counters: Dict[str, int]) -> None: + counters_path = self._optuna_retrain_counters_path() + try: + with counters_path.open("w", encoding="utf-8") as write_file: + json.dump(counters, write_file, indent=4, sort_keys=True) + except Exception as e: + logger.warning("Failed to save optuna retrain counters: %r", e) + + def _increment_optuna_retrain_counter(self, pair: str) -> int: + pair = ReforceXY._sanitize_pair(pair) + counters = self._load_optuna_retrain_counters() + pair_count = int(counters.get(pair, 0)) + 1 + counters[pair] = pair_count + self._save_optuna_retrain_counters(counters) + return pair_count def create_storage(self, pair: str) -> BaseStorage: """ @@ -1182,9 +1246,24 @@ class ReforceXY(BaseReinforcementLearningModel): study_name = f"{identifier}-{dk.pair}" storage = self.create_storage(dk.pair) continuous = self.rl_config_optuna.get("continuous", False) - if continuous: + + pair_purge_count = self._increment_optuna_retrain_counter(dk.pair) + pair_purge_triggered = ( + self.optuna_purge_period > 0 + and pair_purge_count % self.optuna_purge_period == 0 + ) + + if continuous or pair_purge_triggered: ReforceXY.delete_study(study_name, storage) + if pair_purge_triggered: + logger.info( + "Hyperopt study %s purged on retrain %s (purge_period=%s)", + study_name, + pair_purge_count, + self.optuna_purge_period, + ) + reduction_factor = 3 n_envs = self.n_envs if ReforceXY._MODEL_TYPES[0] in self.model_type: # "PPO" @@ -1206,9 +1285,11 @@ class ReforceXY(BaseReinforcementLearningModel): ), direction=StudyDirection.MAXIMIZE, storage=storage, - load_if_exists=not continuous, + load_if_exists=not continuous and not pair_purge_triggered, ) - if self.rl_config_optuna.get("warm_start", False): + if ( + self.rl_config_optuna.get("warm_start", False) and not pair_purge_triggered + ) or pair_purge_triggered: best_trial_params = self.load_best_trial_params(dk.pair) if best_trial_params: study.enqueue_trial(best_trial_params) diff --git a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py index ac10393..dbb406b 100644 --- a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py +++ b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py @@ -2418,8 +2418,8 @@ class QuickAdapterRegressorV3(BaseRegressionModel): ) -> None: try: optuna.delete_study(study_name=study_name, storage=storage) - except Exception: - pass + except Exception as e: + logger.warning("Failed to delete study %s: %r", study_name, e) @staticmethod def optuna_load_study( -- 2.43.0