]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
feat(qav3): add extrema selection methods
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Tue, 18 Nov 2025 22:35:50 +0000 (23:35 +0100)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Tue, 18 Nov 2025 22:35:50 +0000 (23:35 +0100)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
README.md
quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py
quickadapter/user_data/strategies/QuickAdapterV3.py

index 417748fb8eca2ee67ba65a551d4bee62ee2706f0..b9d6d397c7bd6fb74d80564fabd7deff957caef0 100644 (file)
--- a/README.md
+++ b/README.md
@@ -75,6 +75,7 @@ docker compose up -d --build
 | freqai.feature_parameters.label_knn_p_order          | `None`            | float                                                                                                                            | p-order for KNN Minkowski metric distance. (optional)                                                                                                                                                      |
 | freqai.feature_parameters.label_knn_n_neighbors      | 5                 | int >= 1                                                                                                                         | Number of neighbors for KNN.                                                                                                                                                                               |
 | _Prediction thresholds_                              |                   |                                                                                                                                  |                                                                                                                                                                                                            |
+| freqai.prediction_extrema_selection                  | `peak_values`     | enum {`peak_values`,`extrema_rank`}                                                                                              | Extrema selection method. `peak_values` uses detected peaks, `extrema_rank` uses ranked extrema values.                                                                                                    |
 | freqai.prediction_thresholds_smoothing               | `mean`            | enum {`mean`,`isodata`,`li`,`minimum`,`otsu`,`triangle`,`yen`,`soft_extremum`}                                                   | Thresholding method for prediction thresholds smoothing.                                                                                                                                                   |
 | freqai.prediction_thresholds_alpha                   | 12.0              | float > 0                                                                                                                        | Alpha for `soft_extremum`.                                                                                                                                                                                 |
 | freqai.outlier_threshold                             | 0.999             | float (0,1)                                                                                                                      | Quantile threshold for predictions outlier filtering.                                                                                                                                                      |
index c020b77712f3a5b9ebed60c61711107c48b91b02..e88ee9e0e982131d0c5e20b333a6f87e32a88431 100644 (file)
@@ -6,7 +6,7 @@ import time
 import warnings
 from functools import cached_property
 from pathlib import Path
-from typing import Any, Callable, Optional
+from typing import Any, Callable, Literal, Optional
 
 import numpy as np
 import optuna
@@ -45,6 +45,9 @@ warnings.simplefilter(action="ignore", category=FutureWarning)
 
 logger = logging.getLogger(__name__)
 
+ExtremaSelectionMethod = Literal["peak_values", "extrema_rank"]
+OptunaNamespace = Literal["hp", "train", "label"]
+
 
 class QuickAdapterRegressorV3(BaseRegressionModel):
     """
@@ -67,6 +70,22 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
 
     _SQRT_2 = np.sqrt(2.0)
 
+    _EXTREMA_SELECTION_METHODS: tuple[ExtremaSelectionMethod, ...] = (
+        "peak_values",
+        "extrema_rank",
+    )
+    _OPTUNA_STORAGE_BACKENDS: tuple[str, ...] = ("sqlite", "file")
+    _OPTUNA_SAMPLERS: tuple[str, ...] = ("tpe", "auto")
+    _OPTUNA_NAMESPACES: tuple[OptunaNamespace, ...] = ("hp", "train", "label")
+
+    @staticmethod
+    def _extrema_selection_methods_set() -> set[ExtremaSelectionMethod]:
+        return set(QuickAdapterRegressorV3._EXTREMA_SELECTION_METHODS)
+
+    @staticmethod
+    def _optuna_namespaces_set() -> set[OptunaNamespace]:
+        return set(QuickAdapterRegressorV3._OPTUNA_NAMESPACES)
+
     @cached_property
     def _optuna_config(self) -> dict[str, Any]:
         optuna_default_config = {
@@ -77,8 +96,8 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
                 .get("n_jobs", 1),
                 max(int(self.max_system_threads / 4), 1),
             ),
-            "sampler": "tpe",
-            "storage": "file",
+            "sampler": self._OPTUNA_SAMPLERS[0],  # "tpe"
+            "storage": self._OPTUNA_STORAGE_BACKENDS[1],  # "file"
             "continuous": True,
             "warm_start": True,
             "n_startup_trials": 15,
@@ -206,18 +225,22 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
             self._optuna_train_value[pair] = -1
             self._optuna_label_values[pair] = [-1, -1]
             self._optuna_hp_params[pair] = (
-                self.optuna_load_best_params(pair, "hp")
-                if self.optuna_load_best_params(pair, "hp")
+                self.optuna_load_best_params(pair, self._OPTUNA_NAMESPACES[0])  # "hp"
+                if self.optuna_load_best_params(pair, self._OPTUNA_NAMESPACES[0])
                 else {}
             )
             self._optuna_train_params[pair] = (
-                self.optuna_load_best_params(pair, "train")
-                if self.optuna_load_best_params(pair, "train")
+                self.optuna_load_best_params(
+                    pair, self._OPTUNA_NAMESPACES[1]
+                )  # "train"
+                if self.optuna_load_best_params(pair, self._OPTUNA_NAMESPACES[1])
                 else {}
             )
             self._optuna_label_params[pair] = (
-                self.optuna_load_best_params(pair, "label")
-                if self.optuna_load_best_params(pair, "label")
+                self.optuna_load_best_params(
+                    pair, self._OPTUNA_NAMESPACES[2]
+                )  # "label"
+                if self.optuna_load_best_params(pair, self._OPTUNA_NAMESPACES[2])
                 else {
                     "label_period_candles": self.ft_params.get(
                         "label_period_candles",
@@ -239,59 +262,77 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
         )
 
     def get_optuna_params(self, pair: str, namespace: str) -> dict[str, Any]:
-        if namespace == "hp":
+        if namespace == self._OPTUNA_NAMESPACES[0]:  # "hp"
             params = self._optuna_hp_params.get(pair)
-        elif namespace == "train":
+        elif namespace == self._OPTUNA_NAMESPACES[1]:  # "train"
             params = self._optuna_train_params.get(pair)
-        elif namespace == "label":
+        elif namespace == self._OPTUNA_NAMESPACES[2]:  # "label"
             params = self._optuna_label_params.get(pair)
         else:
-            raise ValueError(f"Invalid namespace: {namespace}")
+            raise ValueError(
+                f"Invalid namespace: {namespace}. "
+                f"Expected {', '.join(self._OPTUNA_NAMESPACES)}"
+            )
         return params
 
     def set_optuna_params(
         self, pair: str, namespace: str, params: dict[str, Any]
     ) -> None:
-        if namespace == "hp":
+        if namespace == self._OPTUNA_NAMESPACES[0]:  # "hp"
             self._optuna_hp_params[pair] = params
-        elif namespace == "train":
+        elif namespace == self._OPTUNA_NAMESPACES[1]:  # "train"
             self._optuna_train_params[pair] = params
-        elif namespace == "label":
+        elif namespace == self._OPTUNA_NAMESPACES[2]:  # "label"
             self._optuna_label_params[pair] = params
         else:
-            raise ValueError(f"Invalid namespace: {namespace}")
+            raise ValueError(
+                f"Invalid namespace: {namespace}. "
+                f"Expected {', '.join(self._OPTUNA_NAMESPACES)}"
+            )
 
     def get_optuna_value(self, pair: str, namespace: str) -> float:
-        if namespace == "hp":
+        if namespace == self._OPTUNA_NAMESPACES[0]:  # "hp"
             value = self._optuna_hp_value.get(pair)
-        elif namespace == "train":
+        elif namespace == self._OPTUNA_NAMESPACES[1]:  # "train"
             value = self._optuna_train_value.get(pair)
         else:
-            raise ValueError(f"Invalid namespace: {namespace}")
+            raise ValueError(
+                f"Invalid namespace: {namespace}. "
+                f"Expected {', '.join(self._OPTUNA_NAMESPACES[:2])}"  # Only hp and train
+            )
         return value
 
     def set_optuna_value(self, pair: str, namespace: str, value: float) -> None:
-        if namespace == "hp":
+        if namespace == self._OPTUNA_NAMESPACES[0]:  # "hp"
             self._optuna_hp_value[pair] = value
-        elif namespace == "train":
+        elif namespace == self._OPTUNA_NAMESPACES[1]:  # "train"
             self._optuna_train_value[pair] = value
         else:
-            raise ValueError(f"Invalid namespace: {namespace}")
+            raise ValueError(
+                f"Invalid namespace: {namespace}. "
+                f"Expected {', '.join(self._OPTUNA_NAMESPACES[:2])}"  # Only hp and train
+            )
 
     def get_optuna_values(self, pair: str, namespace: str) -> list[float | int]:
-        if namespace == "label":
+        if namespace == self._OPTUNA_NAMESPACES[2]:  # "label"
             values = self._optuna_label_values.get(pair)
         else:
-            raise ValueError(f"Invalid namespace: {namespace}")
+            raise ValueError(
+                f"Invalid namespace: {namespace}. "
+                f"Expected {self._OPTUNA_NAMESPACES[2]}"  # Only label
+            )
         return values
 
     def set_optuna_values(
         self, pair: str, namespace: str, values: list[float | int]
     ) -> None:
-        if namespace == "label":
+        if namespace == self._OPTUNA_NAMESPACES[2]:  # "label"
             self._optuna_label_values[pair] = values
         else:
-            raise ValueError(f"Invalid namespace: {namespace}")
+            raise ValueError(
+                f"Invalid namespace: {namespace}. "
+                f"Expected {self._OPTUNA_NAMESPACES[2]}"  # Only label
+            )
 
     def init_optuna_label_candle_pool(self) -> None:
         optuna_label_candle_pool_full = self._optuna_label_candle_pool_full
@@ -364,7 +405,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
         if self._optuna_hyperopt:
             self.optuna_optimize(
                 pair=dk.pair,
-                namespace="hp",
+                namespace=self._OPTUNA_NAMESPACES[0],  # "hp"
                 objective=lambda trial: hp_objective(
                     trial,
                     str(self.freqai_info.get("regressor", "xgboost")),
@@ -374,7 +415,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
                     X_test,
                     y_test,
                     test_weights,
-                    self.get_optuna_params(dk.pair, "hp"),
+                    self.get_optuna_params(dk.pair, self._OPTUNA_NAMESPACES[0]),  # "hp"
                     model_training_parameters,
                     self._optuna_config.get("space_reduction"),
                     self._optuna_config.get("expansion_ratio"),
@@ -382,7 +423,9 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
                 direction=optuna.study.StudyDirection.MINIMIZE,
             )
 
-            optuna_hp_params = self.get_optuna_params(dk.pair, "hp")
+            optuna_hp_params = self.get_optuna_params(
+                dk.pair, self._OPTUNA_NAMESPACES[0]
+            )  # "hp"
             if optuna_hp_params:
                 model_training_parameters = {
                     **model_training_parameters,
@@ -391,7 +434,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
 
             train_study = self.optuna_optimize(
                 pair=dk.pair,
-                namespace="train",
+                namespace=self._OPTUNA_NAMESPACES[1],  # "train"
                 objective=lambda trial: train_objective(
                     trial,
                     str(self.freqai_info.get("regressor", "xgboost")),
@@ -409,12 +452,20 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
                 direction=optuna.study.StudyDirection.MINIMIZE,
             )
 
-            optuna_hp_value = self.get_optuna_value(dk.pair, "hp")
-            optuna_train_params = self.get_optuna_params(dk.pair, "train")
-            optuna_train_value = self.get_optuna_value(dk.pair, "train")
+            optuna_hp_value = self.get_optuna_value(
+                dk.pair, self._OPTUNA_NAMESPACES[0]
+            )  # "hp"
+            optuna_train_params = self.get_optuna_params(
+                dk.pair, self._OPTUNA_NAMESPACES[1]
+            )  # "train"
+            optuna_train_value = self.get_optuna_value(
+                dk.pair, self._OPTUNA_NAMESPACES[1]
+            )  # "train"
             if (
                 optuna_train_params
-                and self.optuna_validate_params(dk.pair, "train", train_study)
+                and self.optuna_validate_params(
+                    dk.pair, self._OPTUNA_NAMESPACES[1], train_study
+                )  # "train"
                 and optuna_train_value < optuna_hp_value
             ):
                 train_period_candles = optuna_train_params.get("train_period_candles")
@@ -461,8 +512,11 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
         namespace: str,
         callback: Callable[[], None],
     ) -> None:
-        if namespace != "label":
-            raise ValueError(f"Invalid namespace: {namespace}")
+        if namespace not in {self._OPTUNA_NAMESPACES[2]}:  # Only "label"
+            raise ValueError(
+                f"Invalid namespace: {namespace}. "
+                f"Expected {self._OPTUNA_NAMESPACES[2]}"  # Only label
+            )
         if not callable(callback):
             raise ValueError("callback must be callable")
         self._optuna_label_candles[pair] += 1
@@ -499,10 +553,10 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
         if self._optuna_hyperopt:
             self.optuna_throttle_callback(
                 pair=pair,
-                namespace="label",
+                namespace=self._OPTUNA_NAMESPACES[2],  # "label"
                 callback=lambda: self.optuna_optimize(
                     pair=pair,
-                    namespace="label",
+                    namespace=self._OPTUNA_NAMESPACES[2],  # "label"
                     objective=lambda trial: label_objective(
                         trial,
                         self.data_provider.get_pair_dataframe(
@@ -555,7 +609,9 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
             min_pred, max_pred = self.min_max_pred(
                 pred_df,
                 fit_live_predictions_candles,
-                self.get_optuna_params(pair, "label").get("label_period_candles"),
+                self.get_optuna_params(pair, self._OPTUNA_NAMESPACES[2]).get(
+                    "label_period_candles"
+                ),  # "label"
             )
             dk.data["extra_returns_per_train"][MINIMA_THRESHOLD_COLUMN] = min_pred
             dk.data["extra_returns_per_train"][MAXIMA_THRESHOLD_COLUMN] = max_pred
@@ -593,17 +649,24 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
         dk.data["extra_returns_per_train"]["DI_cutoff"] = cutoff
 
         dk.data["extra_returns_per_train"]["label_period_candles"] = (
-            self.get_optuna_params(pair, "label").get("label_period_candles")
+            self.get_optuna_params(pair, self._OPTUNA_NAMESPACES[2]).get(
+                "label_period_candles"
+            )  # "label"
         )
         dk.data["extra_returns_per_train"]["label_natr_ratio"] = self.get_optuna_params(
-            pair, "label"
+            pair,
+            self._OPTUNA_NAMESPACES[2],  # "label"
         ).get("label_natr_ratio")
 
-        hp_rmse = self.optuna_validate_value(self.get_optuna_value(pair, "hp"))
+        hp_rmse = self.optuna_validate_value(
+            self.get_optuna_value(pair, self._OPTUNA_NAMESPACES[0])
+        )  # "hp"
         dk.data["extra_returns_per_train"]["hp_rmse"] = (
             hp_rmse if hp_rmse is not None else np.inf
         )
-        train_rmse = self.optuna_validate_value(self.get_optuna_value(pair, "train"))
+        train_rmse = self.optuna_validate_value(
+            self.get_optuna_value(pair, self._OPTUNA_NAMESPACES[1])
+        )  # "train"
         dk.data["extra_returns_per_train"]["train_rmse"] = (
             train_rmse
             if (train_rmse is not None and hp_rmse is not None and train_rmse < hp_rmse)
@@ -643,6 +706,19 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
         thresholds_candles = max(2, int(label_period_cycles)) * label_period_candles
 
         pred_extrema = pred_df.get(EXTREMA_COLUMN).iloc[-thresholds_candles:].copy()
+
+        extrema_selection = str(
+            self.freqai_info.get(
+                "prediction_extrema_selection",
+                self._EXTREMA_SELECTION_METHODS[0],
+            )
+        )
+        if extrema_selection not in self._extrema_selection_methods_set():
+            raise ValueError(
+                f"Unsupported extrema selection method: {extrema_selection}. "
+                f"Supported methods are {', '.join(self._EXTREMA_SELECTION_METHODS)}"
+            )
+        extrema_selection: ExtremaSelectionMethod = extrema_selection  # type: ignore[assignment]
         thresholds_smoothing = str(
             self.freqai_info.get("prediction_thresholds_smoothing", "mean")
         )
@@ -663,11 +739,11 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
                 self.freqai_info.get("prediction_thresholds_alpha", 12.0)
             )
             return QuickAdapterRegressorV3.soft_extremum_min_max(
-                pred_extrema, thresholds_alpha
+                pred_extrema, thresholds_alpha, extrema_selection
             )
         elif thresholds_smoothing in skimage_thresholds_smoothing_methods:
             return QuickAdapterRegressorV3.skimage_min_max(
-                pred_extrema, thresholds_smoothing
+                pred_extrema, thresholds_smoothing, extrema_selection
             )
         else:
             raise ValueError(
@@ -675,7 +751,10 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
             )
 
     @staticmethod
-    def get_pred_min_max(pred_extrema: pd.Series) -> tuple[pd.Series, pd.Series]:
+    def get_pred_min_max(
+        pred_extrema: pd.Series,
+        extrema_selection: ExtremaSelectionMethod,
+    ) -> tuple[pd.Series, pd.Series]:
         pred_extrema = (
             pd.to_numeric(pred_extrema, errors="coerce")
             .where(np.isfinite, np.nan)
@@ -683,13 +762,41 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
         )
         if pred_extrema.empty:
             return pd.Series(dtype=float), pd.Series(dtype=float)
-        n_pred_minima = max(1, sp.signal.find_peaks(-pred_extrema)[0].size)
-        n_pred_maxima = max(1, sp.signal.find_peaks(pred_extrema)[0].size)
 
-        sorted_pred_extrema = pred_extrema.sort_values(ascending=True)
-        return sorted_pred_extrema.iloc[:n_pred_minima], sorted_pred_extrema.iloc[
-            -n_pred_maxima:
-        ]
+        minima_indices = sp.signal.find_peaks(-pred_extrema)[0]
+        maxima_indices = sp.signal.find_peaks(pred_extrema)[0]
+
+        if extrema_selection == QuickAdapterRegressorV3._EXTREMA_SELECTION_METHODS[0]:
+            pred_minima = (
+                pred_extrema.iloc[minima_indices]
+                if minima_indices.size > 0
+                else pd.Series(dtype=float)
+            )
+            pred_maxima = (
+                pred_extrema.iloc[maxima_indices]
+                if maxima_indices.size > 0
+                else pd.Series(dtype=float)
+            )
+        elif extrema_selection == QuickAdapterRegressorV3._EXTREMA_SELECTION_METHODS[1]:
+            n_minima = minima_indices.size
+            n_maxima = maxima_indices.size
+
+            if n_minima > 0:
+                pred_minima = pred_extrema.nsmallest(n_minima)
+            else:
+                pred_minima = pd.Series(dtype=float)
+
+            if n_maxima > 0:
+                pred_maxima = pred_extrema.nlargest(n_maxima)
+            else:
+                pred_maxima = pd.Series(dtype=float)
+        else:
+            raise ValueError(
+                f"Unsupported extrema selection method: {extrema_selection}. "
+                f"Supported methods are {', '.join(QuickAdapterRegressorV3._EXTREMA_SELECTION_METHODS)}"
+            )
+
+        return pred_minima, pred_maxima
 
     @staticmethod
     def safe_min_pred(pred_extrema: pd.Series) -> float:
@@ -721,12 +828,14 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
 
     @staticmethod
     def soft_extremum_min_max(
-        pred_extrema: pd.Series, alpha: float
+        pred_extrema: pd.Series,
+        alpha: float,
+        extrema_selection: ExtremaSelectionMethod,
     ) -> tuple[float, float]:
         if alpha < 0:
             raise ValueError("alpha must be non-negative")
         pred_minima, pred_maxima = QuickAdapterRegressorV3.get_pred_min_max(
-            pred_extrema
+            pred_extrema, extrema_selection
         )
         soft_minimum = soft_extremum(pred_minima, alpha=-alpha)
         if not np.isfinite(soft_minimum):
@@ -737,9 +846,13 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
         return soft_minimum, soft_maximum
 
     @staticmethod
-    def skimage_min_max(pred_extrema: pd.Series, method: str) -> tuple[float, float]:
+    def skimage_min_max(
+        pred_extrema: pd.Series,
+        method: str,
+        extrema_selection: ExtremaSelectionMethod,
+    ) -> tuple[float, float]:
         pred_minima, pred_maxima = QuickAdapterRegressorV3.get_pred_min_max(
-            pred_extrema
+            pred_extrema, extrema_selection
         )
 
         method_functions = {
@@ -882,8 +995,11 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
     def get_multi_objective_study_best_trial(
         self, namespace: str, study: optuna.study.Study
     ) -> Optional[optuna.trial.FrozenTrial]:
-        if namespace != "label":
-            raise ValueError(f"Invalid namespace: {namespace}")
+        if namespace not in {self._OPTUNA_NAMESPACES[2]}:  # Only "label"
+            raise ValueError(
+                f"Invalid namespace: {namespace}. "
+                f"Expected {self._OPTUNA_NAMESPACES[2]}"  # Only label
+            )
         n_objectives = len(study.directions)
         if n_objectives < 2:
             raise ValueError(
@@ -1518,7 +1634,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
         storage_dir = self.full_path
         storage_filename = f"optuna-{pair.split('/')[0]}"
         storage_backend = self._optuna_config.get("storage")
-        if storage_backend == "sqlite":
+        if storage_backend == self._OPTUNA_STORAGE_BACKENDS[0]:  # "sqlite"
             storage = optuna.storages.RDBStorage(
                 url=f"sqlite:///{storage_dir}/{storage_filename}.sqlite",
                 heartbeat_interval=60,
@@ -1526,7 +1642,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
                     max_retry=3
                 ),
             )
-        elif storage_backend == "file":
+        elif storage_backend == self._OPTUNA_STORAGE_BACKENDS[1]:  # "file"
             storage = optuna.storages.JournalStorage(
                 optuna.storages.journal.JournalFileBackend(
                     f"{storage_dir}/{storage_filename}.log"
@@ -1534,7 +1650,8 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
             )
         else:
             raise ValueError(
-                f"Unsupported optuna storage backend: {storage_backend}. Supported backends are 'sqlite' and 'file'"
+                f"Unsupported optuna storage backend: {storage_backend}. "
+                f"Supported backends are {', '.join(self._OPTUNA_STORAGE_BACKENDS)}"
             )
         return storage
 
@@ -1549,12 +1666,12 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
             return optuna.pruners.NopPruner()
 
     def optuna_create_sampler(self) -> optuna.samplers.BaseSampler:
-        sampler = self._optuna_config.get("sampler", "tpe")
-        if sampler == "auto":
+        sampler = self._optuna_config.get("sampler", self._OPTUNA_SAMPLERS[0])
+        if sampler == self._OPTUNA_SAMPLERS[1]:  # "auto"
             return optunahub.load_module("samplers/auto_sampler").AutoSampler(
                 seed=self._optuna_config.get("seed")
             )
-        elif sampler == "tpe":
+        elif sampler == self._OPTUNA_SAMPLERS[0]:  # "tpe"
             return optuna.samplers.TPESampler(
                 n_startup_trials=self._optuna_config.get("n_startup_trials"),
                 multivariate=True,
@@ -1563,7 +1680,8 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
             )
         else:
             raise ValueError(
-                f"Unsupported sampler: '{sampler}'. Supported samplers: 'tpe', 'auto'"
+                f"Unsupported sampler: {sampler}. "
+                f"Supported samplers are {', '.join(self._OPTUNA_SAMPLERS)}"
             )
 
     def optuna_create_study(
index 9e1c48614a74c46e8b10fd9956c98187bda8c420..8f53359296836a106b2517db2f9cdcb1ecf9d735 100644 (file)
@@ -54,6 +54,12 @@ EXTREMA_COLUMN = "&s-extrema"
 MAXIMA_THRESHOLD_COLUMN = "&s-maxima_threshold"
 MINIMA_THRESHOLD_COLUMN = "&s-minima_threshold"
 
+# Type aliases
+TradeDirection = Literal["long", "short"]
+InterpolationDirection = Literal["direct", "inverse"]
+OrderType = Literal["entry", "exit"]
+TradingMode = Literal["spot", "margin", "futures"]
+
 
 class QuickAdapterV3(IStrategy):
     """
@@ -74,6 +80,14 @@ class QuickAdapterV3(IStrategy):
 
     INTERFACE_VERSION = 3
 
+    _TRADE_DIRECTIONS: tuple[TradeDirection, ...] = ("long", "short")
+    _INTERPOLATION_DIRECTIONS: tuple[InterpolationDirection, ...] = (
+        "direct",
+        "inverse",
+    )
+    _ORDER_TYPES: tuple[OrderType, ...] = ("entry", "exit")
+    _TRADING_MODES: tuple[TradingMode, ...] = ("spot", "margin", "futures")
+
     def version(self) -> str:
         return "3.3.170"
 
@@ -129,6 +143,17 @@ class QuickAdapterV3(IStrategy):
     def can_short(self) -> bool:
         return self.is_short_allowed()
 
+    @staticmethod
+    def _trade_directions_set() -> set[TradeDirection]:
+        return {
+            QuickAdapterV3._TRADE_DIRECTIONS[0],
+            QuickAdapterV3._TRADE_DIRECTIONS[1],
+        }
+
+    @staticmethod
+    def _order_types_set() -> set[OrderType]:
+        return {QuickAdapterV3._ORDER_TYPES[0], QuickAdapterV3._ORDER_TYPES[1]}
+
     @cached_property
     def plot_config(self) -> dict[str, Any]:
         return {
@@ -1199,13 +1224,13 @@ class QuickAdapterV3(IStrategy):
         if isna(candle_label_natr_value_quantile):
             return np.nan
 
-        if interpolation_direction == "direct":
+        if interpolation_direction == self._INTERPOLATION_DIRECTIONS[0]:  # "direct"
             natr_ratio_percent = (
                 min_natr_ratio_percent
                 + (max_natr_ratio_percent - min_natr_ratio_percent)
                 * candle_label_natr_value_quantile**quantile_exponent
             )
-        elif interpolation_direction == "inverse":
+        elif interpolation_direction == self._INTERPOLATION_DIRECTIONS[1]:  # "inverse"
             natr_ratio_percent = (
                 max_natr_ratio_percent
                 - (max_natr_ratio_percent - min_natr_ratio_percent)
@@ -1213,7 +1238,8 @@ class QuickAdapterV3(IStrategy):
             )
         else:
             raise ValueError(
-                f"Invalid interpolation_direction: {interpolation_direction}. Expected 'direct' or 'inverse'"
+                f"Invalid interpolation_direction: {interpolation_direction}. "
+                f"Expected {', '.join(self._INTERPOLATION_DIRECTIONS)}"
             )
         candle_deviation = (
             candle_label_natr_value / 100.0
@@ -1253,7 +1279,7 @@ class QuickAdapterV3(IStrategy):
             min_natr_ratio_percent=min_natr_ratio_percent,
             max_natr_ratio_percent=max_natr_ratio_percent,
             candle_idx=candle_idx,
-            interpolation_direction="direct",
+            interpolation_direction=self._INTERPOLATION_DIRECTIONS[0],  # "direct"
         )
         if isna(current_deviation) or current_deviation <= 0:
             return np.nan
@@ -1268,14 +1294,14 @@ class QuickAdapterV3(IStrategy):
         is_candle_bullish: bool = candle_close > candle_open
         is_candle_bearish: bool = candle_close < candle_open
 
-        if side == "long":
+        if side == self._TRADE_DIRECTIONS[0]:  # "long"
             base_price = (
                 QuickAdapterV3.weighted_close(candle)
                 if is_candle_bearish
                 else candle_close
             )
             candle_threshold = base_price * (1 + current_deviation)
-        elif side == "short":
+        elif side == self._TRADE_DIRECTIONS[1]:  # "short"
             base_price = (
                 QuickAdapterV3.weighted_close(candle)
                 if is_candle_bullish
@@ -1283,7 +1309,9 @@ class QuickAdapterV3(IStrategy):
             )
             candle_threshold = base_price * (1 - current_deviation)
         else:
-            raise ValueError(f"Invalid side: {side}. Expected 'long' or 'short'")
+            raise ValueError(
+                f"Invalid side: {side}. Expected {', '.join(self._TRADE_DIRECTIONS)}"
+            )
         self._candle_threshold_cache[cache_key] = candle_threshold
         return self._candle_threshold_cache[cache_key]
 
@@ -1367,9 +1395,9 @@ class QuickAdapterV3(IStrategy):
         """
         if df.empty:
             return False
-        if side not in {"long", "short"}:
+        if side not in self._sides_set():
             return False
-        if order not in {"entry", "exit"}:
+        if order not in self._order_types_set():
             return False
 
         trade_direction = side
@@ -1397,14 +1425,16 @@ class QuickAdapterV3(IStrategy):
             candle_idx=-1,
         )
         current_ok = np.isfinite(current_threshold) and (
-            (side == "long" and rate > current_threshold)
-            or (side == "short" and rate < current_threshold)
-        )
-        if order == "exit":
-            if side == "long":
-                trade_direction = "short"
-            if side == "short":
-                trade_direction = "long"
+            (side == self._TRADE_DIRECTIONS[0] and rate > current_threshold)  # "long"
+            or (
+                side == self._TRADE_DIRECTIONS[1] and rate < current_threshold
+            )  # "short"
+        )
+        if order == self._ORDER_TYPES[1]:  # "exit"
+            if side == self._TRADE_DIRECTIONS[0]:  # "long"
+                trade_direction = self._TRADE_DIRECTIONS[1]  # "short"
+            if side == self._TRADE_DIRECTIONS[1]:  # "short"
+                trade_direction = self._TRADE_DIRECTIONS[0]  # "long"
         if not current_ok:
             logger.info(
                 f"User denied {trade_direction} {order} for {pair}: rate {format_number(rate)} did not break threshold {format_number(current_threshold)}"
@@ -1441,8 +1471,11 @@ class QuickAdapterV3(IStrategy):
             ):
                 return current_ok
 
-            if (side == "long" and not (close_k > threshold_k)) or (
-                side == "short" and not (close_k < threshold_k)
+            if (
+                side == self._TRADE_DIRECTIONS[0] and not (close_k > threshold_k)
+            ) or (  # "long"
+                side == self._TRADE_DIRECTIONS[1]
+                and not (close_k < threshold_k)  # "short"
             ):
                 logger.info(
                     f"User denied {trade_direction} {order} for {pair}: "
@@ -1670,15 +1703,15 @@ class QuickAdapterV3(IStrategy):
                 trade.set_custom_data("last_outlier_date", last_candle_date.isoformat())
 
         if (
-            trade.trade_direction == "short"
+            trade.trade_direction == self._TRADE_DIRECTIONS[1]  # "short"
             and last_candle.get("do_predict") == 1
             and last_candle.get("DI_catch") == 1
             and last_candle.get(EXTREMA_COLUMN) < last_candle.get("minima_threshold")
             and self.reversal_confirmed(
                 df,
                 pair,
-                "long",
-                "exit",
+                self._TRADE_DIRECTIONS[0],  # "long"
+                self._ORDER_TYPES[1],  # "exit"
                 current_rate,
                 self._reversal_lookback_period,
                 self._reversal_decay_ratio,
@@ -1688,15 +1721,15 @@ class QuickAdapterV3(IStrategy):
         ):
             return "minima_detected_short"
         if (
-            trade.trade_direction == "long"
+            trade.trade_direction == self._TRADE_DIRECTIONS[0]  # "long"
             and last_candle.get("do_predict") == 1
             and last_candle.get("DI_catch") == 1
             and last_candle.get(EXTREMA_COLUMN) > last_candle.get("maxima_threshold")
             and self.reversal_confirmed(
                 df,
                 pair,
-                "short",
-                "exit",
+                self._TRADE_DIRECTIONS[1],  # "short"
+                self._ORDER_TYPES[1],  # "exit"
                 current_rate,
                 self._reversal_lookback_period,
                 self._reversal_decay_ratio,
@@ -1805,9 +1838,9 @@ class QuickAdapterV3(IStrategy):
         side: str,
         **kwargs,
     ) -> bool:
-        if side not in {"long", "short"}:
+        if side not in self._sides_set():
             return False
-        if side == "short" and not self.can_short:
+        if side == self._TRADE_DIRECTIONS[1] and not self.can_short:  # "short"
             logger.info(f"User denied short entry for {pair}: shorting not allowed")
             return False
         if Trade.get_open_trade_count() >= self.config.get("max_open_trades"):
@@ -1831,7 +1864,7 @@ class QuickAdapterV3(IStrategy):
             df,
             pair,
             side,
-            "entry",
+            self._ORDER_TYPES[0],  # "entry"
             rate,
             self._reversal_lookback_period,
             self._reversal_decay_ratio,
@@ -1843,12 +1876,18 @@ class QuickAdapterV3(IStrategy):
 
     def is_short_allowed(self) -> bool:
         trading_mode = self.config.get("trading_mode")
-        if trading_mode in {"margin", "futures"}:
+        if trading_mode in {
+            self._TRADING_MODES[1],
+            self._TRADING_MODES[2],
+        }:  # margin, futures
             return True
-        elif trading_mode == "spot":
+        elif trading_mode == self._TRADING_MODES[0]:  # "spot"
             return False
         else:
-            raise ValueError(f"Invalid trading_mode: {trading_mode}")
+            raise ValueError(
+                f"Invalid trading_mode: {trading_mode}. "
+                f"Expected {', '.join(self._TRADING_MODES)}"
+            )
 
     def leverage(
         self,