From 7fbfdd93722be34db275e6d752acc5bb7c5c444f Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Fri, 27 Jun 2025 20:30:45 +0200 Subject: [PATCH] fix(qav3): properly propagate expansion_factor tunable MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- .../user_data/freqaimodels/QuickAdapterRegressorV3.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py index 1276aca..431238a 100644 --- a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py +++ b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py @@ -295,6 +295,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel): test_weights, self.get_optuna_params(dk.pair, "hp"), model_training_parameters, + self._optuna_config.get("expansion_factor"), ), direction=optuna.study.StudyDirection.MINIMIZE, ) @@ -1221,6 +1222,7 @@ def get_optuna_study_model_parameters( trial: optuna.trial.Trial, regressor: str, model_training_best_parameters: dict[str, Any], + expansion_factor: float, ) -> dict[str, Any]: if regressor not in regressors: raise ValueError( @@ -1241,7 +1243,6 @@ def get_optuna_study_model_parameters( } ranges = copy.deepcopy(default_ranges) - expansion_factor = self._optuna_config.get("expansion_factor") if model_training_best_parameters: for param, (default_min, default_max) in default_ranges.items(): center_value = model_training_best_parameters.get(param) @@ -1350,9 +1351,10 @@ def hp_objective( test_weights: np.ndarray, model_training_best_parameters: dict[str, Any], model_training_parameters: dict[str, Any], + expansion_factor: float, ) -> float: study_model_parameters = get_optuna_study_model_parameters( - trial, regressor, model_training_best_parameters + trial, regressor, model_training_best_parameters, expansion_factor ) model_training_parameters = {**model_training_parameters, **study_model_parameters} -- 2.43.0