]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
feat(qav3): add weighted interpolation to compute SL/TP price targets
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Sat, 14 Jun 2025 10:14:39 +0000 (12:14 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Sat, 14 Jun 2025 10:14:39 +0000 (12:14 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
quickadapter/user_data/strategies/QuickAdapterV3.py

index 09e104d14bfe88d6ed111bfdadf9987edcd41c09..60db753ea96c32b422d97005850489349ea1a69d 100644 (file)
@@ -515,7 +515,7 @@ class QuickAdapterV3(IStrategy):
             isna(trade_duration) or trade_duration <= 0
         )
 
-    def get_trade_interpolation_natr(
+    def get_trade_weighted_interpolation_natr(
         self, df: DataFrame, trade: Trade
     ) -> Optional[float]:
         label_natr = df.get("natr_label_period_candles")
@@ -537,7 +537,64 @@ class QuickAdapterV3(IStrategy):
         if isna(current_natr) or current_natr < 0:
             return None
         median_natr = trade_label_natr.median()
-        interpolation_values = [current_natr, median_natr, entry_natr]
+
+        entry_quantile = calculate_quantile(trade_label_natr.to_numpy(), entry_natr)
+        current_quantile = calculate_quantile(trade_label_natr.to_numpy(), current_natr)
+        median_quantile = calculate_quantile(trade_label_natr.to_numpy(), median_natr)
+
+        if isna(entry_quantile) or isna(current_quantile) or isna(median_quantile):
+            return None
+
+        def calculate_weight(
+            quantile: float,
+            min_weight: float = 0.0,
+            max_weight: float = 1.0,
+            steepness: float = 1.5,
+        ) -> float:
+            normalized_distance_from_center = abs(quantile - 0.5) * 2.0
+            return (
+                min_weight
+                + (max_weight - min_weight) * normalized_distance_from_center**steepness
+            )
+
+        entry_weight = calculate_weight(entry_quantile)
+        current_weight = calculate_weight(current_quantile)
+        median_weight = calculate_weight(median_quantile)
+
+        total_weight = entry_weight + current_weight + median_weight
+        if np.isclose(total_weight, 0):
+            return None
+        entry_weight /= total_weight
+        current_weight /= total_weight
+        median_weight /= total_weight
+
+        return (
+            entry_natr * entry_weight
+            + current_natr * current_weight
+            + median_natr * median_weight
+        )
+
+    def get_trade_interpolation_natr(
+        self, df: DataFrame, trade: Trade
+    ) -> Optional[float]:
+        label_natr = df.get("natr_label_period_candles")
+        if label_natr is None or label_natr.empty:
+            return None
+        dates = df.get("date")
+        if dates is None or dates.empty:
+            return None
+        entry_date = self.get_trade_entry_date(trade)
+        trade_label_natr = label_natr[dates >= entry_date]
+        if trade_label_natr.empty:
+            return None
+        entry_natr = trade_label_natr.iloc[0]
+        if isna(entry_natr) or entry_natr < 0:
+            return None
+        if len(trade_label_natr) == 1:
+            return entry_natr
+        current_natr = trade_label_natr.iloc[-1]
+        if isna(current_natr) or current_natr < 0:
+            return None
         trade_volatility_quantile = calculate_quantile(
             trade_label_natr.to_numpy(), entry_natr
         )
@@ -545,8 +602,8 @@ class QuickAdapterV3(IStrategy):
             return None
         return np.interp(
             trade_volatility_quantile,
-            np.linspace(0.0, 1.0, len(interpolation_values)),
-            interpolation_values,
+            [0.0, 1.0],
+            [current_natr, entry_natr],
         )
 
     def get_trade_moving_average_natr(
@@ -582,13 +639,15 @@ class QuickAdapterV3(IStrategy):
         )
         if trade_price_target == "interpolation":
             return self.get_trade_interpolation_natr(df, trade)
+        elif trade_price_target == "weighted_interpolation":
+            return self.get_trade_weighted_interpolation_natr(df, trade)
         elif trade_price_target == "moving_average":
             return self.get_trade_moving_average_natr(
                 df, trade.pair, trade_duration_candles
             )
         else:
             raise ValueError(
-                f"Invalid trade_price_target: {trade_price_target}. Expected 'interpolation' or 'moving_average'."
+                f"Invalid trade_price_target: {trade_price_target}. Expected 'interpolation', 'weighted_interpolation' or 'moving_average'."
             )
 
     def get_stoploss_distance(