From 91732a7d9c8625228035ac612349949519cab658 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Mon, 3 Mar 2025 22:13:48 +0100 Subject: [PATCH] fix(reforcexy): parse properly policy_kwargs tunable section MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/user_data/freqaimodels/ReforceXY.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index c485ed0..fa53dbf 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -242,13 +242,13 @@ class ReforceXY(BaseReinforcementLearningModel): else: model_params["policy_kwargs"]["net_arch"] = net_arch - model_params["policy_kwargs"]["activation_fn"] = model_params[ - "policy_kwargs" - ].get("activation_fn", "relu") + model_params["policy_kwargs"]["activation_fn"] = get_activation_fn( + model_params["policy_kwargs"].get("activation_fn", "relu") + ) - model_params["policy_kwargs"]["optimizer_class"] = model_params[ - "policy_kwargs" - ].get("optimizer_class", "adam") + model_params["policy_kwargs"]["optimizer_class"] = get_optimizer_class( + model_params["policy_kwargs"].get("optimizer_class", "adam") + ) return model_params -- 2.43.0