https://github.com/sponsors/robcaulk
"""
- version = "3.7.82"
+ version = "3.7.83"
@cached_property
def _optuna_config(self) -> dict:
self._optuna_hp_params: dict[str, dict] = {}
self._optuna_train_params: dict[str, dict] = {}
self._optuna_label_params: dict[str, dict] = {}
+ self.init_optuna_label_candle_pool()
self._optuna_label_candles: dict[str, int] = {}
self._optuna_label_candle: dict[str, int] = {}
for pair in self.pairs:
}
)
self._optuna_label_candles[pair] = 0
- self._optuna_label_candle[pair] = self.get_optuna_label_candle()
+ self.set_optuna_label_candle(pair)
logger.info(
f"Initialized {self.__class__.__name__} {self.freqai_info.get('regressor', 'xgboost')} regressor model version {self.version}"
else:
raise ValueError(f"Invalid namespace: {namespace}")
- def get_optuna_label_candle(self) -> int:
+ def get_optuna_label_all_candles(self) -> list[int]:
+ n_pairs = len(self.pairs)
label_frequency_candles = max(
- 2, int(self.ft_params.get("label_frequency_candles", 12))
+ 2, n_pairs, int(self.ft_params.get("label_frequency_candles", 12))
)
- random_offset = random.randint(
- -label_frequency_candles // 2, label_frequency_candles // 2
+ min_offset = -int(label_frequency_candles / 2)
+ max_offset = int(label_frequency_candles / 2)
+ return [
+ max(1, label_frequency_candles + offset)
+ for offset in range(min_offset, max_offset + 1)
+ ]
+
+ def init_optuna_label_candle_pool(self) -> None:
+ self._optuna_label_candle_pool = self.get_optuna_label_all_candles()
+ random.shuffle(self._optuna_label_candle_pool)
+ if len(self._optuna_label_candle_pool) == 0:
+ raise RuntimeError("Failed to initialize optuna label candle pool")
+
+ def set_optuna_label_candle(self, pair: str) -> None:
+ if len(self._optuna_label_candle_pool) == 0:
+ self.init_optuna_label_candle_pool()
+ self._optuna_label_candle[pair] = self._optuna_label_candle_pool.pop()
+ optuna_label_available_candles = (
+ set(self.get_optuna_label_all_candles())
+ - set(self._optuna_label_candle_pool)
+ - set(self._optuna_label_candle.values())
)
- return max(1, label_frequency_candles + random_offset)
+ if len(optuna_label_available_candles) > 0:
+ self._optuna_label_candle_pool.extend(optuna_label_available_candles)
+ random.shuffle(self._optuna_label_candle_pool)
def fit(self, data_dictionary: dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
"""
)
finally:
self._optuna_label_candles[pair] = 0
- self._optuna_label_candle[pair] = self.get_optuna_label_candle()
+ self.set_optuna_label_candle(pair)
else:
logger.info(
f"Optuna {pair} {namespace} callback throttled, still {self._optuna_label_candle[pair] - self._optuna_label_candles[pair]} candles to go"
candles_step: int,
model_training_parameters: dict,
) -> float:
- min_test_window: int = fit_live_predictions_candles * 2
+ min_test_window: int = fit_live_predictions_candles * 4
if len(X_test) < min_test_window:
logger.warning(f"Insufficient test data: {len(X_test)} < {min_test_window}")
return np.inf