From: Jérôme Benoit Date: Mon, 15 Sep 2025 18:41:11 +0000 (+0200) Subject: refactor: improve a bit training reproducibility X-Git-Url: https://git.piment-noir.org/?a=commitdiff_plain;h=26d3c4e813e1f5ae9836ac500a0ca35c9fa9386d;p=freqai-strategies.git refactor: improve a bit training reproducibility Signed-off-by: Jérôme Benoit --- diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index 1025268..6289765 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -261,6 +261,9 @@ class ReforceXY(BaseReinforcementLearningModel): model_params: Dict[str, Any] = copy.deepcopy(self.model_training_parameters) + if model_params.get("seed") is None: + model_params["seed"] = 42 + if self.lr_schedule: lr = model_params.get("learning_rate", 0.0003) if isinstance(lr, (int, float)): diff --git a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py index 68354f4..0208770 100644 --- a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py +++ b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py @@ -290,7 +290,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel): y_test = data_dictionary.get("test_labels") test_weights = data_dictionary.get("test_weights") - model_training_parameters = self.model_training_parameters + model_training_parameters = copy.deepcopy(self.model_training_parameters) start_time = time.time() if self._optuna_hyperopt: diff --git a/quickadapter/user_data/strategies/Utils.py b/quickadapter/user_data/strategies/Utils.py index 7400285..69101e3 100644 --- a/quickadapter/user_data/strategies/Utils.py +++ b/quickadapter/user_data/strategies/Utils.py @@ -743,6 +743,9 @@ def fit_regressor( init_model: Any = None, callbacks: Optional[list[Callable]] = None, ) -> Any: + if model_training_parameters.get("seed") is None: + model_training_parameters["seed"] = 1 + if regressor == "xgboost": from xgboost import XGBRegressor