From 01341f6906629e8de4446245fce8031aad806135 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Thu, 18 Sep 2025 12:39:52 +0200 Subject: [PATCH] fix(reforcexy): double evaluation during hyperopt MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/user_data/freqaimodels/ReforceXY.py | 50 +++++++++---------- 1 file changed, 24 insertions(+), 26 deletions(-) diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index f978803..d146982 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -375,33 +375,31 @@ class ReforceXY(BaseReinforcementLearningModel): self.progressbar_callback = ProgressBarCallback() callbacks.append(self.progressbar_callback) - self.eval_callback = MaskableEvalCallback( - self.eval_env, - eval_freq=eval_freq, - deterministic=True, - render=False, - best_model_save_path=data_path, - use_masking=self.is_maskable, - callback_on_new_best=rollout_plot_callback, - callback_after_eval=no_improvement_callback, - verbose=verbose, - ) - callbacks.append(self.eval_callback) - if not trial: - return callbacks - - self.optuna_callback = MaskableTrialEvalCallback( - self.eval_env, - trial, - eval_freq=eval_freq, - deterministic=True, - render=False, - best_model_save_path=data_path, - use_masking=self.is_maskable, - verbose=verbose, - ) - callbacks.append(self.optuna_callback) + self.eval_callback = MaskableEvalCallback( + self.eval_env, + eval_freq=eval_freq, + deterministic=True, + render=False, + best_model_save_path=data_path, + use_masking=self.is_maskable, + callback_on_new_best=rollout_plot_callback, + callback_after_eval=no_improvement_callback, + verbose=verbose, + ) + callbacks.append(self.eval_callback) + else: + self.optuna_callback = MaskableTrialEvalCallback( + self.eval_env, + trial, + eval_freq=eval_freq, + deterministic=True, + render=False, + best_model_save_path=data_path, + use_masking=self.is_maskable, + verbose=verbose, + ) + callbacks.append(self.optuna_callback) return callbacks def fit( -- 2.43.0