]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
feat: add optuna auto sampler support
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Sat, 15 Nov 2025 11:39:41 +0000 (12:39 +0100)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Sat, 15 Nov 2025 11:39:41 +0000 (12:39 +0100)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
README.md
ReforceXY/Dockerfile.reforcexy
ReforceXY/user_data/freqaimodels/ReforceXY.py
quickadapter/Dockerfile.quickadapter
quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py
quickadapter/user_data/strategies/QuickAdapterV3.py

index 5e9f3134b1956759806dbbc71b214e1b0462dd28..300c3eee0531beb9a355070636fd8c1dfea9045a 100644 (file)
--- a/README.md
+++ b/README.md
@@ -80,12 +80,13 @@ docker compose up -d --build
 | freqai.outlier_threshold                             | 0.999             | float (0,1)                                                                                                                      | Quantile threshold for predictions outlier filtering.                                                                                                                                                      |
 | _Optuna / HPO_                                       |                   |                                                                                                                                  |                                                                                                                                                                                                            |
 | freqai.optuna_hyperopt.enabled                       | true              | bool                                                                                                                             | Enables HPO.                                                                                                                                                                                               |
-| freqai.optuna_hyperopt.n_jobs                        | CPU threads / 4   | int >= 1                                                                                                                         | Parallel HPO workers.                                                                                                                                                                                      |
+| freqai.optuna_hyperopt.sampler                       | `tpe`             | enum {`tpe`,`auto`}                                                                                                              | HPO sampler algorithm. `tpe` uses TPESampler with multivariate and group, `auto` uses AutoSampler.                                                                                                         |
 | freqai.optuna_hyperopt.storage                       | `file`            | enum {`file`,`sqlite`}                                                                                                           | HPO storage backend.                                                                                                                                                                                       |
 | freqai.optuna_hyperopt.continuous                    | true              | bool                                                                                                                             | Continuous HPO.                                                                                                                                                                                            |
 | freqai.optuna_hyperopt.warm_start                    | true              | bool                                                                                                                             | Warm start HPO with previous best value(s).                                                                                                                                                                |
 | freqai.optuna_hyperopt.n_startup_trials              | 15                | int >= 0                                                                                                                         | HPO startup trials.                                                                                                                                                                                        |
 | freqai.optuna_hyperopt.n_trials                      | 50                | int >= 1                                                                                                                         | Maximum HPO trials.                                                                                                                                                                                        |
+| freqai.optuna_hyperopt.n_jobs                        | CPU threads / 4   | int >= 1                                                                                                                         | Parallel HPO workers.                                                                                                                                                                                      |
 | freqai.optuna_hyperopt.timeout                       | 7200              | int >= 0                                                                                                                         | HPO wall-clock timeout in seconds.                                                                                                                                                                         |
 | freqai.optuna_hyperopt.label_candles_step            | 1                 | int >= 1                                                                                                                         | Step for Zigzag NATR horizon search space.                                                                                                                                                                 |
 | freqai.optuna_hyperopt.train_candles_step            | 10                | int >= 1                                                                                                                         | Step for training sets size search space.                                                                                                                                                                  |
index fa44fd18c63f04217e245c965a6bc51cbea0907f..d1ddb554826efb002c054c4905f3feac9f3a769b 100644 (file)
@@ -2,9 +2,16 @@ FROM freqtradeorg/freqtrade:stable_freqairl
 
 ARG optuna_version=4.6.0
 
+USER root
+RUN apt-get update \
+ && apt-get install -y --no-install-recommends git \
+ && rm -rf /var/lib/apt/lists/*
+USER ftuser
 RUN pip install --user --no-cache-dir \
     optuna==${optuna_version} \
-    optuna-dashboard
+    optuna-dashboard \
+    optunahub \
+    -r https://hub.optuna.org/samplers/auto_sampler/requirements.txt
 
 LABEL org.opencontainers.image.source="freqai-strategies" \
       org.opencontainers.image.title="freqtrade-reforcexy" \
index 3f85b50e7a1d53a5cf45332d45caf0ef99ed2214..d80843179994a7c2b07eab0073eea1162990470d 100644 (file)
@@ -14,6 +14,7 @@ import matplotlib
 import matplotlib.pyplot as plt
 import matplotlib.transforms as mtransforms
 import numpy as np
+import optunahub
 import torch as th
 from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
 from freqtrade.freqai.RL.Base5ActionRLEnv import Actions, Base5ActionRLEnv, Positions
@@ -28,8 +29,8 @@ from matplotlib.lines import Line2D
 from numpy.typing import NDArray
 from optuna import Trial, TrialPruned, create_study, delete_study
 from optuna.exceptions import ExperimentalWarning
-from optuna.pruners import HyperbandPruner
-from optuna.samplers import TPESampler
+from optuna.pruners import BasePruner, HyperbandPruner
+from optuna.samplers import BaseSampler, TPESampler
 from optuna.storages import (
     BaseStorage,
     JournalStorage,
@@ -101,9 +102,9 @@ class ReforceXY(BaseReinforcementLearningModel):
                 "timeout_hours": 0,                 // Maximum time in hours for hyperopt (0 = no timeout)
                 "continuous": false,                // If true, perform continuous optimization
                 "warm_start": false,                // If true, enqueue previous best params if exists
-                "seed": 42,                         // RNG seed
+                "sampler": "tpe",                   // Optuna sampler (tpe|auto)
                 "storage": "sqlite",                // Optuna storage backend (sqlite|file)
-            }
+                "seed": 42,                         // RNG seed
         }
     }
     Requirements:
@@ -942,6 +943,34 @@ class ReforceXY(BaseReinforcementLearningModel):
         except (ValueError, KeyError):
             return False
 
+    def create_sampler(self) -> BaseSampler:
+        sampler = self.rl_config_optuna.get("sampler", "tpe")
+        if sampler == "auto":
+            return optunahub.load_module("samplers/auto_sampler").AutoSampler(
+                seed=self.rl_config_optuna.get("seed", 42)
+            )
+        elif sampler == "tpe":
+            return TPESampler(
+                n_startup_trials=self.optuna_n_startup_trials,
+                multivariate=True,
+                group=True,
+                seed=self.rl_config_optuna.get("seed", 42),
+            )
+        else:
+            raise ValueError(
+                f"Unsupported sampler: '{sampler}'. Supported samplers: 'tpe', 'auto'"
+            )
+
+    @staticmethod
+    def create_pruner(
+        min_resource: int, max_resource: int, reduction_factor: int
+    ) -> BasePruner:
+        return HyperbandPruner(
+            min_resource=min_resource,
+            max_resource=max_resource,
+            reduction_factor=reduction_factor,
+        )
+
     def optimize(
         self, dk: FreqaiDataKitchen, total_timesteps: int
     ) -> Optional[Dict[str, Any]]:
@@ -963,18 +992,12 @@ class ReforceXY(BaseReinforcementLearningModel):
             reduction_factor * 2, (total_timesteps // self.n_envs) // resource_eval_freq
         )
         min_resource = max(1, max_resource // reduction_factor)
+
         study: Study = create_study(
             study_name=study_name,
-            sampler=TPESampler(
-                n_startup_trials=self.optuna_n_startup_trials,
-                multivariate=True,
-                group=True,
-                seed=self.rl_config_optuna.get("seed", 42),
-            ),
-            pruner=HyperbandPruner(
-                min_resource=min_resource,
-                max_resource=max_resource,
-                reduction_factor=reduction_factor,
+            sampler=self.create_sampler(),
+            pruner=ReforceXY.create_pruner(
+                min_resource, max_resource, reduction_factor
             ),
             direction=StudyDirection.MAXIMIZE,
             storage=storage,
index 1e85327839a9c00ecd801c7890dbf78313e7c559..2aaef0b46d000fbb7a5cf8afaa93ca54dd30187b 100644 (file)
@@ -6,13 +6,15 @@ ARG scikit_image_version=0.25.2
 
 USER root
 RUN apt-get update \
- && apt-get install -y --no-install-recommends build-essential \
+ && apt-get install -y --no-install-recommends build-essential git \
  && rm -rf /var/lib/apt/lists/*
 USER ftuser
 RUN pip install --user --no-cache-dir \
     optuna==${optuna_version} \
     optuna-integration==${optuna_version} \
     optuna-dashboard \
+    optunahub \
+    -r https://hub.optuna.org/samplers/auto_sampler/requirements.txt \
     scikit-learn-extra==${scikit_learn_extra_version} \
     scikit-image==${scikit_image_version}
 
index 3d74a789dfa2c4a4d8168327426398dd4d0b3120..d05c1eec703e8d48a0ef4f16646d2b541514623b 100644 (file)
@@ -10,6 +10,7 @@ from typing import Any, Callable, Optional
 
 import numpy as np
 import optuna
+import optunahub
 import pandas as pd
 import scipy as sp
 import skimage
@@ -74,6 +75,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
                 .get("n_jobs", 1),
                 max(int(self.max_system_threads / 4), 1),
             ),
+            "sampler": "tpe",
             "storage": "file",
             "continuous": True,
             "warm_start": True,
@@ -1531,6 +1533,34 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
             )
         return storage
 
+    def optuna_create_pruner(
+        self, is_single_objective: bool
+    ) -> optuna.pruners.BasePruner:
+        if is_single_objective:
+            return optuna.pruners.HyperbandPruner(
+                min_resource=self._optuna_config.get("min_resource")
+            )
+        else:
+            return optuna.pruners.NopPruner()
+
+    def optuna_create_sampler(self) -> optuna.samplers.BaseSampler:
+        sampler = self._optuna_config.get("sampler", "tpe")
+        if sampler == "auto":
+            return optunahub.load_module("samplers/auto_sampler").AutoSampler(
+                seed=self._optuna_config.get("seed")
+            )
+        elif sampler == "tpe":
+            return optuna.samplers.TPESampler(
+                n_startup_trials=self._optuna_config.get("n_startup_trials"),
+                multivariate=True,
+                group=True,
+                seed=self._optuna_config.get("seed"),
+            )
+        else:
+            raise ValueError(
+                f"Unsupported sampler: '{sampler}'. Supported samplers: 'tpe', 'auto'"
+            )
+
     def optuna_create_study(
         self,
         pair: str,
@@ -1566,23 +1596,11 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
         if continuous:
             QuickAdapterRegressorV3.optuna_study_delete(study_name, storage)
 
-        if is_study_single_objective:
-            pruner = optuna.pruners.HyperbandPruner(
-                min_resource=self._optuna_config.get("min_resource")
-            )
-        else:
-            pruner = optuna.pruners.NopPruner()
-
         try:
             return optuna.create_study(
                 study_name=study_name,
-                sampler=optuna.samplers.TPESampler(
-                    n_startup_trials=self._optuna_config.get("n_startup_trials"),
-                    multivariate=True,
-                    group=True,
-                    seed=self._optuna_config.get("seed"),
-                ),
-                pruner=pruner,
+                sampler=self.optuna_create_sampler(),
+                pruner=self.optuna_create_pruner(is_study_single_objective),
                 direction=direction,
                 directions=directions,
                 storage=storage,
index 3cde63aa7b55591425d6e086c27432b5487974cd..0fcf370b8813291d48991c86af0e80f943f317fa 100644 (file)
@@ -75,7 +75,7 @@ class QuickAdapterV3(IStrategy):
     INTERFACE_VERSION = 3
 
     def version(self) -> str:
-        return "3.3.169"
+        return "3.3.170"
 
     timeframe = "5m"
 
@@ -266,7 +266,8 @@ class QuickAdapterV3(IStrategy):
         self._candle_threshold_cache: dict[CandleThresholdCacheKey, float] = {}
         self._cached_df_signature: dict[str, DfSignature] = {}
 
-    def _df_signature(self, df: DataFrame) -> DfSignature:
+    @staticmethod
+    def _df_signature(df: DataFrame) -> DfSignature:
         n = len(df)
         if n == 0:
             return (0, None)
@@ -1159,7 +1160,7 @@ class QuickAdapterV3(IStrategy):
         interpolation_direction: Literal["direct", "inverse"] = "direct",
         quantile_exponent: float = 1.5,
     ) -> float:
-        df_signature = self._df_signature(df)
+        df_signature = QuickAdapterV3._df_signature(df)
         prev_df_signature = self._cached_df_signature.get(pair)
         if prev_df_signature != df_signature:
             self._candle_deviation_cache = {
@@ -1229,7 +1230,7 @@ class QuickAdapterV3(IStrategy):
         max_natr_ratio_percent: float,
         candle_idx: int = -1,
     ) -> float:
-        df_signature = self._df_signature(df)
+        df_signature = QuickAdapterV3._df_signature(df)
         prev_df_signature = self._cached_df_signature.get(pair)
         if prev_df_signature != df_signature:
             self._candle_threshold_cache = {