import json
import random
from statistics import median
+import threading
import time
import numpy as np
import pandas as pd
and self._optuna_config.get("enabled")
and self.data_split_parameters.get("test_size", TEST_SIZE) > 0
)
+ self._optuna_locks = {
+ "label": threading.RLock(),
+ "throttle": threading.RLock(),
+ }
self._optuna_hp_value: dict[str, float] = {}
self._optuna_train_value: dict[str, float] = {}
self._optuna_label_values: dict[str, list] = {}
raise RuntimeError("Failed to initialize optuna label candle pool")
def set_optuna_label_candle(self, pair: str) -> None:
- if len(self._optuna_label_candle_pool) == 0:
- self.init_optuna_label_candle_pool()
- self._optuna_label_candle[pair] = self._optuna_label_candle_pool.pop()
- optuna_label_available_candles = (
- set(self.get_optuna_label_all_candles())
- - set(self._optuna_label_candle_pool)
- - set(self._optuna_label_candle.values())
- )
- if len(optuna_label_available_candles) > 0:
- self._optuna_label_candle_pool.extend(optuna_label_available_candles)
- random.shuffle(self._optuna_label_candle_pool)
+ with self._optuna_locks.get("label"):
+ if len(self._optuna_label_candle_pool) == 0:
+ self.init_optuna_label_candle_pool()
+ self._optuna_label_candle[pair] = self._optuna_label_candle_pool.pop()
+ optuna_label_available_candles = (
+ set(self.get_optuna_label_all_candles())
+ - set(self._optuna_label_candle_pool)
+ - set(self._optuna_label_candle.values())
+ )
+ if len(optuna_label_available_candles) > 0:
+ self._optuna_label_candle_pool.extend(optuna_label_available_candles)
+ random.shuffle(self._optuna_label_candle_pool)
def fit(self, data_dictionary: dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
"""
) -> None:
if namespace != "label":
raise ValueError(f"Invalid namespace: {namespace}")
- self._optuna_label_candles[pair] += 1
- if self._optuna_label_candles[pair] >= self._optuna_label_candle[pair]:
- try:
- callback()
- except Exception as e:
- logger.error(
- f"Error executing optuna {pair} {namespace} callback: {str(e)}",
- exc_info=True,
+ with self._optuna_locks.get("throttle"):
+ self._optuna_label_candles[pair] += 1
+ if self._optuna_label_candles[pair] >= self._optuna_label_candle[pair]:
+ try:
+ callback()
+ except Exception as e:
+ logger.error(
+ f"Error executing optuna {pair} {namespace} callback: {str(e)}",
+ exc_info=True,
+ )
+ finally:
+ self._optuna_label_candles[pair] = 0
+ self.set_optuna_label_candle(pair)
+ else:
+ logger.info(
+ f"Optuna {pair} {namespace} callback throttled, still {self._optuna_label_candle[pair] - self._optuna_label_candles[pair]} candles to go"
)
- finally:
- self._optuna_label_candles[pair] = 0
- self.set_optuna_label_candle(pair)
- else:
- logger.info(
- f"Optuna {pair} {namespace} callback throttled, still {self._optuna_label_candle[pair] - self._optuna_label_candles[pair]} candles to go"
- )
def fit_live_predictions(self, dk: FreqaiDataKitchen, pair: str) -> None:
warmed_up = True