]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
fix(qav3): fix incorrect lru caching usage
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Fri, 20 Jun 2025 17:44:49 +0000 (19:44 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Fri, 20 Jun 2025 17:44:49 +0000 (19:44 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py
quickadapter/user_data/strategies/QuickAdapterV3.py

index 3cee27fd54a4ac33c6eeae1c3298ce591f892c7d..e2994e0a85a71755456a75d96a8aa1ba01f91b10 100644 (file)
@@ -88,7 +88,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
             raise ValueError(
                 "FreqAI model requires 'identifier' defined in the freqai section configuration"
             )
-        self._optuna_hyperopt: bool = (
+        self._optuna_hyperopt: bool | None = (
             self.freqai_info.get("enabled", False)
             and self._optuna_config.get("enabled")
             and self.data_split_parameters.get("test_size", TEST_SIZE) > 0
@@ -99,6 +99,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
         self._optuna_hp_params: dict[str, dict[str, Any]] = {}
         self._optuna_train_params: dict[str, dict[str, Any]] = {}
         self._optuna_label_params: dict[str, dict[str, Any]] = {}
+        self._optuna_label_candle_pool_cache: dict[tuple[int, int], list[int]] = {}
         self.init_optuna_label_candle_pool()
         self._optuna_label_candle: dict[str, int] = {}
         self._optuna_label_candles: dict[str, int] = {}
@@ -189,18 +190,21 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
         else:
             raise ValueError(f"Invalid namespace: {namespace}")
 
-    @lru_cache(maxsize=8)
     def build_optuna_label_candle_pool(self) -> list[int]:
         n_pairs = len(self.pairs)
         label_frequency_candles = max(
             2, 2 * n_pairs, int(self.ft_params.get("label_frequency_candles", 12))
         )
+        cache_key = (n_pairs, label_frequency_candles)
+        if cache_key in self._optuna_label_candle_pool_cache:
+            return self._optuna_label_candle_pool_cache[cache_key]
         min_offset = -int(label_frequency_candles / 2)
         max_offset = int(label_frequency_candles / 2)
-        return [
+        self._optuna_label_candle_pool_cache[cache_key] = [
             max(1, label_frequency_candles + offset)
             for offset in range(min_offset, max_offset + 1)
         ]
+        return self._optuna_label_candle_pool_cache[cache_key]
 
     def init_optuna_label_candle_pool(self) -> None:
         self._optuna_label_candle_pool = self.build_optuna_label_candle_pool()
@@ -1087,6 +1091,13 @@ def fit_regressor(
     return model
 
 
+@lru_cache(maxsize=128)
+def calculate_min_extrema(
+    length: int, fit_live_predictions_candles: int, min_extrema: int = 2
+) -> int:
+    return int(round((length / fit_live_predictions_candles) * min_extrema))
+
+
 def train_objective(
     trial: optuna.trial.Trial,
     regressor: str,
@@ -1101,12 +1112,6 @@ def train_objective(
     candles_step: int,
     model_training_parameters: dict[str, Any],
 ) -> float:
-    @lru_cache(maxsize=128)
-    def calculate_min_extrema(
-        length: int, fit_live_predictions_candles: int, min_extrema: int = 2
-    ) -> int:
-        return int(round((length / fit_live_predictions_candles) * min_extrema))
-
     test_ok = True
     test_length = len(X_test)
     if debug:
index 79364dbbd33016bfce28d02c9aac0e1d76713454..5a9ff18fe7efeaae3bdf3dffef7d4a2ed1986a51 100644 (file)
@@ -666,8 +666,9 @@ class QuickAdapterV3(IStrategy):
                 f"Invalid trade_price_target: {trade_price_target}. Expected 'interpolation', 'weighted_interpolation' or 'moving_average'."
             )
 
+    @staticmethod
     @lru_cache(maxsize=128)
-    def get_stoploss_log_factor(self, trade_duration_candles: int) -> float:
+    def get_stoploss_log_factor(trade_duration_candles: int) -> float:
         return 1 / math.log10(3.75 + 0.25 * trade_duration_candles)
 
     def get_stoploss_distance(
@@ -683,11 +684,12 @@ class QuickAdapterV3(IStrategy):
             current_rate
             * (trade_natr / 100.0)
             * self.get_stoploss_natr_ratio(trade.pair)
-            * self.get_stoploss_log_factor(trade_duration_candles)
+            * QuickAdapterV3.get_stoploss_log_factor(trade_duration_candles)
         )
 
+    @staticmethod
     @lru_cache(maxsize=128)
-    def get_take_profit_log_factor(self, trade_duration_candles: int) -> float:
+    def get_take_profit_log_factor(trade_duration_candles: int) -> float:
         return math.log10(9.75 + 0.25 * trade_duration_candles)
 
     def get_take_profit_distance(self, df: DataFrame, trade: Trade) -> Optional[float]:
@@ -701,7 +703,7 @@ class QuickAdapterV3(IStrategy):
             trade.open_rate
             * (trade_natr / 100.0)
             * self.get_take_profit_natr_ratio(trade.pair)
-            * self.get_take_profit_log_factor(trade_duration_candles)
+            * QuickAdapterV3.get_take_profit_log_factor(trade_duration_candles)
         )
 
     def throttle_callback(
@@ -858,7 +860,6 @@ class QuickAdapterV3(IStrategy):
             )
         return False
 
-    @lru_cache(maxsize=8)
     def max_open_trades_per_side(self) -> int:
         max_open_trades = self.config.get("max_open_trades")
         if max_open_trades < 0: