From 84f4eabcd091fb138abd83635ce9ba52b3779f8b Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Fri, 11 Apr 2025 17:40:00 +0200 Subject: [PATCH] refactor(qav3): align optuna results handling namespace MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- .../freqaimodels/QuickAdapterRegressorV3.py | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py index f6878e5..414610d 100644 --- a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py +++ b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py @@ -86,15 +86,15 @@ class QuickAdapterRegressorV3(BaseRegressionModel): and self._optuna_config.get("enabled") and self.data_split_parameters.get("test_size", TEST_SIZE) > 0 ) - self._optuna_hp_rmse: dict[str, float] = {} - self._optuna_train_rmse: dict[str, float] = {} + self._optuna_hp_value: dict[str, float] = {} + self._optuna_train_value: dict[str, float] = {} self._optuna_label_values: dict[str, dict] = {} self._optuna_hp_params: dict[str, dict] = {} self._optuna_train_params: dict[str, dict] = {} self._optuna_label_params: dict[str, dict] = {} for pair in self.pairs: - self._optuna_hp_rmse[pair] = -1 - self._optuna_train_rmse[pair] = -1 + self._optuna_hp_value[pair] = -1 + self._optuna_train_value[pair] = -1 self._optuna_label_values[pair] = [-1, -1] self._optuna_hp_params[pair] = ( self.optuna_load_best_params(pair, "hp") @@ -141,20 +141,20 @@ class QuickAdapterRegressorV3(BaseRegressionModel): else: raise ValueError(f"Invalid namespace: {namespace}") - def get_optuna_rmse(self, pair: str, namespace: str) -> float: + def get_optuna_value(self, pair: str, namespace: str) -> float: if namespace == "hp": - rmse = self._optuna_hp_rmse.get(pair) + rmse = self._optuna_hp_value.get(pair) elif namespace == "train": - rmse = self._optuna_train_rmse.get(pair) + rmse = self._optuna_train_value.get(pair) else: raise ValueError(f"Invalid namespace: {namespace}") return rmse - def set_optuna_rmse(self, pair: str, namespace: str, rmse: float) -> None: + def set_optuna_value(self, pair: str, namespace: str, value: float) -> None: if namespace == "hp": - self._optuna_hp_rmse[pair] = rmse + self._optuna_hp_value[pair] = value elif namespace == "train": - self._optuna_train_rmse[pair] = rmse + self._optuna_train_value[pair] = value else: raise ValueError(f"Invalid namespace: {namespace}") @@ -351,8 +351,10 @@ class QuickAdapterRegressorV3(BaseRegressionModel): pair, "label" ).get("label_natr_ratio") - dk.data["extra_returns_per_train"]["hp_rmse"] = self.get_optuna_rmse(pair, "hp") - dk.data["extra_returns_per_train"]["train_rmse"] = self.get_optuna_rmse( + dk.data["extra_returns_per_train"]["hp_rmse"] = self.get_optuna_value( + pair, "hp" + ) + dk.data["extra_returns_per_train"]["train_rmse"] = self.get_optuna_value( pair, "train" ) @@ -488,10 +490,10 @@ class QuickAdapterRegressorV3(BaseRegressionModel): f"Optuna {pair} {namespace} {objective_type} hyperopt failed ({time_spent:.2f} secs): no study best trial found" ) return - self.set_optuna_rmse(pair, namespace, study.best_value) + self.set_optuna_value(pair, namespace, study.best_value) self.set_optuna_params(pair, namespace, study.best_params) study_results = { - "rmse": self.get_optuna_rmse(pair, namespace), + "value": self.get_optuna_value(pair, namespace), **self.get_optuna_params(pair, namespace), } else: -- 2.43.0