From 4a2dff8be56ab20fdf2b7b9d4eca112658580823 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Sun, 17 Aug 2025 21:49:24 +0200 Subject: [PATCH] perf(reforcexy): reduce optuna search space MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/user_data/freqaimodels/ReforceXY.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index 76d723f..56017e8 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -1558,7 +1558,6 @@ def get_optimizer_class(optimizer_class_name: str) -> type[th.optim.Optimizer]: """ return { "adam": th.optim.Adam, - "rmsprop": th.optim.RMSprop, }.get(optimizer_class_name, th.optim.Adam) @@ -1593,9 +1592,7 @@ def sample_params_ppo(trial: Trial) -> Dict[str, Any]: "activation_fn", ["tanh", "relu", "elu", "leaky_relu"] ) activation_fn = get_activation_fn(activation_fn_name) - optimizer_class_name = trial.suggest_categorical( - "optimizer_class", ["adam", "rmsprop"] - ) + optimizer_class_name = trial.suggest_categorical("optimizer_class", ["adam"]) optimizer_class = get_optimizer_class(optimizer_class_name) return { "n_steps": n_steps, @@ -1653,9 +1650,7 @@ def sample_params_dqn(trial: Trial) -> Dict[str, Any]: "activation_fn", ["tanh", "relu", "elu", "leaky_relu"] ) activation_fn = get_activation_fn(activation_fn_name) - optimizer_class_name = trial.suggest_categorical( - "optimizer_class", ["adam", "rmsprop"] - ) + optimizer_class_name = trial.suggest_categorical("optimizer_class", ["adam"]) optimizer_class = get_optimizer_class(optimizer_class_name) return { "gamma": gamma, -- 2.43.0