From fca4568bd01187331aec36a3aef7e04cfef337fa Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Tue, 4 Mar 2025 11:04:52 +0100 Subject: [PATCH] perf(reforcexy): reduce optuna search space MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/user_data/freqaimodels/ReforceXY.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index e6559be..492b9c9 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -1453,7 +1453,6 @@ def get_optimizer_class(optimizer_class_name: str) -> type[th.optim.Optimizer]: return { "adam": th.optim.Adam, "rmsprop": th.optim.RMSprop, - "sgd": th.optim.SGD, }[optimizer_class_name] @@ -1489,7 +1488,7 @@ def sample_params_ppo(trial: Trial) -> Dict[str, Any]: ) activation_fn = get_activation_fn(activation_fn_name) optimizer_class_name = trial.suggest_categorical( - "optimizer_class", ["adam", "rmsprop", "sgd"] + "optimizer_class", ["adam", "rmsprop"] ) optimizer_class = get_optimizer_class(optimizer_class_name) return { @@ -1549,7 +1548,7 @@ def sample_params_dqn(trial: Trial) -> Dict[str, Any]: ) activation_fn = get_activation_fn(activation_fn_name) optimizer_class_name = trial.suggest_categorical( - "optimizer_class", ["adam", "rmsprop", "sgd"] + "optimizer_class", ["adam", "rmsprop"] ) optimizer_class = get_optimizer_class(optimizer_class_name) return { -- 2.43.0