From d82d495c85cf50364b9454e6a88cd557d34a5e87 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Thu, 13 Mar 2025 19:41:38 +0100 Subject: [PATCH] refactor(qav3): cleanup label_period_candles handling MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- .../LightGBMRegressorQuickAdapterV35.py | 13 +++++++------ .../XGBoostRegressorQuickAdapterV35.py | 13 +++++++------ .../user_data/strategies/QuickAdapterV3.py | 19 ++++--------------- 3 files changed, 18 insertions(+), 27 deletions(-) diff --git a/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py b/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py index 35f40b6..291f3fe 100644 --- a/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py +++ b/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py @@ -152,6 +152,11 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel): return model + def get_label_period_candles(self, pair: str) -> int: + if self.__optuna_period_params.get(pair, {}).get("label_period_candles"): + return self.__optuna_period_params[pair]["label_period_candles"] + return self.ft_params["label_period_candles"] + def fit_live_predictions(self, dk: FreqaiDataKitchen, pair: str) -> None: warmed_up = True @@ -176,9 +181,7 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel): dk.data["extra_returns_per_train"][MINIMA_THRESHOLD_COLUMN] = -2 dk.data["extra_returns_per_train"][MAXIMA_THRESHOLD_COLUMN] = 2 else: - label_period_candles = self.__optuna_period_params.get(pair, {}).get( - "label_period_candles", self.ft_params["label_period_candles"] - ) + label_period_candles = self.get_label_period_candles(pair) min_pred, max_pred = self.min_max_pred( pred_df_full, num_candles, @@ -217,9 +220,7 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel): dk.data["extra_returns_per_train"]["DI_cutoff"] = cutoff dk.data["extra_returns_per_train"]["label_period_candles"] = ( - self.__optuna_period_params.get(pair, {}).get( - "label_period_candles", self.ft_params["label_period_candles"] - ) + self.get_label_period_candles(pair) ) dk.data["extra_returns_per_train"]["hp_rmse"] = self.__optuna_hp_rmse.get( pair, -1 diff --git a/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py b/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py index 29a5efa..3aa9abb 100644 --- a/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py +++ b/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py @@ -155,6 +155,11 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel): return model + def get_label_period_candles(self, pair: str) -> int: + if self.__optuna_period_params.get(pair, {}).get("label_period_candles"): + return self.__optuna_period_params[pair]["label_period_candles"] + return self.ft_params["label_period_candles"] + def fit_live_predictions(self, dk: FreqaiDataKitchen, pair: str) -> None: warmed_up = True @@ -179,9 +184,7 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel): dk.data["extra_returns_per_train"][MINIMA_THRESHOLD_COLUMN] = -2 dk.data["extra_returns_per_train"][MAXIMA_THRESHOLD_COLUMN] = 2 else: - label_period_candles = self.__optuna_period_params.get(pair, {}).get( - "label_period_candles", self.ft_params["label_period_candles"] - ) + label_period_candles = self.get_label_period_candles(pair) min_pred, max_pred = self.min_max_pred( pred_df_full, num_candles, @@ -220,9 +223,7 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel): dk.data["extra_returns_per_train"]["DI_cutoff"] = cutoff dk.data["extra_returns_per_train"]["label_period_candles"] = ( - self.__optuna_period_params.get(pair, {}).get( - "label_period_candles", self.ft_params["label_period_candles"] - ) + self.get_label_period_candles(pair) ) dk.data["extra_returns_per_train"]["hp_rmse"] = self.__optuna_hp_rmse.get( pair, -1 diff --git a/quickadapter/user_data/strategies/QuickAdapterV3.py b/quickadapter/user_data/strategies/QuickAdapterV3.py index eea84f8..3b2d02a 100644 --- a/quickadapter/user_data/strategies/QuickAdapterV3.py +++ b/quickadapter/user_data/strategies/QuickAdapterV3.py @@ -265,24 +265,13 @@ class QuickAdapterV3(IStrategy): dataframe["%-hour_of_day"] = (dataframe["date"].dt.hour + 1) / 25 return dataframe - def get_label_period_candles(self, metadata, **kwargs) -> int: - pair = str(metadata.get("pair")) + def get_label_period_candles(self, pair: str) -> int: if self.__period_params.get(pair, {}).get("label_period_candles"): - label_period_candles = self.__period_params.get(pair, {}).get( - "label_period_candles", - ) - else: - label_period_candles = self.freqai_info["feature_parameters"][ - "label_period_candles" - ] - if label_period_candles < 1: - raise ValueError( - f"label_period_candles must be greater than 0, got {label_period_candles}" - ) - return label_period_candles + return self.__period_params[pair]["label_period_candles"] + return self.freqai_info["feature_parameters"]["label_period_candles"] def set_freqai_targets(self, dataframe, metadata, **kwargs): - label_period_candles = self.get_label_period_candles(metadata, **kwargs) + label_period_candles = self.get_label_period_candles(str(metadata.get("pair"))) min_peaks = argrelmin( dataframe["low"].values, order=label_period_candles, -- 2.43.0