From: Jérôme Benoit Date: Tue, 9 Sep 2025 02:10:30 +0000 (+0200) Subject: perf(reforcexy): refine optuna search space validation X-Git-Url: https://git.piment-noir.org/?a=commitdiff_plain;h=ff83fc8465f8a4cb33fe985261253b463c643301;p=freqai-strategies.git perf(reforcexy): refine optuna search space validation Signed-off-by: Jérôme Benoit --- diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index 3892d30..25c7386 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -544,6 +544,8 @@ class ReforceXY(BaseReinforcementLearningModel): else self.get_storage() ) eval_freq = max(1, len(train_df) // self.n_envs) + max_resource = max(1, (total_timesteps + eval_freq - 1) // eval_freq) + min_resource = min(3, max_resource) study: Study = create_study( study_name=study_name, sampler=TPESampler( @@ -553,8 +555,8 @@ class ReforceXY(BaseReinforcementLearningModel): seed=self.rl_config_optuna.get("seed", 42), ), pruner=HyperbandPruner( - min_resource=3, - max_resource=(total_timesteps + eval_freq - 1) // eval_freq, + min_resource=min_resource, + max_resource=max_resource, reduction_factor=3, ), direction=StudyDirection.MAXIMIZE, @@ -695,9 +697,9 @@ class ReforceXY(BaseReinforcementLearningModel): Defines a single trial for hyperparameter optimization using Optuna """ if "PPO" in self.model_type: - params = sample_params_ppo(trial) + params = sample_params_ppo(trial, self.n_envs) if params.get("n_steps", 0) > total_timesteps: - raise TrialPruned("n_steps exceeds total_timesteps") + raise TrialPruned("n_steps is greater than total_timesteps") elif "QRDQN" in self.model_type: params = sample_params_qrdqn(trial) elif "DQN" in self.model_type: @@ -733,8 +735,14 @@ class ReforceXY(BaseReinforcementLearningModel): try: model.learn(total_timesteps=total_timesteps, callback=callbacks) except AssertionError: - logger.warning("Optuna encountered NaN") + logger.warning("Optuna encountered NaN (AssertionError)") nan_encountered = True + except ValueError as e: + if "NaN" in str(e): + logger.warning("Optuna encountered NaN (ValueError)") + nan_encountered = True + else: + raise finally: if self.progressbar_callback: self.progressbar_callback.on_training_end() @@ -743,7 +751,7 @@ class ReforceXY(BaseReinforcementLearningModel): model.env.close() if nan_encountered: - return np.nan + raise TrialPruned("NaN encountered during training") if self.optuna_callback.is_pruned: raise TrialPruned() @@ -1664,12 +1672,16 @@ def convert_optuna_params_to_model_params( return model_params -def sample_params_ppo(trial: Trial) -> Dict[str, Any]: +def sample_params_ppo(trial: Trial, n_envs: int) -> Dict[str, Any]: """ Sampler for PPO hyperparams """ n_steps = trial.suggest_categorical("n_steps", [512, 1024, 2048, 4096]) batch_size = trial.suggest_categorical("batch_size", [64, 128, 256, 512, 1024]) + if batch_size > n_steps: + raise TrialPruned("batch_size must be less than or equal to n_steps") + if (n_steps * n_envs) % batch_size != 0: + raise TrialPruned("(n_steps * n_envs) not divisible by batch_size") return convert_optuna_params_to_model_params( "PPO", { @@ -1723,12 +1735,18 @@ def get_common_dqn_optuna_params(trial: Trial) -> Dict[str, Any]: "buffer_size", [int(1e4), int(5e4), int(1e5), int(2e5)] ) learning_starts = trial.suggest_categorical( - "learning_starts", [500, 1000, 2000, 3000, 4000, 5000] + "learning_starts", [500, 1000, 2000, 3000, 4000, 5000, 8000, 10000] ) + if learning_starts >= buffer_size: + raise TrialPruned("learning_starts must be less than buffer_size") return { "train_freq": train_freq, - "gradient_steps": max( - train_freq // trial.suggest_categorical("subsample_steps", [2, 4, 8]), 1 + "gradient_steps": min( + train_freq, + max( + train_freq // trial.suggest_categorical("subsample_steps", [2, 4, 8]), + 1, + ), ), "gamma": trial.suggest_categorical( "gamma", [0.93, 0.95, 0.97, 0.98, 0.99, 0.995, 0.997, 0.999, 0.9999]