From 82ddd80b807d63e91f64706184f8fb5077342fe3 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Mon, 10 Feb 2025 22:56:56 +0100 Subject: [PATCH] fix(qav3): filter prediction dataframe MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- .../LightGBMRegressorQuickAdapterV35.py | 20 +++++++++++++------ .../XGBoostRegressorQuickAdapterV35.py | 20 +++++++++++++------ 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py b/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py index a02cb20..fa7b436 100644 --- a/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py +++ b/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py @@ -242,11 +242,16 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel): def min_max_pred( pred_df: pd.DataFrame, fit_live_predictions_candles: int, label_period_candles: int ): + local_pred_df = pd.DataFrame() + for label in pred_df: + if pred_df[label].dtype == object: + continue + local_pred_df[label] = pred_df[label] beta = 10.0 - min_pred = pred_df.tail(label_period_candles).apply( + min_pred = local_pred_df.tail(label_period_candles).apply( lambda col: smooth_min(col, beta=beta) ) - max_pred = pred_df.tail(label_period_candles).apply( + max_pred = local_pred_df.tail(label_period_candles).apply( lambda col: smooth_max(col, beta=beta) ) @@ -256,10 +261,13 @@ def min_max_pred( def __min_max_pred( pred_df: pd.DataFrame, fit_live_predictions_candles: int, label_period_candles: int ): - pred_df_sorted = ( - pred_df.select_dtypes(exclude=["object"]) - .copy() - .apply(lambda col: col.sort_values(ascending=False, ignore_index=True)) + pred_df_sorted = pd.DataFrame() + for label in pred_df: + if pred_df[label].dtype == object: + continue + pred_df_sorted[label] = pred_df[label] + pred_df_sorted = pred_df_sorted.apply( + lambda col: col.sort_values(ascending=False, ignore_index=True) ) frequency = fit_live_predictions_candles / label_period_candles diff --git a/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py b/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py index 22c7776..bfe6062 100644 --- a/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py +++ b/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py @@ -242,11 +242,16 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel): def min_max_pred( pred_df: pd.DataFrame, fit_live_predictions_candles: int, label_period_candles: int ): + local_pred_df = pd.DataFrame() + for label in pred_df: + if pred_df[label].dtype == object: + continue + local_pred_df[label] = pred_df[label] beta = 10.0 - min_pred = pred_df.tail(label_period_candles).apply( + min_pred = local_pred_df.tail(label_period_candles).apply( lambda col: smooth_min(col, beta=beta) ) - max_pred = pred_df.tail(label_period_candles).apply( + max_pred = local_pred_df.tail(label_period_candles).apply( lambda col: smooth_max(col, beta=beta) ) @@ -256,10 +261,13 @@ def min_max_pred( def __min_max_pred( pred_df: pd.DataFrame, fit_live_predictions_candles: int, label_period_candles: int ): - pred_df_sorted = ( - pred_df.select_dtypes(exclude=["object"]) - .copy() - .apply(lambda col: col.sort_values(ascending=False, ignore_index=True)) + pred_df_sorted = pd.DataFrame() + for label in pred_df: + if pred_df[label].dtype == object: + continue + pred_df_sorted[label] = pred_df[label] + pred_df_sorted = pred_df_sorted.apply( + lambda col: col.sort_values(ascending=False, ignore_index=True) ) frequency = fit_live_predictions_candles / label_period_candles -- 2.43.0