]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor(qav3): cleanup data smoothing code
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Fri, 14 Mar 2025 10:04:38 +0000 (11:04 +0100)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Fri, 14 Mar 2025 10:04:38 +0000 (11:04 +0100)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py
quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py
quickadapter/user_data/strategies/QuickAdapterV3.py

index 291f3fefe03b58fdda30c5355984c7680a8a0904..b02c148ba4c99dceaf89060a7437c863bc7e4bc8 100644 (file)
@@ -262,13 +262,14 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel):
         prediction_thresholds_smoothing = self.freqai_info.get(
             "prediction_thresholds_smoothing", "mean"
         )
-        return {
+        smoothing_methods: dict = {
             "quantile": self.quantile_min_max_pred,
             "mean": mean_min_max_pred,
             "median": median_min_max_pred,
-        }.get(prediction_thresholds_smoothing, mean_min_max_pred)(
-            pred_df, fit_live_predictions_candles, label_period_candles
-        )
+        }
+        return smoothing_methods.get(
+            prediction_thresholds_smoothing, smoothing_methods["mean"]
+        )(pred_df, fit_live_predictions_candles, label_period_candles)
 
     def optuna_hp_enqueue_previous_best_trial(
         self,
index 3aa9abb8b87ed8630a943f564ec9d2010f8e23d1..54f5c4f3aa9a5fe4b3130cfb29445ad24f233e46 100644 (file)
@@ -265,13 +265,14 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel):
         prediction_thresholds_smoothing = self.freqai_info.get(
             "prediction_thresholds_smoothing", "mean"
         )
-        return {
+        smoothing_methods: dict = {
             "quantile": self.quantile_min_max_pred,
             "mean": mean_min_max_pred,
             "median": median_min_max_pred,
-        }.get(prediction_thresholds_smoothing, mean_min_max_pred)(
-            pred_df, fit_live_predictions_candles, label_period_candles
-        )
+        }
+        return smoothing_methods.get(
+            prediction_thresholds_smoothing, smoothing_methods["mean"]
+        )(pred_df, fit_live_predictions_candles, label_period_candles)
 
     def optuna_hp_enqueue_previous_best_trial(
         self,
index 3b2d02afed49900a2e05809bb79f6df9df9f7b09..8235af639f786bdb61f201225e09ff611e20b688 100644 (file)
@@ -97,7 +97,7 @@ class QuickAdapterV3(IStrategy):
         }
 
     @property
-    def protections(self) -> list:
+    def protections(self) -> list[dict]:
         fit_live_predictions_candles = self.freqai_info.get(
             "fit_live_predictions_candles", 100
         )
@@ -166,7 +166,7 @@ class QuickAdapterV3(IStrategy):
         dataframe["%-pct-change"] = dataframe["close"].pct_change()
         dataframe["%-raw_volume"] = dataframe["volume"]
         dataframe["%-obv"] = ta.OBV(dataframe)
-        dataframe["%-ewo"] = EWO(dataframe=dataframe, mode="zlewma", normalize=True)
+        dataframe["%-ewo"] = EWO(dataframe=dataframe, ma_mode="zlewma", normalize=True)
         psar = ta.SAR(
             dataframe["high"], dataframe["low"], acceleration=0.02, maximum=0.2
         )
@@ -299,11 +299,10 @@ class QuickAdapterV3(IStrategy):
             1,
         )
 
-        if "label_period_candles" in dataframe.columns:
-            pair = str(metadata.get("pair"))
-            self.__period_params[pair]["label_period_candles"] = dataframe[
-                "label_period_candles"
-            ].iloc[-1]
+        pair = str(metadata.get("pair"))
+        self.__period_params[pair]["label_period_candles"] = dataframe[
+            "label_period_candles"
+        ].iloc[-1]
 
         dataframe["minima_threshold"] = dataframe[MINIMA_THRESHOLD_COLUMN]
         dataframe["maxima_threshold"] = dataframe[MAXIMA_THRESHOLD_COLUMN]
@@ -433,32 +432,27 @@ class QuickAdapterV3(IStrategy):
         std: float = 0.5,
     ) -> Series:
         extrema_smoothing = self.freqai_info.get("extrema_smoothing", "gaussian")
-        return {
-            "gaussian": (
-                series.rolling(
-                    window=get_gaussian_window(std, True),
-                    win_type="gaussian",
-                    center=True,
-                ).mean(std=std)
-            ),
+        smoothing_methods: dict = {
+            "gaussian": series.rolling(
+                window=get_gaussian_window(std, True),
+                win_type="gaussian",
+                center=True,
+            ).mean(std=std),
             "zero_phase_gaussian": zero_phase_gaussian(series=series, std=std),
             "boxcar": series.rolling(
                 window=get_odd_window(window), win_type="boxcar", center=True
             ).mean(),
-            "triang": (
-                series.rolling(
-                    window=get_odd_window(window), win_type="triang", center=True
-                ).mean()
-            ),
+            "triang": series.rolling(
+                window=get_odd_window(window), win_type="triang", center=True
+            ).mean(),
             "smm": series.rolling(window=get_odd_window(window), center=True).median(),
             "sma": series.rolling(window=get_odd_window(window), center=True).mean(),
             "ewma": series.ewm(span=window).mean(),
             "zlewma": zlewma(series=series, timeperiod=window),
-        }.get(
+        }
+        return smoothing_methods.get(
             extrema_smoothing,
-            series.rolling(
-                window=get_gaussian_window(std, True), win_type="gaussian", center=True
-            ).mean(std=std),
+            smoothing_methods["gaussian"],
         )
 
     def load_period_best_params(self, pair: str) -> dict | None:
@@ -498,9 +492,9 @@ def VWAPB(dataframe: DataFrame, window=20, num_of_std=1) -> tuple:
 
 
 def EWO(
-    dataframe: DataFrame, ma1_length=5, ma2_length=34, mode="sma", normalize=False
+    dataframe: DataFrame, ma1_length=5, ma2_length=34, ma_mode="sma", normalize=False
 ) -> Series:
-    ma_fn = {
+    ma_modes: dict = {
         "sma": ta.SMA,
         "ema": ta.EMA,
         "wma": ta.WMA,
@@ -510,7 +504,8 @@ def EWO(
         "trima": ta.TRIMA,
         "kama": ta.KAMA,
         "t3": ta.T3,
-    }.get(mode, ta.SMA)
+    }
+    ma_fn = ma_modes.get(ma_mode, ma_modes["sma"])
     ma1 = ma_fn(dataframe, timeperiod=ma1_length)
     ma2 = ma_fn(dataframe, timeperiod=ma2_length)
     madiff = ma1 - ma2