from enum import IntEnum
+import hashlib
import logging
import json
from statistics import median
DOWN = -1
+zigzag_cache: dict[str, tuple[list[int], list[float], list[int]]] = {}
+
+
+def zigzag_cached(
+ df: pd.DataFrame,
+ natr_period: int = 14,
+ natr_ratio: float = 6.0,
+ cache_size: int = 2048,
+) -> tuple[list[int], list[float], list[int]]:
+ def hash_df(df: pd.DataFrame) -> str:
+ hasher = hashlib.sha256()
+
+ arr = df.to_numpy()
+
+ hasher.update(str(arr.shape).encode())
+ hasher.update(str(arr.dtype).encode())
+
+ hasher.update(arr.tobytes())
+
+ return hasher.hexdigest()
+
+ cache_key = f"{hash_df(df)}-{natr_period}-{natr_ratio}"
+ if cache_key in zigzag_cache:
+ return zigzag_cache[cache_key]
+
+ pivots_indices, pivots_values, pivots_directions = zigzag(
+ df, natr_period=natr_period, natr_ratio=natr_ratio
+ )
+ if len(zigzag_cache) >= cache_size:
+ del zigzag_cache[next(iter(zigzag_cache))]
+ zigzag_cache[cache_key] = (
+ pivots_indices,
+ pivots_values,
+ pivots_directions,
+ )
+ return pivots_indices, pivots_values, pivots_directions
+
+
def zigzag(
df: pd.DataFrame,
natr_period: int = 14,
if df.empty or n < max(natr_period, 2 * max_confirmation_window + 1):
return [], [], []
- natr_values_cache: dict[int, np.ndarray] = {}
-
- def get_natr_values(period: int) -> np.ndarray:
- if period not in natr_values_cache:
- natr_values_cache[period] = (
- ta.NATR(df, timeperiod=period).bfill() / 100.0
- ).to_numpy()
- return natr_values_cache[period]
+ natr_values = (ta.NATR(df, timeperiod=natr_period).bfill() / 100.0).to_numpy()
indices = df.index.tolist()
- thresholds = get_natr_values(natr_period) * natr_ratio
+ thresholds = natr_values * natr_ratio
closes = df.get("close").to_numpy()
highs = df.get("high").to_numpy()
lows = df.get("low").to_numpy()
if start >= end:
volatility_quantile_cache[pos] = np.nan
else:
- natr_values = get_natr_values(natr_period)
volatility_quantile_cache[pos] = calculate_quantile(
natr_values[start:end], natr_values[pos]
)
if df.empty:
return -np.inf, -np.inf
- _, pivots_values, _ = zigzag(
+ _, pivots_values, _ = zigzag_cached(
df,
natr_period=label_period_candles,
natr_ratio=label_natr_ratio,
from enum import IntEnum
from functools import lru_cache
+import hashlib
from statistics import median
import numpy as np
import pandas as pd
DOWN = -1
+zigzag_cache: dict[str, tuple[list[int], list[float], list[int]]] = {}
+
+
+def zigzag_cached(
+ df: pd.DataFrame,
+ natr_period: int = 14,
+ natr_ratio: float = 6.0,
+ cache_size: int = 2048,
+) -> tuple[list[int], list[float], list[int]]:
+ def hash_df(df: pd.DataFrame) -> str:
+ hasher = hashlib.sha256()
+
+ arr = df.to_numpy()
+
+ hasher.update(str(arr.shape).encode())
+ hasher.update(str(arr.dtype).encode())
+
+ hasher.update(arr.tobytes())
+
+ return hasher.hexdigest()
+
+ cache_key = f"{hash_df(df)}-{natr_period}-{natr_ratio}"
+ if cache_key in zigzag_cache:
+ return zigzag_cache[cache_key]
+
+ pivots_indices, pivots_values, pivots_directions = zigzag(
+ df, natr_period=natr_period, natr_ratio=natr_ratio
+ )
+ if len(zigzag_cache) >= cache_size:
+ del zigzag_cache[next(iter(zigzag_cache))]
+ zigzag_cache[cache_key] = (
+ pivots_indices,
+ pivots_values,
+ pivots_directions,
+ )
+ return pivots_indices, pivots_values, pivots_directions
+
+
def zigzag(
df: pd.DataFrame,
natr_period: int = 14,
if df.empty or n < max(natr_period, 2 * max_confirmation_window + 1):
return [], [], []
- natr_values_cache: dict[int, np.ndarray] = {}
-
- def get_natr_values(period: int) -> np.ndarray:
- if period not in natr_values_cache:
- natr_values_cache[period] = (
- ta.NATR(df, timeperiod=period).bfill() / 100.0
- ).to_numpy()
- return natr_values_cache[period]
+ natr_values = (ta.NATR(df, timeperiod=natr_period).bfill() / 100.0).to_numpy()
indices = df.index.tolist()
- thresholds = get_natr_values(natr_period) * natr_ratio
+ thresholds = natr_values * natr_ratio
closes = df.get("close").to_numpy()
highs = df.get("high").to_numpy()
lows = df.get("low").to_numpy()
if start >= end:
volatility_quantile_cache[pos] = np.nan
else:
- natr_values = get_natr_values(natr_period)
volatility_quantile_cache[pos] = calculate_quantile(
natr_values[start:end], natr_values[pos]
)