From 9effb70e5bb65ef14cec8a3ce1a884fc08e1d6b1 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Tue, 9 Sep 2025 03:24:20 +0200 Subject: [PATCH] perf(reforcexy): readd optuna params dependency properly MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/user_data/freqaimodels/ReforceXY.py | 30 +++++++++++++++---- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index 3892d30..3b26b04 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -695,7 +695,7 @@ class ReforceXY(BaseReinforcementLearningModel): Defines a single trial for hyperparameter optimization using Optuna """ if "PPO" in self.model_type: - params = sample_params_ppo(trial) + params = sample_params_ppo(trial, self.n_envs) if params.get("n_steps", 0) > total_timesteps: raise TrialPruned("n_steps exceeds total_timesteps") elif "QRDQN" in self.model_type: @@ -1664,12 +1664,19 @@ def convert_optuna_params_to_model_params( return model_params -def sample_params_ppo(trial: Trial) -> Dict[str, Any]: +def sample_params_ppo(trial: Trial, n_envs: int) -> Dict[str, Any]: """ Sampler for PPO hyperparams """ n_steps = trial.suggest_categorical("n_steps", [512, 1024, 2048, 4096]) - batch_size = trial.suggest_categorical("batch_size", [64, 128, 256, 512, 1024]) + n_batches = n_steps * max(1, n_envs) + batch_size_candidates = [64, 128, 256, 512, 1024] + batch_size_suggestions = [ + b for b in batch_size_candidates if b <= n_batches and n_batches % b == 0 + ] + if not batch_size_suggestions: + batch_size_suggestions = [b for b in batch_size_candidates if b <= n_batches] + batch_size = trial.suggest_categorical("batch_size", batch_size_suggestions) return convert_optuna_params_to_model_params( "PPO", { @@ -1716,14 +1723,27 @@ def get_common_dqn_optuna_params(trial: Trial) -> Dict[str, Any]: exploration_initial_eps = trial.suggest_float( "exploration_initial_eps", exploration_final_eps, 1.0 ) + if exploration_initial_eps >= 0.9: + min_fraction = 0.2 + elif (exploration_initial_eps - exploration_final_eps) > 0.5: + min_fraction = 0.15 + else: + min_fraction = 0.05 exploration_fraction = trial.suggest_float( - "exploration_fraction", 0.05, 0.9, step=0.02 + "exploration_fraction", min_fraction, 0.9, step=0.02 ) buffer_size = trial.suggest_categorical( "buffer_size", [int(1e4), int(5e4), int(1e5), int(2e5)] ) + learning_starts_suggestions = [ + v + for v in [500, 1000, 2000, 3000, 4000, 5000] + if v <= min(int(buffer_size * 0.05), 5000) + ] + if not learning_starts_suggestions: + learning_starts_suggestions = [500] learning_starts = trial.suggest_categorical( - "learning_starts", [500, 1000, 2000, 3000, 4000, 5000] + "learning_starts", learning_starts_suggestions ) return { "train_freq": train_freq, -- 2.43.0