]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
perf(qav3): add pivots labeling cache
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Sun, 8 Jun 2025 20:56:41 +0000 (22:56 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Sun, 8 Jun 2025 20:56:41 +0000 (22:56 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py
quickadapter/user_data/strategies/Utils.py

index f6f7a64fd40d4ad9c9f38cda09685e59dfa857b5..57169908d7d46dd6ac9a5e30f691b245bb439a6e 100644 (file)
@@ -1,4 +1,5 @@
 from enum import IntEnum
+import hashlib
 import logging
 import json
 from statistics import median
@@ -1111,6 +1112,44 @@ class TrendDirection(IntEnum):
     DOWN = -1
 
 
+zigzag_cache: dict[str, tuple[list[int], list[float], list[int]]] = {}
+
+
+def zigzag_cached(
+    df: pd.DataFrame,
+    natr_period: int = 14,
+    natr_ratio: float = 6.0,
+    cache_size: int = 2048,
+) -> tuple[list[int], list[float], list[int]]:
+    def hash_df(df: pd.DataFrame) -> str:
+        hasher = hashlib.sha256()
+
+        arr = df.to_numpy()
+
+        hasher.update(str(arr.shape).encode())
+        hasher.update(str(arr.dtype).encode())
+
+        hasher.update(arr.tobytes())
+
+        return hasher.hexdigest()
+
+    cache_key = f"{hash_df(df)}-{natr_period}-{natr_ratio}"
+    if cache_key in zigzag_cache:
+        return zigzag_cache[cache_key]
+
+    pivots_indices, pivots_values, pivots_directions = zigzag(
+        df, natr_period=natr_period, natr_ratio=natr_ratio
+    )
+    if len(zigzag_cache) >= cache_size:
+        del zigzag_cache[next(iter(zigzag_cache))]
+    zigzag_cache[cache_key] = (
+        pivots_indices,
+        pivots_values,
+        pivots_directions,
+    )
+    return pivots_indices, pivots_values, pivots_directions
+
+
 def zigzag(
     df: pd.DataFrame,
     natr_period: int = 14,
@@ -1122,17 +1161,10 @@ def zigzag(
     if df.empty or n < max(natr_period, 2 * max_confirmation_window + 1):
         return [], [], []
 
-    natr_values_cache: dict[int, np.ndarray] = {}
-
-    def get_natr_values(period: int) -> np.ndarray:
-        if period not in natr_values_cache:
-            natr_values_cache[period] = (
-                ta.NATR(df, timeperiod=period).bfill() / 100.0
-            ).to_numpy()
-        return natr_values_cache[period]
+    natr_values = (ta.NATR(df, timeperiod=natr_period).bfill() / 100.0).to_numpy()
 
     indices = df.index.tolist()
-    thresholds = get_natr_values(natr_period) * natr_ratio
+    thresholds = natr_values * natr_ratio
     closes = df.get("close").to_numpy()
     highs = df.get("high").to_numpy()
     lows = df.get("low").to_numpy()
@@ -1156,7 +1188,6 @@ def zigzag(
             if start >= end:
                 volatility_quantile_cache[pos] = np.nan
             else:
-                natr_values = get_natr_values(natr_period)
                 volatility_quantile_cache[pos] = calculate_quantile(
                     natr_values[start:end], natr_values[pos]
                 )
@@ -1422,7 +1453,7 @@ def label_objective(
     if df.empty:
         return -np.inf, -np.inf
 
-    _, pivots_values, _ = zigzag(
+    _, pivots_values, _ = zigzag_cached(
         df,
         natr_period=label_period_candles,
         natr_ratio=label_natr_ratio,
index 75043b60e7ea92c6db8e01321ffed000480d8945..f14a01f5b856967ac376b9fd4366214575e966ea 100644 (file)
@@ -1,5 +1,6 @@
 from enum import IntEnum
 from functools import lru_cache
+import hashlib
 from statistics import median
 import numpy as np
 import pandas as pd
@@ -377,6 +378,44 @@ class TrendDirection(IntEnum):
     DOWN = -1
 
 
+zigzag_cache: dict[str, tuple[list[int], list[float], list[int]]] = {}
+
+
+def zigzag_cached(
+    df: pd.DataFrame,
+    natr_period: int = 14,
+    natr_ratio: float = 6.0,
+    cache_size: int = 2048,
+) -> tuple[list[int], list[float], list[int]]:
+    def hash_df(df: pd.DataFrame) -> str:
+        hasher = hashlib.sha256()
+
+        arr = df.to_numpy()
+
+        hasher.update(str(arr.shape).encode())
+        hasher.update(str(arr.dtype).encode())
+
+        hasher.update(arr.tobytes())
+
+        return hasher.hexdigest()
+
+    cache_key = f"{hash_df(df)}-{natr_period}-{natr_ratio}"
+    if cache_key in zigzag_cache:
+        return zigzag_cache[cache_key]
+
+    pivots_indices, pivots_values, pivots_directions = zigzag(
+        df, natr_period=natr_period, natr_ratio=natr_ratio
+    )
+    if len(zigzag_cache) >= cache_size:
+        del zigzag_cache[next(iter(zigzag_cache))]
+    zigzag_cache[cache_key] = (
+        pivots_indices,
+        pivots_values,
+        pivots_directions,
+    )
+    return pivots_indices, pivots_values, pivots_directions
+
+
 def zigzag(
     df: pd.DataFrame,
     natr_period: int = 14,
@@ -388,17 +427,10 @@ def zigzag(
     if df.empty or n < max(natr_period, 2 * max_confirmation_window + 1):
         return [], [], []
 
-    natr_values_cache: dict[int, np.ndarray] = {}
-
-    def get_natr_values(period: int) -> np.ndarray:
-        if period not in natr_values_cache:
-            natr_values_cache[period] = (
-                ta.NATR(df, timeperiod=period).bfill() / 100.0
-            ).to_numpy()
-        return natr_values_cache[period]
+    natr_values = (ta.NATR(df, timeperiod=natr_period).bfill() / 100.0).to_numpy()
 
     indices = df.index.tolist()
-    thresholds = get_natr_values(natr_period) * natr_ratio
+    thresholds = natr_values * natr_ratio
     closes = df.get("close").to_numpy()
     highs = df.get("high").to_numpy()
     lows = df.get("low").to_numpy()
@@ -422,7 +454,6 @@ def zigzag(
             if start >= end:
                 volatility_quantile_cache[pos] = np.nan
             else:
-                natr_values = get_natr_values(natr_period)
                 volatility_quantile_cache[pos] = calculate_quantile(
                     natr_values[start:end], natr_values[pos]
                 )