From: Jérôme Benoit Date: Sun, 8 Jun 2025 20:56:41 +0000 (+0200) Subject: perf(qav3): add pivots labeling cache X-Git-Url: https://git.piment-noir.org/?a=commitdiff_plain;h=fa9d8ba27d60aab9d53656da649393fd7df5b3bd;p=freqai-strategies.git perf(qav3): add pivots labeling cache Signed-off-by: Jérôme Benoit --- diff --git a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py index f6f7a64..5716990 100644 --- a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py +++ b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py @@ -1,4 +1,5 @@ from enum import IntEnum +import hashlib import logging import json from statistics import median @@ -1111,6 +1112,44 @@ class TrendDirection(IntEnum): DOWN = -1 +zigzag_cache: dict[str, tuple[list[int], list[float], list[int]]] = {} + + +def zigzag_cached( + df: pd.DataFrame, + natr_period: int = 14, + natr_ratio: float = 6.0, + cache_size: int = 2048, +) -> tuple[list[int], list[float], list[int]]: + def hash_df(df: pd.DataFrame) -> str: + hasher = hashlib.sha256() + + arr = df.to_numpy() + + hasher.update(str(arr.shape).encode()) + hasher.update(str(arr.dtype).encode()) + + hasher.update(arr.tobytes()) + + return hasher.hexdigest() + + cache_key = f"{hash_df(df)}-{natr_period}-{natr_ratio}" + if cache_key in zigzag_cache: + return zigzag_cache[cache_key] + + pivots_indices, pivots_values, pivots_directions = zigzag( + df, natr_period=natr_period, natr_ratio=natr_ratio + ) + if len(zigzag_cache) >= cache_size: + del zigzag_cache[next(iter(zigzag_cache))] + zigzag_cache[cache_key] = ( + pivots_indices, + pivots_values, + pivots_directions, + ) + return pivots_indices, pivots_values, pivots_directions + + def zigzag( df: pd.DataFrame, natr_period: int = 14, @@ -1122,17 +1161,10 @@ def zigzag( if df.empty or n < max(natr_period, 2 * max_confirmation_window + 1): return [], [], [] - natr_values_cache: dict[int, np.ndarray] = {} - - def get_natr_values(period: int) -> np.ndarray: - if period not in natr_values_cache: - natr_values_cache[period] = ( - ta.NATR(df, timeperiod=period).bfill() / 100.0 - ).to_numpy() - return natr_values_cache[period] + natr_values = (ta.NATR(df, timeperiod=natr_period).bfill() / 100.0).to_numpy() indices = df.index.tolist() - thresholds = get_natr_values(natr_period) * natr_ratio + thresholds = natr_values * natr_ratio closes = df.get("close").to_numpy() highs = df.get("high").to_numpy() lows = df.get("low").to_numpy() @@ -1156,7 +1188,6 @@ def zigzag( if start >= end: volatility_quantile_cache[pos] = np.nan else: - natr_values = get_natr_values(natr_period) volatility_quantile_cache[pos] = calculate_quantile( natr_values[start:end], natr_values[pos] ) @@ -1422,7 +1453,7 @@ def label_objective( if df.empty: return -np.inf, -np.inf - _, pivots_values, _ = zigzag( + _, pivots_values, _ = zigzag_cached( df, natr_period=label_period_candles, natr_ratio=label_natr_ratio, diff --git a/quickadapter/user_data/strategies/Utils.py b/quickadapter/user_data/strategies/Utils.py index 75043b6..f14a01f 100644 --- a/quickadapter/user_data/strategies/Utils.py +++ b/quickadapter/user_data/strategies/Utils.py @@ -1,5 +1,6 @@ from enum import IntEnum from functools import lru_cache +import hashlib from statistics import median import numpy as np import pandas as pd @@ -377,6 +378,44 @@ class TrendDirection(IntEnum): DOWN = -1 +zigzag_cache: dict[str, tuple[list[int], list[float], list[int]]] = {} + + +def zigzag_cached( + df: pd.DataFrame, + natr_period: int = 14, + natr_ratio: float = 6.0, + cache_size: int = 2048, +) -> tuple[list[int], list[float], list[int]]: + def hash_df(df: pd.DataFrame) -> str: + hasher = hashlib.sha256() + + arr = df.to_numpy() + + hasher.update(str(arr.shape).encode()) + hasher.update(str(arr.dtype).encode()) + + hasher.update(arr.tobytes()) + + return hasher.hexdigest() + + cache_key = f"{hash_df(df)}-{natr_period}-{natr_ratio}" + if cache_key in zigzag_cache: + return zigzag_cache[cache_key] + + pivots_indices, pivots_values, pivots_directions = zigzag( + df, natr_period=natr_period, natr_ratio=natr_ratio + ) + if len(zigzag_cache) >= cache_size: + del zigzag_cache[next(iter(zigzag_cache))] + zigzag_cache[cache_key] = ( + pivots_indices, + pivots_values, + pivots_directions, + ) + return pivots_indices, pivots_values, pivots_directions + + def zigzag( df: pd.DataFrame, natr_period: int = 14, @@ -388,17 +427,10 @@ def zigzag( if df.empty or n < max(natr_period, 2 * max_confirmation_window + 1): return [], [], [] - natr_values_cache: dict[int, np.ndarray] = {} - - def get_natr_values(period: int) -> np.ndarray: - if period not in natr_values_cache: - natr_values_cache[period] = ( - ta.NATR(df, timeperiod=period).bfill() / 100.0 - ).to_numpy() - return natr_values_cache[period] + natr_values = (ta.NATR(df, timeperiod=natr_period).bfill() / 100.0).to_numpy() indices = df.index.tolist() - thresholds = get_natr_values(natr_period) * natr_ratio + thresholds = natr_values * natr_ratio closes = df.get("close").to_numpy() highs = df.get("high").to_numpy() lows = df.get("low").to_numpy() @@ -422,7 +454,6 @@ def zigzag( if start >= end: volatility_quantile_cache[pos] = np.nan else: - natr_values = get_natr_values(natr_period) volatility_quantile_cache[pos] = calculate_quantile( natr_values[start:end], natr_values[pos] )