From cfef6c6565a6fc52dd08ced7ee7ac281494c6b00 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Tue, 18 Nov 2025 23:35:50 +0100 Subject: [PATCH] feat(qav3): add extrema selection methods MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- README.md | 1 + .../freqaimodels/QuickAdapterRegressorV3.py | 250 +++++++++++++----- .../user_data/strategies/QuickAdapterV3.py | 101 ++++--- 3 files changed, 255 insertions(+), 97 deletions(-) diff --git a/README.md b/README.md index 417748f..b9d6d39 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,7 @@ docker compose up -d --build | freqai.feature_parameters.label_knn_p_order | `None` | float | p-order for KNN Minkowski metric distance. (optional) | | freqai.feature_parameters.label_knn_n_neighbors | 5 | int >= 1 | Number of neighbors for KNN. | | _Prediction thresholds_ | | | | +| freqai.prediction_extrema_selection | `peak_values` | enum {`peak_values`,`extrema_rank`} | Extrema selection method. `peak_values` uses detected peaks, `extrema_rank` uses ranked extrema values. | | freqai.prediction_thresholds_smoothing | `mean` | enum {`mean`,`isodata`,`li`,`minimum`,`otsu`,`triangle`,`yen`,`soft_extremum`} | Thresholding method for prediction thresholds smoothing. | | freqai.prediction_thresholds_alpha | 12.0 | float > 0 | Alpha for `soft_extremum`. | | freqai.outlier_threshold | 0.999 | float (0,1) | Quantile threshold for predictions outlier filtering. | diff --git a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py index c020b77..e88ee9e 100644 --- a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py +++ b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py @@ -6,7 +6,7 @@ import time import warnings from functools import cached_property from pathlib import Path -from typing import Any, Callable, Optional +from typing import Any, Callable, Literal, Optional import numpy as np import optuna @@ -45,6 +45,9 @@ warnings.simplefilter(action="ignore", category=FutureWarning) logger = logging.getLogger(__name__) +ExtremaSelectionMethod = Literal["peak_values", "extrema_rank"] +OptunaNamespace = Literal["hp", "train", "label"] + class QuickAdapterRegressorV3(BaseRegressionModel): """ @@ -67,6 +70,22 @@ class QuickAdapterRegressorV3(BaseRegressionModel): _SQRT_2 = np.sqrt(2.0) + _EXTREMA_SELECTION_METHODS: tuple[ExtremaSelectionMethod, ...] = ( + "peak_values", + "extrema_rank", + ) + _OPTUNA_STORAGE_BACKENDS: tuple[str, ...] = ("sqlite", "file") + _OPTUNA_SAMPLERS: tuple[str, ...] = ("tpe", "auto") + _OPTUNA_NAMESPACES: tuple[OptunaNamespace, ...] = ("hp", "train", "label") + + @staticmethod + def _extrema_selection_methods_set() -> set[ExtremaSelectionMethod]: + return set(QuickAdapterRegressorV3._EXTREMA_SELECTION_METHODS) + + @staticmethod + def _optuna_namespaces_set() -> set[OptunaNamespace]: + return set(QuickAdapterRegressorV3._OPTUNA_NAMESPACES) + @cached_property def _optuna_config(self) -> dict[str, Any]: optuna_default_config = { @@ -77,8 +96,8 @@ class QuickAdapterRegressorV3(BaseRegressionModel): .get("n_jobs", 1), max(int(self.max_system_threads / 4), 1), ), - "sampler": "tpe", - "storage": "file", + "sampler": self._OPTUNA_SAMPLERS[0], # "tpe" + "storage": self._OPTUNA_STORAGE_BACKENDS[1], # "file" "continuous": True, "warm_start": True, "n_startup_trials": 15, @@ -206,18 +225,22 @@ class QuickAdapterRegressorV3(BaseRegressionModel): self._optuna_train_value[pair] = -1 self._optuna_label_values[pair] = [-1, -1] self._optuna_hp_params[pair] = ( - self.optuna_load_best_params(pair, "hp") - if self.optuna_load_best_params(pair, "hp") + self.optuna_load_best_params(pair, self._OPTUNA_NAMESPACES[0]) # "hp" + if self.optuna_load_best_params(pair, self._OPTUNA_NAMESPACES[0]) else {} ) self._optuna_train_params[pair] = ( - self.optuna_load_best_params(pair, "train") - if self.optuna_load_best_params(pair, "train") + self.optuna_load_best_params( + pair, self._OPTUNA_NAMESPACES[1] + ) # "train" + if self.optuna_load_best_params(pair, self._OPTUNA_NAMESPACES[1]) else {} ) self._optuna_label_params[pair] = ( - self.optuna_load_best_params(pair, "label") - if self.optuna_load_best_params(pair, "label") + self.optuna_load_best_params( + pair, self._OPTUNA_NAMESPACES[2] + ) # "label" + if self.optuna_load_best_params(pair, self._OPTUNA_NAMESPACES[2]) else { "label_period_candles": self.ft_params.get( "label_period_candles", @@ -239,59 +262,77 @@ class QuickAdapterRegressorV3(BaseRegressionModel): ) def get_optuna_params(self, pair: str, namespace: str) -> dict[str, Any]: - if namespace == "hp": + if namespace == self._OPTUNA_NAMESPACES[0]: # "hp" params = self._optuna_hp_params.get(pair) - elif namespace == "train": + elif namespace == self._OPTUNA_NAMESPACES[1]: # "train" params = self._optuna_train_params.get(pair) - elif namespace == "label": + elif namespace == self._OPTUNA_NAMESPACES[2]: # "label" params = self._optuna_label_params.get(pair) else: - raise ValueError(f"Invalid namespace: {namespace}") + raise ValueError( + f"Invalid namespace: {namespace}. " + f"Expected {', '.join(self._OPTUNA_NAMESPACES)}" + ) return params def set_optuna_params( self, pair: str, namespace: str, params: dict[str, Any] ) -> None: - if namespace == "hp": + if namespace == self._OPTUNA_NAMESPACES[0]: # "hp" self._optuna_hp_params[pair] = params - elif namespace == "train": + elif namespace == self._OPTUNA_NAMESPACES[1]: # "train" self._optuna_train_params[pair] = params - elif namespace == "label": + elif namespace == self._OPTUNA_NAMESPACES[2]: # "label" self._optuna_label_params[pair] = params else: - raise ValueError(f"Invalid namespace: {namespace}") + raise ValueError( + f"Invalid namespace: {namespace}. " + f"Expected {', '.join(self._OPTUNA_NAMESPACES)}" + ) def get_optuna_value(self, pair: str, namespace: str) -> float: - if namespace == "hp": + if namespace == self._OPTUNA_NAMESPACES[0]: # "hp" value = self._optuna_hp_value.get(pair) - elif namespace == "train": + elif namespace == self._OPTUNA_NAMESPACES[1]: # "train" value = self._optuna_train_value.get(pair) else: - raise ValueError(f"Invalid namespace: {namespace}") + raise ValueError( + f"Invalid namespace: {namespace}. " + f"Expected {', '.join(self._OPTUNA_NAMESPACES[:2])}" # Only hp and train + ) return value def set_optuna_value(self, pair: str, namespace: str, value: float) -> None: - if namespace == "hp": + if namespace == self._OPTUNA_NAMESPACES[0]: # "hp" self._optuna_hp_value[pair] = value - elif namespace == "train": + elif namespace == self._OPTUNA_NAMESPACES[1]: # "train" self._optuna_train_value[pair] = value else: - raise ValueError(f"Invalid namespace: {namespace}") + raise ValueError( + f"Invalid namespace: {namespace}. " + f"Expected {', '.join(self._OPTUNA_NAMESPACES[:2])}" # Only hp and train + ) def get_optuna_values(self, pair: str, namespace: str) -> list[float | int]: - if namespace == "label": + if namespace == self._OPTUNA_NAMESPACES[2]: # "label" values = self._optuna_label_values.get(pair) else: - raise ValueError(f"Invalid namespace: {namespace}") + raise ValueError( + f"Invalid namespace: {namespace}. " + f"Expected {self._OPTUNA_NAMESPACES[2]}" # Only label + ) return values def set_optuna_values( self, pair: str, namespace: str, values: list[float | int] ) -> None: - if namespace == "label": + if namespace == self._OPTUNA_NAMESPACES[2]: # "label" self._optuna_label_values[pair] = values else: - raise ValueError(f"Invalid namespace: {namespace}") + raise ValueError( + f"Invalid namespace: {namespace}. " + f"Expected {self._OPTUNA_NAMESPACES[2]}" # Only label + ) def init_optuna_label_candle_pool(self) -> None: optuna_label_candle_pool_full = self._optuna_label_candle_pool_full @@ -364,7 +405,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel): if self._optuna_hyperopt: self.optuna_optimize( pair=dk.pair, - namespace="hp", + namespace=self._OPTUNA_NAMESPACES[0], # "hp" objective=lambda trial: hp_objective( trial, str(self.freqai_info.get("regressor", "xgboost")), @@ -374,7 +415,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel): X_test, y_test, test_weights, - self.get_optuna_params(dk.pair, "hp"), + self.get_optuna_params(dk.pair, self._OPTUNA_NAMESPACES[0]), # "hp" model_training_parameters, self._optuna_config.get("space_reduction"), self._optuna_config.get("expansion_ratio"), @@ -382,7 +423,9 @@ class QuickAdapterRegressorV3(BaseRegressionModel): direction=optuna.study.StudyDirection.MINIMIZE, ) - optuna_hp_params = self.get_optuna_params(dk.pair, "hp") + optuna_hp_params = self.get_optuna_params( + dk.pair, self._OPTUNA_NAMESPACES[0] + ) # "hp" if optuna_hp_params: model_training_parameters = { **model_training_parameters, @@ -391,7 +434,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel): train_study = self.optuna_optimize( pair=dk.pair, - namespace="train", + namespace=self._OPTUNA_NAMESPACES[1], # "train" objective=lambda trial: train_objective( trial, str(self.freqai_info.get("regressor", "xgboost")), @@ -409,12 +452,20 @@ class QuickAdapterRegressorV3(BaseRegressionModel): direction=optuna.study.StudyDirection.MINIMIZE, ) - optuna_hp_value = self.get_optuna_value(dk.pair, "hp") - optuna_train_params = self.get_optuna_params(dk.pair, "train") - optuna_train_value = self.get_optuna_value(dk.pair, "train") + optuna_hp_value = self.get_optuna_value( + dk.pair, self._OPTUNA_NAMESPACES[0] + ) # "hp" + optuna_train_params = self.get_optuna_params( + dk.pair, self._OPTUNA_NAMESPACES[1] + ) # "train" + optuna_train_value = self.get_optuna_value( + dk.pair, self._OPTUNA_NAMESPACES[1] + ) # "train" if ( optuna_train_params - and self.optuna_validate_params(dk.pair, "train", train_study) + and self.optuna_validate_params( + dk.pair, self._OPTUNA_NAMESPACES[1], train_study + ) # "train" and optuna_train_value < optuna_hp_value ): train_period_candles = optuna_train_params.get("train_period_candles") @@ -461,8 +512,11 @@ class QuickAdapterRegressorV3(BaseRegressionModel): namespace: str, callback: Callable[[], None], ) -> None: - if namespace != "label": - raise ValueError(f"Invalid namespace: {namespace}") + if namespace not in {self._OPTUNA_NAMESPACES[2]}: # Only "label" + raise ValueError( + f"Invalid namespace: {namespace}. " + f"Expected {self._OPTUNA_NAMESPACES[2]}" # Only label + ) if not callable(callback): raise ValueError("callback must be callable") self._optuna_label_candles[pair] += 1 @@ -499,10 +553,10 @@ class QuickAdapterRegressorV3(BaseRegressionModel): if self._optuna_hyperopt: self.optuna_throttle_callback( pair=pair, - namespace="label", + namespace=self._OPTUNA_NAMESPACES[2], # "label" callback=lambda: self.optuna_optimize( pair=pair, - namespace="label", + namespace=self._OPTUNA_NAMESPACES[2], # "label" objective=lambda trial: label_objective( trial, self.data_provider.get_pair_dataframe( @@ -555,7 +609,9 @@ class QuickAdapterRegressorV3(BaseRegressionModel): min_pred, max_pred = self.min_max_pred( pred_df, fit_live_predictions_candles, - self.get_optuna_params(pair, "label").get("label_period_candles"), + self.get_optuna_params(pair, self._OPTUNA_NAMESPACES[2]).get( + "label_period_candles" + ), # "label" ) dk.data["extra_returns_per_train"][MINIMA_THRESHOLD_COLUMN] = min_pred dk.data["extra_returns_per_train"][MAXIMA_THRESHOLD_COLUMN] = max_pred @@ -593,17 +649,24 @@ class QuickAdapterRegressorV3(BaseRegressionModel): dk.data["extra_returns_per_train"]["DI_cutoff"] = cutoff dk.data["extra_returns_per_train"]["label_period_candles"] = ( - self.get_optuna_params(pair, "label").get("label_period_candles") + self.get_optuna_params(pair, self._OPTUNA_NAMESPACES[2]).get( + "label_period_candles" + ) # "label" ) dk.data["extra_returns_per_train"]["label_natr_ratio"] = self.get_optuna_params( - pair, "label" + pair, + self._OPTUNA_NAMESPACES[2], # "label" ).get("label_natr_ratio") - hp_rmse = self.optuna_validate_value(self.get_optuna_value(pair, "hp")) + hp_rmse = self.optuna_validate_value( + self.get_optuna_value(pair, self._OPTUNA_NAMESPACES[0]) + ) # "hp" dk.data["extra_returns_per_train"]["hp_rmse"] = ( hp_rmse if hp_rmse is not None else np.inf ) - train_rmse = self.optuna_validate_value(self.get_optuna_value(pair, "train")) + train_rmse = self.optuna_validate_value( + self.get_optuna_value(pair, self._OPTUNA_NAMESPACES[1]) + ) # "train" dk.data["extra_returns_per_train"]["train_rmse"] = ( train_rmse if (train_rmse is not None and hp_rmse is not None and train_rmse < hp_rmse) @@ -643,6 +706,19 @@ class QuickAdapterRegressorV3(BaseRegressionModel): thresholds_candles = max(2, int(label_period_cycles)) * label_period_candles pred_extrema = pred_df.get(EXTREMA_COLUMN).iloc[-thresholds_candles:].copy() + + extrema_selection = str( + self.freqai_info.get( + "prediction_extrema_selection", + self._EXTREMA_SELECTION_METHODS[0], + ) + ) + if extrema_selection not in self._extrema_selection_methods_set(): + raise ValueError( + f"Unsupported extrema selection method: {extrema_selection}. " + f"Supported methods are {', '.join(self._EXTREMA_SELECTION_METHODS)}" + ) + extrema_selection: ExtremaSelectionMethod = extrema_selection # type: ignore[assignment] thresholds_smoothing = str( self.freqai_info.get("prediction_thresholds_smoothing", "mean") ) @@ -663,11 +739,11 @@ class QuickAdapterRegressorV3(BaseRegressionModel): self.freqai_info.get("prediction_thresholds_alpha", 12.0) ) return QuickAdapterRegressorV3.soft_extremum_min_max( - pred_extrema, thresholds_alpha + pred_extrema, thresholds_alpha, extrema_selection ) elif thresholds_smoothing in skimage_thresholds_smoothing_methods: return QuickAdapterRegressorV3.skimage_min_max( - pred_extrema, thresholds_smoothing + pred_extrema, thresholds_smoothing, extrema_selection ) else: raise ValueError( @@ -675,7 +751,10 @@ class QuickAdapterRegressorV3(BaseRegressionModel): ) @staticmethod - def get_pred_min_max(pred_extrema: pd.Series) -> tuple[pd.Series, pd.Series]: + def get_pred_min_max( + pred_extrema: pd.Series, + extrema_selection: ExtremaSelectionMethod, + ) -> tuple[pd.Series, pd.Series]: pred_extrema = ( pd.to_numeric(pred_extrema, errors="coerce") .where(np.isfinite, np.nan) @@ -683,13 +762,41 @@ class QuickAdapterRegressorV3(BaseRegressionModel): ) if pred_extrema.empty: return pd.Series(dtype=float), pd.Series(dtype=float) - n_pred_minima = max(1, sp.signal.find_peaks(-pred_extrema)[0].size) - n_pred_maxima = max(1, sp.signal.find_peaks(pred_extrema)[0].size) - sorted_pred_extrema = pred_extrema.sort_values(ascending=True) - return sorted_pred_extrema.iloc[:n_pred_minima], sorted_pred_extrema.iloc[ - -n_pred_maxima: - ] + minima_indices = sp.signal.find_peaks(-pred_extrema)[0] + maxima_indices = sp.signal.find_peaks(pred_extrema)[0] + + if extrema_selection == QuickAdapterRegressorV3._EXTREMA_SELECTION_METHODS[0]: + pred_minima = ( + pred_extrema.iloc[minima_indices] + if minima_indices.size > 0 + else pd.Series(dtype=float) + ) + pred_maxima = ( + pred_extrema.iloc[maxima_indices] + if maxima_indices.size > 0 + else pd.Series(dtype=float) + ) + elif extrema_selection == QuickAdapterRegressorV3._EXTREMA_SELECTION_METHODS[1]: + n_minima = minima_indices.size + n_maxima = maxima_indices.size + + if n_minima > 0: + pred_minima = pred_extrema.nsmallest(n_minima) + else: + pred_minima = pd.Series(dtype=float) + + if n_maxima > 0: + pred_maxima = pred_extrema.nlargest(n_maxima) + else: + pred_maxima = pd.Series(dtype=float) + else: + raise ValueError( + f"Unsupported extrema selection method: {extrema_selection}. " + f"Supported methods are {', '.join(QuickAdapterRegressorV3._EXTREMA_SELECTION_METHODS)}" + ) + + return pred_minima, pred_maxima @staticmethod def safe_min_pred(pred_extrema: pd.Series) -> float: @@ -721,12 +828,14 @@ class QuickAdapterRegressorV3(BaseRegressionModel): @staticmethod def soft_extremum_min_max( - pred_extrema: pd.Series, alpha: float + pred_extrema: pd.Series, + alpha: float, + extrema_selection: ExtremaSelectionMethod, ) -> tuple[float, float]: if alpha < 0: raise ValueError("alpha must be non-negative") pred_minima, pred_maxima = QuickAdapterRegressorV3.get_pred_min_max( - pred_extrema + pred_extrema, extrema_selection ) soft_minimum = soft_extremum(pred_minima, alpha=-alpha) if not np.isfinite(soft_minimum): @@ -737,9 +846,13 @@ class QuickAdapterRegressorV3(BaseRegressionModel): return soft_minimum, soft_maximum @staticmethod - def skimage_min_max(pred_extrema: pd.Series, method: str) -> tuple[float, float]: + def skimage_min_max( + pred_extrema: pd.Series, + method: str, + extrema_selection: ExtremaSelectionMethod, + ) -> tuple[float, float]: pred_minima, pred_maxima = QuickAdapterRegressorV3.get_pred_min_max( - pred_extrema + pred_extrema, extrema_selection ) method_functions = { @@ -882,8 +995,11 @@ class QuickAdapterRegressorV3(BaseRegressionModel): def get_multi_objective_study_best_trial( self, namespace: str, study: optuna.study.Study ) -> Optional[optuna.trial.FrozenTrial]: - if namespace != "label": - raise ValueError(f"Invalid namespace: {namespace}") + if namespace not in {self._OPTUNA_NAMESPACES[2]}: # Only "label" + raise ValueError( + f"Invalid namespace: {namespace}. " + f"Expected {self._OPTUNA_NAMESPACES[2]}" # Only label + ) n_objectives = len(study.directions) if n_objectives < 2: raise ValueError( @@ -1518,7 +1634,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel): storage_dir = self.full_path storage_filename = f"optuna-{pair.split('/')[0]}" storage_backend = self._optuna_config.get("storage") - if storage_backend == "sqlite": + if storage_backend == self._OPTUNA_STORAGE_BACKENDS[0]: # "sqlite" storage = optuna.storages.RDBStorage( url=f"sqlite:///{storage_dir}/{storage_filename}.sqlite", heartbeat_interval=60, @@ -1526,7 +1642,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel): max_retry=3 ), ) - elif storage_backend == "file": + elif storage_backend == self._OPTUNA_STORAGE_BACKENDS[1]: # "file" storage = optuna.storages.JournalStorage( optuna.storages.journal.JournalFileBackend( f"{storage_dir}/{storage_filename}.log" @@ -1534,7 +1650,8 @@ class QuickAdapterRegressorV3(BaseRegressionModel): ) else: raise ValueError( - f"Unsupported optuna storage backend: {storage_backend}. Supported backends are 'sqlite' and 'file'" + f"Unsupported optuna storage backend: {storage_backend}. " + f"Supported backends are {', '.join(self._OPTUNA_STORAGE_BACKENDS)}" ) return storage @@ -1549,12 +1666,12 @@ class QuickAdapterRegressorV3(BaseRegressionModel): return optuna.pruners.NopPruner() def optuna_create_sampler(self) -> optuna.samplers.BaseSampler: - sampler = self._optuna_config.get("sampler", "tpe") - if sampler == "auto": + sampler = self._optuna_config.get("sampler", self._OPTUNA_SAMPLERS[0]) + if sampler == self._OPTUNA_SAMPLERS[1]: # "auto" return optunahub.load_module("samplers/auto_sampler").AutoSampler( seed=self._optuna_config.get("seed") ) - elif sampler == "tpe": + elif sampler == self._OPTUNA_SAMPLERS[0]: # "tpe" return optuna.samplers.TPESampler( n_startup_trials=self._optuna_config.get("n_startup_trials"), multivariate=True, @@ -1563,7 +1680,8 @@ class QuickAdapterRegressorV3(BaseRegressionModel): ) else: raise ValueError( - f"Unsupported sampler: '{sampler}'. Supported samplers: 'tpe', 'auto'" + f"Unsupported sampler: {sampler}. " + f"Supported samplers are {', '.join(self._OPTUNA_SAMPLERS)}" ) def optuna_create_study( diff --git a/quickadapter/user_data/strategies/QuickAdapterV3.py b/quickadapter/user_data/strategies/QuickAdapterV3.py index 9e1c486..8f53359 100644 --- a/quickadapter/user_data/strategies/QuickAdapterV3.py +++ b/quickadapter/user_data/strategies/QuickAdapterV3.py @@ -54,6 +54,12 @@ EXTREMA_COLUMN = "&s-extrema" MAXIMA_THRESHOLD_COLUMN = "&s-maxima_threshold" MINIMA_THRESHOLD_COLUMN = "&s-minima_threshold" +# Type aliases +TradeDirection = Literal["long", "short"] +InterpolationDirection = Literal["direct", "inverse"] +OrderType = Literal["entry", "exit"] +TradingMode = Literal["spot", "margin", "futures"] + class QuickAdapterV3(IStrategy): """ @@ -74,6 +80,14 @@ class QuickAdapterV3(IStrategy): INTERFACE_VERSION = 3 + _TRADE_DIRECTIONS: tuple[TradeDirection, ...] = ("long", "short") + _INTERPOLATION_DIRECTIONS: tuple[InterpolationDirection, ...] = ( + "direct", + "inverse", + ) + _ORDER_TYPES: tuple[OrderType, ...] = ("entry", "exit") + _TRADING_MODES: tuple[TradingMode, ...] = ("spot", "margin", "futures") + def version(self) -> str: return "3.3.170" @@ -129,6 +143,17 @@ class QuickAdapterV3(IStrategy): def can_short(self) -> bool: return self.is_short_allowed() + @staticmethod + def _trade_directions_set() -> set[TradeDirection]: + return { + QuickAdapterV3._TRADE_DIRECTIONS[0], + QuickAdapterV3._TRADE_DIRECTIONS[1], + } + + @staticmethod + def _order_types_set() -> set[OrderType]: + return {QuickAdapterV3._ORDER_TYPES[0], QuickAdapterV3._ORDER_TYPES[1]} + @cached_property def plot_config(self) -> dict[str, Any]: return { @@ -1199,13 +1224,13 @@ class QuickAdapterV3(IStrategy): if isna(candle_label_natr_value_quantile): return np.nan - if interpolation_direction == "direct": + if interpolation_direction == self._INTERPOLATION_DIRECTIONS[0]: # "direct" natr_ratio_percent = ( min_natr_ratio_percent + (max_natr_ratio_percent - min_natr_ratio_percent) * candle_label_natr_value_quantile**quantile_exponent ) - elif interpolation_direction == "inverse": + elif interpolation_direction == self._INTERPOLATION_DIRECTIONS[1]: # "inverse" natr_ratio_percent = ( max_natr_ratio_percent - (max_natr_ratio_percent - min_natr_ratio_percent) @@ -1213,7 +1238,8 @@ class QuickAdapterV3(IStrategy): ) else: raise ValueError( - f"Invalid interpolation_direction: {interpolation_direction}. Expected 'direct' or 'inverse'" + f"Invalid interpolation_direction: {interpolation_direction}. " + f"Expected {', '.join(self._INTERPOLATION_DIRECTIONS)}" ) candle_deviation = ( candle_label_natr_value / 100.0 @@ -1253,7 +1279,7 @@ class QuickAdapterV3(IStrategy): min_natr_ratio_percent=min_natr_ratio_percent, max_natr_ratio_percent=max_natr_ratio_percent, candle_idx=candle_idx, - interpolation_direction="direct", + interpolation_direction=self._INTERPOLATION_DIRECTIONS[0], # "direct" ) if isna(current_deviation) or current_deviation <= 0: return np.nan @@ -1268,14 +1294,14 @@ class QuickAdapterV3(IStrategy): is_candle_bullish: bool = candle_close > candle_open is_candle_bearish: bool = candle_close < candle_open - if side == "long": + if side == self._TRADE_DIRECTIONS[0]: # "long" base_price = ( QuickAdapterV3.weighted_close(candle) if is_candle_bearish else candle_close ) candle_threshold = base_price * (1 + current_deviation) - elif side == "short": + elif side == self._TRADE_DIRECTIONS[1]: # "short" base_price = ( QuickAdapterV3.weighted_close(candle) if is_candle_bullish @@ -1283,7 +1309,9 @@ class QuickAdapterV3(IStrategy): ) candle_threshold = base_price * (1 - current_deviation) else: - raise ValueError(f"Invalid side: {side}. Expected 'long' or 'short'") + raise ValueError( + f"Invalid side: {side}. Expected {', '.join(self._TRADE_DIRECTIONS)}" + ) self._candle_threshold_cache[cache_key] = candle_threshold return self._candle_threshold_cache[cache_key] @@ -1367,9 +1395,9 @@ class QuickAdapterV3(IStrategy): """ if df.empty: return False - if side not in {"long", "short"}: + if side not in self._sides_set(): return False - if order not in {"entry", "exit"}: + if order not in self._order_types_set(): return False trade_direction = side @@ -1397,14 +1425,16 @@ class QuickAdapterV3(IStrategy): candle_idx=-1, ) current_ok = np.isfinite(current_threshold) and ( - (side == "long" and rate > current_threshold) - or (side == "short" and rate < current_threshold) - ) - if order == "exit": - if side == "long": - trade_direction = "short" - if side == "short": - trade_direction = "long" + (side == self._TRADE_DIRECTIONS[0] and rate > current_threshold) # "long" + or ( + side == self._TRADE_DIRECTIONS[1] and rate < current_threshold + ) # "short" + ) + if order == self._ORDER_TYPES[1]: # "exit" + if side == self._TRADE_DIRECTIONS[0]: # "long" + trade_direction = self._TRADE_DIRECTIONS[1] # "short" + if side == self._TRADE_DIRECTIONS[1]: # "short" + trade_direction = self._TRADE_DIRECTIONS[0] # "long" if not current_ok: logger.info( f"User denied {trade_direction} {order} for {pair}: rate {format_number(rate)} did not break threshold {format_number(current_threshold)}" @@ -1441,8 +1471,11 @@ class QuickAdapterV3(IStrategy): ): return current_ok - if (side == "long" and not (close_k > threshold_k)) or ( - side == "short" and not (close_k < threshold_k) + if ( + side == self._TRADE_DIRECTIONS[0] and not (close_k > threshold_k) + ) or ( # "long" + side == self._TRADE_DIRECTIONS[1] + and not (close_k < threshold_k) # "short" ): logger.info( f"User denied {trade_direction} {order} for {pair}: " @@ -1670,15 +1703,15 @@ class QuickAdapterV3(IStrategy): trade.set_custom_data("last_outlier_date", last_candle_date.isoformat()) if ( - trade.trade_direction == "short" + trade.trade_direction == self._TRADE_DIRECTIONS[1] # "short" and last_candle.get("do_predict") == 1 and last_candle.get("DI_catch") == 1 and last_candle.get(EXTREMA_COLUMN) < last_candle.get("minima_threshold") and self.reversal_confirmed( df, pair, - "long", - "exit", + self._TRADE_DIRECTIONS[0], # "long" + self._ORDER_TYPES[1], # "exit" current_rate, self._reversal_lookback_period, self._reversal_decay_ratio, @@ -1688,15 +1721,15 @@ class QuickAdapterV3(IStrategy): ): return "minima_detected_short" if ( - trade.trade_direction == "long" + trade.trade_direction == self._TRADE_DIRECTIONS[0] # "long" and last_candle.get("do_predict") == 1 and last_candle.get("DI_catch") == 1 and last_candle.get(EXTREMA_COLUMN) > last_candle.get("maxima_threshold") and self.reversal_confirmed( df, pair, - "short", - "exit", + self._TRADE_DIRECTIONS[1], # "short" + self._ORDER_TYPES[1], # "exit" current_rate, self._reversal_lookback_period, self._reversal_decay_ratio, @@ -1805,9 +1838,9 @@ class QuickAdapterV3(IStrategy): side: str, **kwargs, ) -> bool: - if side not in {"long", "short"}: + if side not in self._sides_set(): return False - if side == "short" and not self.can_short: + if side == self._TRADE_DIRECTIONS[1] and not self.can_short: # "short" logger.info(f"User denied short entry for {pair}: shorting not allowed") return False if Trade.get_open_trade_count() >= self.config.get("max_open_trades"): @@ -1831,7 +1864,7 @@ class QuickAdapterV3(IStrategy): df, pair, side, - "entry", + self._ORDER_TYPES[0], # "entry" rate, self._reversal_lookback_period, self._reversal_decay_ratio, @@ -1843,12 +1876,18 @@ class QuickAdapterV3(IStrategy): def is_short_allowed(self) -> bool: trading_mode = self.config.get("trading_mode") - if trading_mode in {"margin", "futures"}: + if trading_mode in { + self._TRADING_MODES[1], + self._TRADING_MODES[2], + }: # margin, futures return True - elif trading_mode == "spot": + elif trading_mode == self._TRADING_MODES[0]: # "spot" return False else: - raise ValueError(f"Invalid trading_mode: {trading_mode}") + raise ValueError( + f"Invalid trading_mode: {trading_mode}. " + f"Expected {', '.join(self._TRADING_MODES)}" + ) def leverage( self, -- 2.43.0