]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
perf(qav3): add caching at pivot confirmation
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Sat, 24 May 2025 13:05:11 +0000 (15:05 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Sat, 24 May 2025 13:05:11 +0000 (15:05 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py
quickadapter/user_data/strategies/QuickAdapterV3.py
quickadapter/user_data/strategies/Utils.py

index 0c7a6b0a16e6b4b205174a12d783e9746752a79b..1aea149411d215301b768ab91716b6fcc4354bbd 100644 (file)
@@ -45,7 +45,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
     https://github.com/sponsors/robcaulk
     """
 
-    version = "3.7.58"
+    version = "3.7.59"
 
     @cached_property
     def _optuna_config(self) -> dict:
@@ -913,29 +913,33 @@ def zigzag(
     candidate_pivot_value = np.nan
     candidate_pivot_direction: TrendDirection = TrendDirection.NEUTRAL
 
-    def volatility_quantile(pos: int) -> float:
-        start = max(0, pos + 1 - natr_period)
-        end = min(pos + 1, n)
-        if start >= end:
-            return np.nan
+    volatility_quantile_cache: dict[int, float] = {}
 
-        natr_values = get_natr_values(natr_period)
-        lookback_natr_values = natr_values[start:end]
-        quantile = calculate_quantile(lookback_natr_values, natr_values[pos])
+    def calculate_volatility_quantile(pos: int) -> float:
+        if pos not in volatility_quantile_cache:
+            start = max(0, pos + 1 - natr_period)
+            end = min(pos + 1, n)
+            if start >= end:
+                volatility_quantile_cache[pos] = np.nan
+            else:
+                natr_values = get_natr_values(natr_period)
+                volatility_quantile_cache[pos] = calculate_quantile(
+                    natr_values[start:end], natr_values[pos]
+                )
 
-        return quantile
+        return volatility_quantile_cache[pos]
 
     def calculate_confirmation_window(
         pos: int,
         min_window: int = min_confirmation_window,
         max_window: int = max_confirmation_window,
     ) -> int:
-        quantile = volatility_quantile(pos)
-        if np.isnan(quantile):
+        volatility_quantile = calculate_volatility_quantile(pos)
+        if np.isnan(volatility_quantile):
             return int(round(np.median([min_window, max_window])))
 
         return np.clip(
-            round(max_window - (max_window - min_window) * quantile),
+            round(max_window - (max_window - min_window) * volatility_quantile),
             min_window,
             max_window,
         ).astype(int)
@@ -945,12 +949,12 @@ def zigzag(
         min_depth: int = 6,
         max_depth: int = 24,
     ) -> int:
-        quantile = volatility_quantile(pos)
-        if np.isnan(quantile):
+        volatility_quantile = calculate_volatility_quantile(pos)
+        if np.isnan(volatility_quantile):
             return int(round(np.median([min_depth, max_depth])))
 
         return np.clip(
-            round(max_depth - (max_depth - min_depth) * quantile),
+            round(max_depth - (max_depth - min_depth) * volatility_quantile),
             min_depth,
             max_depth,
         ).astype(int)
@@ -960,11 +964,11 @@ def zigzag(
         min_strength: float = 1.0,
         max_strength: float = 1.5,
     ) -> float:
-        quantile = volatility_quantile(pos)
-        if np.isnan(quantile):
+        volatility_quantile = calculate_volatility_quantile(pos)
+        if np.isnan(volatility_quantile):
             return np.median([min_strength, max_strength])
 
-        return min_strength + (max_strength - min_strength) * quantile
+        return min_strength + (max_strength - min_strength) * volatility_quantile
 
     def update_candidate_pivot(pos: int, value: float, direction: TrendDirection):
         nonlocal candidate_pivot_pos, candidate_pivot_value, candidate_pivot_direction
index b2d992193b0bcdfb4cda7e755e97782bc50c78db..0c50b4396df9af39c51281911b16bccbd85e6de3 100644 (file)
@@ -19,7 +19,7 @@ from Utils import (
     alligator,
     bottom_change_percent,
     get_ma_fn,
-    zero_lag_series,
+    zero_lag,
     zigzag,
     ewo,
     non_zero_diff,
@@ -60,7 +60,7 @@ class QuickAdapterV3(IStrategy):
     INTERFACE_VERSION = 3
 
     def version(self) -> str:
-        return "3.3.60"
+        return "3.3.61"
 
     timeframe = "5m"
 
@@ -474,32 +474,16 @@ class QuickAdapterV3(IStrategy):
     def is_trade_duration_valid(trade_duration: float) -> bool:
         return not (isna(trade_duration) or trade_duration <= 0)
 
-    def get_stoploss_distance(
-        self, df: DataFrame, trade: Trade, current_rate: float
-    ) -> Optional[float]:
-        trade_duration_candles = QuickAdapterV3.get_trade_duration_candles(df, trade)
-        if not QuickAdapterV3.is_trade_duration_valid(trade_duration_candles):
-            return None
-        current_natr = df["natr_label_period_candles"].iloc[-1]
-        if isna(current_natr) or current_natr < 0:
-            return None
-        return (
-            current_rate
-            * (current_natr / 100.0)
-            * self.get_stoploss_natr_ratio(trade.pair)
-            * (1 / math.log10(3.75 + 0.25 * trade_duration_candles))
-        )
-
-    def get_take_profit_distance(self, df: DataFrame, trade: Trade) -> Optional[float]:
-        trade_duration_candles = QuickAdapterV3.get_trade_duration_candles(df, trade)
+    @staticmethod
+    def get_trade_natr(df: DataFrame, trade_duration_candles: int) -> Optional[float]:
         if not QuickAdapterV3.is_trade_duration_valid(trade_duration_candles):
             return None
-        trade_zl_natr = zero_lag_series(
+        trade_zl_natr = zero_lag(
             df["natr_label_period_candles"], period=trade_duration_candles
         )
         if trade_zl_natr.empty:
             return None
-        take_profit_natr = np.nan
+        trade_natr = np.nan
         if trade_duration_candles >= 2:
             kama = get_ma_fn("kama")
             try:
@@ -510,21 +494,39 @@ class QuickAdapterV3(IStrategy):
                     ~np.isnan(trade_kama_natr_values)
                 ]
                 if trade_kama_natr_values.size > 0:
-                    take_profit_natr = trade_kama_natr_values[-1]
+                    trade_natr = trade_kama_natr_values[-1]
             except Exception as e:
-                logger.error(
-                    f"Failed to calculate KAMA at take profit price computation: {str(e)}",
-                    exc_info=True,
-                )
-        if isna(take_profit_natr):
-            take_profit_natr = (
-                trade_zl_natr.ewm(span=trade_duration_candles).mean().iloc[-1]
-            )
-        if isna(take_profit_natr) or take_profit_natr < 0:
+                logger.error(f"Failed to calculate KAMA: {str(e)}", exc_info=True)
+        if isna(trade_natr):
+            trade_natr = trade_zl_natr.ewm(span=trade_duration_candles).mean().iloc[-1]
+        return trade_natr
+
+    def get_stoploss_distance(
+        self, df: DataFrame, trade: Trade, current_rate: float
+    ) -> Optional[float]:
+        trade_duration_candles = QuickAdapterV3.get_trade_duration_candles(df, trade)
+        if not QuickAdapterV3.is_trade_duration_valid(trade_duration_candles):
+            return None
+        trade_natr = QuickAdapterV3.get_trade_natr(df, trade_duration_candles)
+        if isna(trade_natr) or trade_natr < 0:
+            return None
+        return (
+            current_rate
+            * (trade_natr / 100.0)
+            * self.get_stoploss_natr_ratio(trade.pair)
+            * (1 / math.log10(3.75 + 0.25 * trade_duration_candles))
+        )
+
+    def get_take_profit_distance(self, df: DataFrame, trade: Trade) -> Optional[float]:
+        trade_duration_candles = QuickAdapterV3.get_trade_duration_candles(df, trade)
+        if not QuickAdapterV3.is_trade_duration_valid(trade_duration_candles):
+            return None
+        trade_natr = QuickAdapterV3.get_trade_natr(df, trade_duration_candles)
+        if isna(trade_natr) or trade_natr < 0:
             return None
         return (
             trade.open_rate
-            * (take_profit_natr / 100.0)
+            * (trade_natr / 100.0)
             * self.get_take_profit_natr_ratio(trade.pair)
             * math.log10(9.75 + 0.25 * trade_duration_candles)
         )
index 685339a8ef0a41c80cf5a3fb124bb3265dcd3c06..481ed9a03d3abbe62e89d68f95b93d372dc4eec3 100644 (file)
@@ -127,7 +127,7 @@ def vwapb(dataframe: pd.DataFrame, window=20, num_of_std=1) -> tuple:
     return vwap_low, vwap, vwap_high
 
 
-def zero_lag_series(series: pd.Series, period: int) -> pd.Series:
+def zero_lag(series: pd.Series, period: int) -> pd.Series:
     """Applies a zero lag filter to reduce MA lag."""
     lag = max(int(0.5 * (period - 1)), 0)
     if lag == 0:
@@ -135,7 +135,7 @@ def zero_lag_series(series: pd.Series, period: int) -> pd.Series:
     return 2 * series - series.shift(lag)
 
 
-def get_ma_fn(mamode: str) -> Callable[[pd.Series, int], pd.Series]:
+def get_ma_fn(mamode: str) -> Callable[[pd.Series, int], np.ndarray]:
     mamodes: dict = {
         "sma": ta.SMA,
         "ema": ta.EMA,
@@ -190,9 +190,9 @@ def frama(df: pd.DataFrame, period: int = 16, zero_lag=False) -> pd.Series:
     closes = df["close"]
 
     if zero_lag:
-        highs = zero_lag_series(highs, period=period)
-        lows = zero_lag_series(lows, period=period)
-        closes = zero_lag_series(closes, period=period)
+        highs = zero_lag(highs, period=period)
+        lows = zero_lag(lows, period=period)
+        closes = zero_lag(closes, period=period)
 
     fd = pd.Series(np.nan, index=closes.index)
     for i in range(period, n):
@@ -227,7 +227,7 @@ def smma(series: pd.Series, period: int, zero_lag=False, offset=0) -> pd.Series:
         return pd.Series(index=series.index, dtype=float)
 
     if zero_lag:
-        series = zero_lag_series(series, period=period)
+        series = zero_lag(series, period=period)
     smma = pd.Series(np.nan, index=series.index)
     smma.iloc[period - 1] = series.iloc[:period].mean()
 
@@ -263,21 +263,21 @@ def ewo(
     """
     Calculate the Elliott Wave Oscillator (EWO) using two moving averages.
     """
-    price_series = get_price_fn(pricemode)(dataframe)
+    prices = get_price_fn(pricemode)(dataframe)
 
     if zero_lag:
-        price_series_ma1 = zero_lag_series(price_series, period=ma1_length)
-        price_series_ma2 = zero_lag_series(price_series, period=ma2_length)
+        prices_ma1 = zero_lag(prices, period=ma1_length)
+        prices_ma2 = zero_lag(prices, period=ma2_length)
     else:
-        price_series_ma1 = price_series
-        price_series_ma2 = price_series
+        prices_ma1 = prices
+        prices_ma2 = prices
 
     ma_fn = get_ma_fn(mamode)
-    ma1 = ma_fn(price_series_ma1, timeperiod=ma1_length)
-    ma2 = ma_fn(price_series_ma2, timeperiod=ma2_length)
+    ma1 = ma_fn(prices_ma1, timeperiod=ma1_length)
+    ma2 = ma_fn(prices_ma2, timeperiod=ma2_length)
     madiff = ma1 - ma2
     if normalize:
-        madiff = (madiff / price_series) * 100.0
+        madiff = (madiff / prices) * 100.0
     return madiff
 
 
@@ -295,34 +295,30 @@ def alligator(
     """
     Calculate Bill Williams' Alligator indicator lines.
     """
-    price_series = get_price_fn(pricemode)(df)
+    prices = get_price_fn(pricemode)(df)
 
-    jaw = smma(price_series, period=jaw_period, zero_lag=zero_lag, offset=jaw_shift)
-    teeth = smma(
-        price_series, period=teeth_period, zero_lag=zero_lag, offset=teeth_shift
-    )
-    lips = smma(price_series, period=lips_period, zero_lag=zero_lag, offset=lips_shift)
+    jaw = smma(prices, period=jaw_period, zero_lag=zero_lag, offset=jaw_shift)
+    teeth = smma(prices, period=teeth_period, zero_lag=zero_lag, offset=teeth_shift)
+    lips = smma(prices, period=lips_period, zero_lag=zero_lag, offset=lips_shift)
 
     return jaw, teeth, lips
 
 
-def find_fractals(
-    df: pd.DataFrame, fractal_period: int = 2
-) -> tuple[list[int], list[int]]:
+def find_fractals(df: pd.DataFrame, period: int = 2) -> tuple[list[int], list[int]]:
     n = len(df)
-    if n < 2 * fractal_period + 1:
+    if n < 2 * period + 1:
         return [], []
 
     highs = df["high"].values
     lows = df["low"].values
 
-    fractal_candidate_indices = np.arange(fractal_period, n - fractal_period)
+    fractal_candidate_indices = np.arange(period, n - period)
 
     fractal_candidate_indices_length = len(fractal_candidate_indices)
     is_fractal_high = np.ones(fractal_candidate_indices_length, dtype=bool)
     is_fractal_low = np.ones(fractal_candidate_indices_length, dtype=bool)
 
-    for i in range(1, fractal_period + 1):
+    for i in range(1, period + 1):
         is_fractal_high &= (
             highs[fractal_candidate_indices] > highs[fractal_candidate_indices - i]
         ) & (highs[fractal_candidate_indices] > highs[fractal_candidate_indices + i])
@@ -397,29 +393,33 @@ def zigzag(
     candidate_pivot_value = np.nan
     candidate_pivot_direction: TrendDirection = TrendDirection.NEUTRAL
 
-    def volatility_quantile(pos: int) -> float:
-        start = max(0, pos + 1 - natr_period)
-        end = min(pos + 1, n)
-        if start >= end:
-            return np.nan
+    volatility_quantile_cache: dict[int, float] = {}
 
-        natr_values = get_natr_values(natr_period)
-        lookback_natr_values = natr_values[start:end]
-        quantile = calculate_quantile(lookback_natr_values, natr_values[pos])
+    def calculate_volatility_quantile(pos: int) -> float:
+        if pos not in volatility_quantile_cache:
+            start = max(0, pos + 1 - natr_period)
+            end = min(pos + 1, n)
+            if start >= end:
+                volatility_quantile_cache[pos] = np.nan
+            else:
+                natr_values = get_natr_values(natr_period)
+                volatility_quantile_cache[pos] = calculate_quantile(
+                    natr_values[start:end], natr_values[pos]
+                )
 
-        return quantile
+        return volatility_quantile_cache[pos]
 
     def calculate_confirmation_window(
         pos: int,
         min_window: int = min_confirmation_window,
         max_window: int = max_confirmation_window,
     ) -> int:
-        quantile = volatility_quantile(pos)
-        if np.isnan(quantile):
+        volatility_quantile = calculate_volatility_quantile(pos)
+        if np.isnan(volatility_quantile):
             return int(round(np.median([min_window, max_window])))
 
         return np.clip(
-            round(max_window - (max_window - min_window) * quantile),
+            round(max_window - (max_window - min_window) * volatility_quantile),
             min_window,
             max_window,
         ).astype(int)
@@ -429,12 +429,12 @@ def zigzag(
         min_depth: int = 6,
         max_depth: int = 24,
     ) -> int:
-        quantile = volatility_quantile(pos)
-        if np.isnan(quantile):
+        volatility_quantile = calculate_volatility_quantile(pos)
+        if np.isnan(volatility_quantile):
             return int(round(np.median([min_depth, max_depth])))
 
         return np.clip(
-            round(max_depth - (max_depth - min_depth) * quantile),
+            round(max_depth - (max_depth - min_depth) * volatility_quantile),
             min_depth,
             max_depth,
         ).astype(int)
@@ -444,11 +444,11 @@ def zigzag(
         min_strength: float = 1.0,
         max_strength: float = 1.5,
     ) -> float:
-        quantile = volatility_quantile(pos)
-        if np.isnan(quantile):
+        volatility_quantile = calculate_volatility_quantile(pos)
+        if np.isnan(volatility_quantile):
             return np.median([min_strength, max_strength])
 
-        return min_strength + (max_strength - min_strength) * quantile
+        return min_strength + (max_strength - min_strength) * volatility_quantile
 
     def update_candidate_pivot(pos: int, value: float, direction: TrendDirection):
         nonlocal candidate_pivot_pos, candidate_pivot_value, candidate_pivot_direction