From: Jérôme Benoit Date: Thu, 13 Mar 2025 18:41:38 +0000 (+0100) Subject: refactor(qav3): cleanup label_period_candles handling X-Git-Url: https://git.piment-noir.org/?a=commitdiff_plain;h=d82d495c85cf50364b9454e6a88cd557d34a5e87;p=freqai-strategies.git refactor(qav3): cleanup label_period_candles handling Signed-off-by: Jérôme Benoit --- diff --git a/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py b/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py index 35f40b6..291f3fe 100644 --- a/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py +++ b/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py @@ -152,6 +152,11 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel): return model + def get_label_period_candles(self, pair: str) -> int: + if self.__optuna_period_params.get(pair, {}).get("label_period_candles"): + return self.__optuna_period_params[pair]["label_period_candles"] + return self.ft_params["label_period_candles"] + def fit_live_predictions(self, dk: FreqaiDataKitchen, pair: str) -> None: warmed_up = True @@ -176,9 +181,7 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel): dk.data["extra_returns_per_train"][MINIMA_THRESHOLD_COLUMN] = -2 dk.data["extra_returns_per_train"][MAXIMA_THRESHOLD_COLUMN] = 2 else: - label_period_candles = self.__optuna_period_params.get(pair, {}).get( - "label_period_candles", self.ft_params["label_period_candles"] - ) + label_period_candles = self.get_label_period_candles(pair) min_pred, max_pred = self.min_max_pred( pred_df_full, num_candles, @@ -217,9 +220,7 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel): dk.data["extra_returns_per_train"]["DI_cutoff"] = cutoff dk.data["extra_returns_per_train"]["label_period_candles"] = ( - self.__optuna_period_params.get(pair, {}).get( - "label_period_candles", self.ft_params["label_period_candles"] - ) + self.get_label_period_candles(pair) ) dk.data["extra_returns_per_train"]["hp_rmse"] = self.__optuna_hp_rmse.get( pair, -1 diff --git a/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py b/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py index 29a5efa..3aa9abb 100644 --- a/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py +++ b/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py @@ -155,6 +155,11 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel): return model + def get_label_period_candles(self, pair: str) -> int: + if self.__optuna_period_params.get(pair, {}).get("label_period_candles"): + return self.__optuna_period_params[pair]["label_period_candles"] + return self.ft_params["label_period_candles"] + def fit_live_predictions(self, dk: FreqaiDataKitchen, pair: str) -> None: warmed_up = True @@ -179,9 +184,7 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel): dk.data["extra_returns_per_train"][MINIMA_THRESHOLD_COLUMN] = -2 dk.data["extra_returns_per_train"][MAXIMA_THRESHOLD_COLUMN] = 2 else: - label_period_candles = self.__optuna_period_params.get(pair, {}).get( - "label_period_candles", self.ft_params["label_period_candles"] - ) + label_period_candles = self.get_label_period_candles(pair) min_pred, max_pred = self.min_max_pred( pred_df_full, num_candles, @@ -220,9 +223,7 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel): dk.data["extra_returns_per_train"]["DI_cutoff"] = cutoff dk.data["extra_returns_per_train"]["label_period_candles"] = ( - self.__optuna_period_params.get(pair, {}).get( - "label_period_candles", self.ft_params["label_period_candles"] - ) + self.get_label_period_candles(pair) ) dk.data["extra_returns_per_train"]["hp_rmse"] = self.__optuna_hp_rmse.get( pair, -1 diff --git a/quickadapter/user_data/strategies/QuickAdapterV3.py b/quickadapter/user_data/strategies/QuickAdapterV3.py index eea84f8..3b2d02a 100644 --- a/quickadapter/user_data/strategies/QuickAdapterV3.py +++ b/quickadapter/user_data/strategies/QuickAdapterV3.py @@ -265,24 +265,13 @@ class QuickAdapterV3(IStrategy): dataframe["%-hour_of_day"] = (dataframe["date"].dt.hour + 1) / 25 return dataframe - def get_label_period_candles(self, metadata, **kwargs) -> int: - pair = str(metadata.get("pair")) + def get_label_period_candles(self, pair: str) -> int: if self.__period_params.get(pair, {}).get("label_period_candles"): - label_period_candles = self.__period_params.get(pair, {}).get( - "label_period_candles", - ) - else: - label_period_candles = self.freqai_info["feature_parameters"][ - "label_period_candles" - ] - if label_period_candles < 1: - raise ValueError( - f"label_period_candles must be greater than 0, got {label_period_candles}" - ) - return label_period_candles + return self.__period_params[pair]["label_period_candles"] + return self.freqai_info["feature_parameters"]["label_period_candles"] def set_freqai_targets(self, dataframe, metadata, **kwargs): - label_period_candles = self.get_label_period_candles(metadata, **kwargs) + label_period_candles = self.get_label_period_candles(str(metadata.get("pair"))) min_peaks = argrelmin( dataframe["low"].values, order=label_period_candles,