From: Jérôme Benoit Date: Sat, 15 Nov 2025 11:39:41 +0000 (+0100) Subject: feat: add optuna auto sampler support X-Git-Url: https://git.piment-noir.org/?a=commitdiff_plain;h=5e8e59bc234f10381b8d57b4e4ed465ff48c8f4e;p=freqai-strategies.git feat: add optuna auto sampler support Signed-off-by: Jérôme Benoit --- diff --git a/README.md b/README.md index 5e9f313..300c3ee 100644 --- a/README.md +++ b/README.md @@ -80,12 +80,13 @@ docker compose up -d --build | freqai.outlier_threshold | 0.999 | float (0,1) | Quantile threshold for predictions outlier filtering. | | _Optuna / HPO_ | | | | | freqai.optuna_hyperopt.enabled | true | bool | Enables HPO. | -| freqai.optuna_hyperopt.n_jobs | CPU threads / 4 | int >= 1 | Parallel HPO workers. | +| freqai.optuna_hyperopt.sampler | `tpe` | enum {`tpe`,`auto`} | HPO sampler algorithm. `tpe` uses TPESampler with multivariate and group, `auto` uses AutoSampler. | | freqai.optuna_hyperopt.storage | `file` | enum {`file`,`sqlite`} | HPO storage backend. | | freqai.optuna_hyperopt.continuous | true | bool | Continuous HPO. | | freqai.optuna_hyperopt.warm_start | true | bool | Warm start HPO with previous best value(s). | | freqai.optuna_hyperopt.n_startup_trials | 15 | int >= 0 | HPO startup trials. | | freqai.optuna_hyperopt.n_trials | 50 | int >= 1 | Maximum HPO trials. | +| freqai.optuna_hyperopt.n_jobs | CPU threads / 4 | int >= 1 | Parallel HPO workers. | | freqai.optuna_hyperopt.timeout | 7200 | int >= 0 | HPO wall-clock timeout in seconds. | | freqai.optuna_hyperopt.label_candles_step | 1 | int >= 1 | Step for Zigzag NATR horizon search space. | | freqai.optuna_hyperopt.train_candles_step | 10 | int >= 1 | Step for training sets size search space. | diff --git a/ReforceXY/Dockerfile.reforcexy b/ReforceXY/Dockerfile.reforcexy index fa44fd1..d1ddb55 100644 --- a/ReforceXY/Dockerfile.reforcexy +++ b/ReforceXY/Dockerfile.reforcexy @@ -2,9 +2,16 @@ FROM freqtradeorg/freqtrade:stable_freqairl ARG optuna_version=4.6.0 +USER root +RUN apt-get update \ + && apt-get install -y --no-install-recommends git \ + && rm -rf /var/lib/apt/lists/* +USER ftuser RUN pip install --user --no-cache-dir \ optuna==${optuna_version} \ - optuna-dashboard + optuna-dashboard \ + optunahub \ + -r https://hub.optuna.org/samplers/auto_sampler/requirements.txt LABEL org.opencontainers.image.source="freqai-strategies" \ org.opencontainers.image.title="freqtrade-reforcexy" \ diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index 3f85b50..d808431 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -14,6 +14,7 @@ import matplotlib import matplotlib.pyplot as plt import matplotlib.transforms as mtransforms import numpy as np +import optunahub import torch as th from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions @@ -28,8 +29,8 @@ from matplotlib.lines import Line2D from numpy.typing import NDArray from optuna import Trial, TrialPruned, create_study, delete_study from optuna.exceptions import ExperimentalWarning -from optuna.pruners import HyperbandPruner -from optuna.samplers import TPESampler +from optuna.pruners import BasePruner, HyperbandPruner +from optuna.samplers import BaseSampler, TPESampler from optuna.storages import ( BaseStorage, JournalStorage, @@ -101,9 +102,9 @@ class ReforceXY(BaseReinforcementLearningModel): "timeout_hours": 0, // Maximum time in hours for hyperopt (0 = no timeout) "continuous": false, // If true, perform continuous optimization "warm_start": false, // If true, enqueue previous best params if exists - "seed": 42, // RNG seed + "sampler": "tpe", // Optuna sampler (tpe|auto) "storage": "sqlite", // Optuna storage backend (sqlite|file) - } + "seed": 42, // RNG seed } } Requirements: @@ -942,6 +943,34 @@ class ReforceXY(BaseReinforcementLearningModel): except (ValueError, KeyError): return False + def create_sampler(self) -> BaseSampler: + sampler = self.rl_config_optuna.get("sampler", "tpe") + if sampler == "auto": + return optunahub.load_module("samplers/auto_sampler").AutoSampler( + seed=self.rl_config_optuna.get("seed", 42) + ) + elif sampler == "tpe": + return TPESampler( + n_startup_trials=self.optuna_n_startup_trials, + multivariate=True, + group=True, + seed=self.rl_config_optuna.get("seed", 42), + ) + else: + raise ValueError( + f"Unsupported sampler: '{sampler}'. Supported samplers: 'tpe', 'auto'" + ) + + @staticmethod + def create_pruner( + min_resource: int, max_resource: int, reduction_factor: int + ) -> BasePruner: + return HyperbandPruner( + min_resource=min_resource, + max_resource=max_resource, + reduction_factor=reduction_factor, + ) + def optimize( self, dk: FreqaiDataKitchen, total_timesteps: int ) -> Optional[Dict[str, Any]]: @@ -963,18 +992,12 @@ class ReforceXY(BaseReinforcementLearningModel): reduction_factor * 2, (total_timesteps // self.n_envs) // resource_eval_freq ) min_resource = max(1, max_resource // reduction_factor) + study: Study = create_study( study_name=study_name, - sampler=TPESampler( - n_startup_trials=self.optuna_n_startup_trials, - multivariate=True, - group=True, - seed=self.rl_config_optuna.get("seed", 42), - ), - pruner=HyperbandPruner( - min_resource=min_resource, - max_resource=max_resource, - reduction_factor=reduction_factor, + sampler=self.create_sampler(), + pruner=ReforceXY.create_pruner( + min_resource, max_resource, reduction_factor ), direction=StudyDirection.MAXIMIZE, storage=storage, diff --git a/quickadapter/Dockerfile.quickadapter b/quickadapter/Dockerfile.quickadapter index 1e85327..2aaef0b 100644 --- a/quickadapter/Dockerfile.quickadapter +++ b/quickadapter/Dockerfile.quickadapter @@ -6,13 +6,15 @@ ARG scikit_image_version=0.25.2 USER root RUN apt-get update \ - && apt-get install -y --no-install-recommends build-essential \ + && apt-get install -y --no-install-recommends build-essential git \ && rm -rf /var/lib/apt/lists/* USER ftuser RUN pip install --user --no-cache-dir \ optuna==${optuna_version} \ optuna-integration==${optuna_version} \ optuna-dashboard \ + optunahub \ + -r https://hub.optuna.org/samplers/auto_sampler/requirements.txt \ scikit-learn-extra==${scikit_learn_extra_version} \ scikit-image==${scikit_image_version} diff --git a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py index 3d74a78..d05c1ee 100644 --- a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py +++ b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py @@ -10,6 +10,7 @@ from typing import Any, Callable, Optional import numpy as np import optuna +import optunahub import pandas as pd import scipy as sp import skimage @@ -74,6 +75,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel): .get("n_jobs", 1), max(int(self.max_system_threads / 4), 1), ), + "sampler": "tpe", "storage": "file", "continuous": True, "warm_start": True, @@ -1531,6 +1533,34 @@ class QuickAdapterRegressorV3(BaseRegressionModel): ) return storage + def optuna_create_pruner( + self, is_single_objective: bool + ) -> optuna.pruners.BasePruner: + if is_single_objective: + return optuna.pruners.HyperbandPruner( + min_resource=self._optuna_config.get("min_resource") + ) + else: + return optuna.pruners.NopPruner() + + def optuna_create_sampler(self) -> optuna.samplers.BaseSampler: + sampler = self._optuna_config.get("sampler", "tpe") + if sampler == "auto": + return optunahub.load_module("samplers/auto_sampler").AutoSampler( + seed=self._optuna_config.get("seed") + ) + elif sampler == "tpe": + return optuna.samplers.TPESampler( + n_startup_trials=self._optuna_config.get("n_startup_trials"), + multivariate=True, + group=True, + seed=self._optuna_config.get("seed"), + ) + else: + raise ValueError( + f"Unsupported sampler: '{sampler}'. Supported samplers: 'tpe', 'auto'" + ) + def optuna_create_study( self, pair: str, @@ -1566,23 +1596,11 @@ class QuickAdapterRegressorV3(BaseRegressionModel): if continuous: QuickAdapterRegressorV3.optuna_study_delete(study_name, storage) - if is_study_single_objective: - pruner = optuna.pruners.HyperbandPruner( - min_resource=self._optuna_config.get("min_resource") - ) - else: - pruner = optuna.pruners.NopPruner() - try: return optuna.create_study( study_name=study_name, - sampler=optuna.samplers.TPESampler( - n_startup_trials=self._optuna_config.get("n_startup_trials"), - multivariate=True, - group=True, - seed=self._optuna_config.get("seed"), - ), - pruner=pruner, + sampler=self.optuna_create_sampler(), + pruner=self.optuna_create_pruner(is_study_single_objective), direction=direction, directions=directions, storage=storage, diff --git a/quickadapter/user_data/strategies/QuickAdapterV3.py b/quickadapter/user_data/strategies/QuickAdapterV3.py index 3cde63a..0fcf370 100644 --- a/quickadapter/user_data/strategies/QuickAdapterV3.py +++ b/quickadapter/user_data/strategies/QuickAdapterV3.py @@ -75,7 +75,7 @@ class QuickAdapterV3(IStrategy): INTERFACE_VERSION = 3 def version(self) -> str: - return "3.3.169" + return "3.3.170" timeframe = "5m" @@ -266,7 +266,8 @@ class QuickAdapterV3(IStrategy): self._candle_threshold_cache: dict[CandleThresholdCacheKey, float] = {} self._cached_df_signature: dict[str, DfSignature] = {} - def _df_signature(self, df: DataFrame) -> DfSignature: + @staticmethod + def _df_signature(df: DataFrame) -> DfSignature: n = len(df) if n == 0: return (0, None) @@ -1159,7 +1160,7 @@ class QuickAdapterV3(IStrategy): interpolation_direction: Literal["direct", "inverse"] = "direct", quantile_exponent: float = 1.5, ) -> float: - df_signature = self._df_signature(df) + df_signature = QuickAdapterV3._df_signature(df) prev_df_signature = self._cached_df_signature.get(pair) if prev_df_signature != df_signature: self._candle_deviation_cache = { @@ -1229,7 +1230,7 @@ class QuickAdapterV3(IStrategy): max_natr_ratio_percent: float, candle_idx: int = -1, ) -> float: - df_signature = self._df_signature(df) + df_signature = QuickAdapterV3._df_signature(df) prev_df_signature = self._cached_df_signature.get(pair) if prev_df_signature != df_signature: self._candle_threshold_cache = {