From dff013818953ea8d29b38ab940eccac944132052 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Thu, 18 Sep 2025 23:09:23 +0200 Subject: [PATCH] perf(reforcexy): refine optuna trial pruning condition MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/user_data/freqaimodels/ReforceXY.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index f98cade..cd7aa37 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -854,6 +854,15 @@ class ReforceXY(BaseReinforcementLearningModel): else: raise NotImplementedError + if "DQN" in self.model_type: + batch_size = params.get("batch_size") + gradient_steps = params.get("gradient_steps") + buffer_size = params.get("buffer_size") + if (batch_size * gradient_steps) > buffer_size: + raise TrialPruned( + "batch_size * gradient_steps is greater than buffer_size" + ) + # Ensure that the sampled parameters take precedence params = deepmerge(self.get_model_params(), params) @@ -2249,9 +2258,9 @@ def sample_params_ppo(trial: Trial, n_envs: int) -> Dict[str, Any]: n_steps = trial.suggest_categorical("n_steps", list(PPO_N_STEPS)) batch_size = trial.suggest_categorical("batch_size", [64, 128, 256, 512, 1024]) if batch_size > n_steps: - raise TrialPruned("batch_size must be less than or equal to n_steps") + raise TrialPruned("batch_size is greater than n_steps") if (n_steps * n_envs) % batch_size != 0: - raise TrialPruned("(n_steps * n_envs) not divisible by batch_size") + raise TrialPruned("n_steps * n_envs is not divisible by batch_size") return convert_optuna_params_to_model_params( "PPO", { @@ -2311,7 +2320,7 @@ def get_common_dqn_optuna_params(trial: Trial) -> Dict[str, Any]: "learning_starts", [500, 1000, 2000, 3000, 4000, 5000, 8000, 10000] ) if learning_starts >= buffer_size: - raise TrialPruned("learning_starts must be less than buffer_size") + raise TrialPruned("learning_starts is greater than or equal to buffer_size") return { "train_freq": trial.suggest_categorical( "train_freq", [2, 4, 8, 16, 128, 256, 512, 1024] -- 2.43.0