]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor(qav3): make thresholding computation API more generic
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Fri, 15 Aug 2025 20:15:24 +0000 (22:15 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Fri, 15 Aug 2025 20:15:24 +0000 (22:15 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
quickadapter/user_data/strategies/QuickAdapterV3.py

index 7e7a213e3779dcf22e39e7577f1cf60d11a9d75a..af173955e548347b9ebd48b26fe8ba8f26376b44 100644 (file)
@@ -1004,72 +1004,97 @@ class QuickAdapterV3(IStrategy):
     def weighted_close(series: Series) -> float:
         return (series.get("high") + series.get("low") + 2 * series.get("close")) / 4.0
 
-    def _calculate_current_deviation(
+    def _calculate_candle_deviation(
         self,
         df: DataFrame,
         pair: str,
         min_natr_ratio_percent: float,
         max_natr_ratio_percent: float,
+        candle_idx: int = -1,
         interpolation_direction: Literal["direct", "inverse"] = "direct",
         quantile_exponent: float = 1.5,
     ) -> Optional[float]:
-        label_natr_values = df.get("natr_label_period_candles").to_numpy()
+        label_natr_series = df.get("natr_label_period_candles")
+        if label_natr_series is None or label_natr_series.empty:
+            return None
+
+        n = len(label_natr_series)
+        if candle_idx < 0:
+            candle_idx = n + candle_idx
+        candle_idx = max(0, min(candle_idx, n - 1))
+
+        label_natr_values = label_natr_series.iloc[: candle_idx + 1].to_numpy()
+        if label_natr_values.size == 0:
+            return None
+        candle_label_natr_value = label_natr_values[-1]
+        if isna(candle_label_natr_value) or candle_label_natr_value < 0:
+            return None
         label_period_candles = self.get_label_period_candles(pair)
-        last_label_natr_value = label_natr_values[-1]
-        last_label_natr_value_quantile = calculate_quantile(
-            label_natr_values[-label_period_candles:], last_label_natr_value
+        candle_label_natr_value_quantile = calculate_quantile(
+            label_natr_values[-label_period_candles:], candle_label_natr_value
         )
-        if isna(last_label_natr_value_quantile):
-            last_label_natr_value_quantile = 0.5
+        if isna(candle_label_natr_value_quantile):
+            return None
+
         if interpolation_direction == "direct":
             natr_ratio_percent = (
                 min_natr_ratio_percent
                 + (max_natr_ratio_percent - min_natr_ratio_percent)
-                * last_label_natr_value_quantile**quantile_exponent
+                * candle_label_natr_value_quantile**quantile_exponent
             )
         elif interpolation_direction == "inverse":
             natr_ratio_percent = (
                 max_natr_ratio_percent
                 - (max_natr_ratio_percent - min_natr_ratio_percent)
-                * last_label_natr_value_quantile**quantile_exponent
+                * candle_label_natr_value_quantile**quantile_exponent
             )
         else:
             raise ValueError(
                 f"Invalid interpolation_direction: {interpolation_direction}. Expected 'direct' or 'inverse'"
             )
-        return (last_label_natr_value / 100.0) * self.get_label_natr_ratio_percent(
+        return (candle_label_natr_value / 100.0) * self.get_label_natr_ratio_percent(
             pair, natr_ratio_percent
         )
 
-    def calculate_current_threshold(self, df: DataFrame, pair: str, side: str) -> float:
-        current_deviation = self._calculate_current_deviation(
+    def calculate_candle_threshold(
+        self, df: DataFrame, pair: str, side: str, candle_idx: int = -1
+    ) -> float:
+        current_deviation = self._calculate_candle_deviation(
             df,
             pair,
             min_natr_ratio_percent=0.00999,
             max_natr_ratio_percent=0.099,
+            candle_idx=candle_idx,
             interpolation_direction="direct",
         )
-        if isna(current_deviation):
-            return np.inf if side == "short" else -np.inf
+        if isna(current_deviation) or current_deviation <= 0:
+            return np.nan
 
-        last_candle = df.iloc[-1]
-        last_candle_close = last_candle.get("close")
-        last_candle_open = last_candle.get("open")
-        is_last_candle_bullish = last_candle_close > last_candle_open
-        is_last_candle_bearish = last_candle_close < last_candle_open
+        n = len(df)
+        if candle_idx < 0:
+            candle_idx = n + candle_idx
+        candle_idx = max(0, min(candle_idx, n - 1))
+
+        candle = df.iloc[candle_idx]
+        candle_close = candle.get("close")
+        candle_open = candle.get("open")
+        if isna(candle_close) or isna(candle_open):
+            return np.nan
+        is_candle_bullish: bool = candle_close > candle_open
+        is_candle_bearish: bool = candle_close < candle_open
 
         if side == "long":
             base_price = (
-                QuickAdapterV3.weighted_close(last_candle)
-                if is_last_candle_bearish
-                else last_candle_close
+                QuickAdapterV3.weighted_close(candle)
+                if is_candle_bearish
+                else candle_close
             )
             return base_price * (1 + current_deviation)
         elif side == "short":
             base_price = (
-                QuickAdapterV3.weighted_close(last_candle)
-                if is_last_candle_bullish
-                else last_candle_close
+                QuickAdapterV3.weighted_close(candle)
+                if is_candle_bullish
+                else candle_close
             )
             return base_price * (1 - current_deviation)
 
@@ -1248,7 +1273,7 @@ class QuickAdapterV3(IStrategy):
             and last_candle.get("do_predict") == 1
             and last_candle.get("DI_catch") == 1
             and last_candle.get(EXTREMA_COLUMN) < last_candle.get("minima_threshold")
-            and current_rate > self.calculate_current_threshold(df, pair, "long")
+            and current_rate > self.calculate_candle_threshold(df, pair, "long")
         ):
             return "minima_detected_short"
         if (
@@ -1256,7 +1281,7 @@ class QuickAdapterV3(IStrategy):
             and last_candle.get("do_predict") == 1
             and last_candle.get("DI_catch") == 1
             and last_candle.get(EXTREMA_COLUMN) > last_candle.get("maxima_threshold")
-            and current_rate < self.calculate_current_threshold(df, pair, "short")
+            and current_rate < self.calculate_candle_threshold(df, pair, "short")
         ):
             return "maxima_detected_long"
 
@@ -1367,7 +1392,7 @@ class QuickAdapterV3(IStrategy):
         )
         if df.empty:
             return False
-        current_threshold = self.calculate_current_threshold(df, pair, side)
+        current_threshold = self.calculate_candle_threshold(df, pair, side)
         if (side == "long" and rate > current_threshold) or (
             side == "short" and rate < current_threshold
         ):