From e916ec810385755633087449c35c0c6738ea14d0 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Mon, 3 Mar 2025 14:43:00 +0100 Subject: [PATCH] fix(reforcexy): ensure policy_kwargs dict is initialized MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/user_data/freqaimodels/ReforceXY.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index c180825..456fbcf 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -231,6 +231,8 @@ class ReforceXY(BaseReinforcementLearningModel): logger.info("Clip range linear schedule enabled, initial value: %s", _cr) net_arch = model_params.get("policy_kwargs", {}).get("net_arch", [128, 128]) + if not model_params.get("policy_kwargs"): + model_params["policy_kwargs"] = {} model_params["policy_kwargs"].update( { "net_arch": net_arch, -- 2.43.0