From 7998113e658c5f4d5e6631b19c794bd5c0b11839 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Sun, 14 Sep 2025 23:48:03 +0200 Subject: [PATCH] perf(reforcexy): default optuna params for PPO to AdamW optimizer MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/user_data/freqaimodels/ReforceXY.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index 810169b..5bb3080 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -1733,13 +1733,14 @@ def get_activation_fn( def get_optimizer_class( - optimizer_class_name: Literal["adam"], + optimizer_class_name: Literal["adam", "adamw"], ) -> type[th.optim.Optimizer]: """ Get optimizer class """ return { "adam": th.optim.Adam, + "adamw": th.optim.AdamW, }.get(optimizer_class_name, th.optim.Adam) @@ -1873,7 +1874,7 @@ def sample_params_ppo(trial: Trial, n_envs: int) -> Dict[str, Any]: "activation_fn": trial.suggest_categorical( "activation_fn", ["tanh", "relu", "elu", "leaky_relu"] ), - "optimizer_class": trial.suggest_categorical("optimizer_class", ["adam"]), + "optimizer_class": trial.suggest_categorical("optimizer_class", ["adamw"]), }, ) -- 2.43.0