From 26d3c4e813e1f5ae9836ac500a0ca35c9fa9386d Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Mon, 15 Sep 2025 20:41:11 +0200 Subject: [PATCH] refactor: improve a bit training reproducibility MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/user_data/freqaimodels/ReforceXY.py | 3 +++ quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py | 2 +- quickadapter/user_data/strategies/Utils.py | 3 +++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index 1025268..6289765 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -261,6 +261,9 @@ class ReforceXY(BaseReinforcementLearningModel): model_params: Dict[str, Any] = copy.deepcopy(self.model_training_parameters) + if model_params.get("seed") is None: + model_params["seed"] = 42 + if self.lr_schedule: lr = model_params.get("learning_rate", 0.0003) if isinstance(lr, (int, float)): diff --git a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py index 68354f4..0208770 100644 --- a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py +++ b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py @@ -290,7 +290,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel): y_test = data_dictionary.get("test_labels") test_weights = data_dictionary.get("test_weights") - model_training_parameters = self.model_training_parameters + model_training_parameters = copy.deepcopy(self.model_training_parameters) start_time = time.time() if self._optuna_hyperopt: diff --git a/quickadapter/user_data/strategies/Utils.py b/quickadapter/user_data/strategies/Utils.py index 7400285..69101e3 100644 --- a/quickadapter/user_data/strategies/Utils.py +++ b/quickadapter/user_data/strategies/Utils.py @@ -743,6 +743,9 @@ def fit_regressor( init_model: Any = None, callbacks: Optional[list[Callable]] = None, ) -> Any: + if model_training_parameters.get("seed") is None: + model_training_parameters["seed"] = 1 + if regressor == "xgboost": from xgboost import XGBRegressor -- 2.43.0