import warnings
from functools import cached_property
from pathlib import Path
-from typing import Any, Callable, Optional
+from typing import Any, Callable, Literal, Optional
import numpy as np
import optuna
logger = logging.getLogger(__name__)
+ExtremaSelectionMethod = Literal["peak_values", "extrema_rank"]
+OptunaNamespace = Literal["hp", "train", "label"]
+
class QuickAdapterRegressorV3(BaseRegressionModel):
"""
_SQRT_2 = np.sqrt(2.0)
+ _EXTREMA_SELECTION_METHODS: tuple[ExtremaSelectionMethod, ...] = (
+ "peak_values",
+ "extrema_rank",
+ )
+ _OPTUNA_STORAGE_BACKENDS: tuple[str, ...] = ("sqlite", "file")
+ _OPTUNA_SAMPLERS: tuple[str, ...] = ("tpe", "auto")
+ _OPTUNA_NAMESPACES: tuple[OptunaNamespace, ...] = ("hp", "train", "label")
+
+ @staticmethod
+ def _extrema_selection_methods_set() -> set[ExtremaSelectionMethod]:
+ return set(QuickAdapterRegressorV3._EXTREMA_SELECTION_METHODS)
+
+ @staticmethod
+ def _optuna_namespaces_set() -> set[OptunaNamespace]:
+ return set(QuickAdapterRegressorV3._OPTUNA_NAMESPACES)
+
@cached_property
def _optuna_config(self) -> dict[str, Any]:
optuna_default_config = {
.get("n_jobs", 1),
max(int(self.max_system_threads / 4), 1),
),
- "sampler": "tpe",
- "storage": "file",
+ "sampler": self._OPTUNA_SAMPLERS[0], # "tpe"
+ "storage": self._OPTUNA_STORAGE_BACKENDS[1], # "file"
"continuous": True,
"warm_start": True,
"n_startup_trials": 15,
self._optuna_train_value[pair] = -1
self._optuna_label_values[pair] = [-1, -1]
self._optuna_hp_params[pair] = (
- self.optuna_load_best_params(pair, "hp")
- if self.optuna_load_best_params(pair, "hp")
+ self.optuna_load_best_params(pair, self._OPTUNA_NAMESPACES[0]) # "hp"
+ if self.optuna_load_best_params(pair, self._OPTUNA_NAMESPACES[0])
else {}
)
self._optuna_train_params[pair] = (
- self.optuna_load_best_params(pair, "train")
- if self.optuna_load_best_params(pair, "train")
+ self.optuna_load_best_params(
+ pair, self._OPTUNA_NAMESPACES[1]
+ ) # "train"
+ if self.optuna_load_best_params(pair, self._OPTUNA_NAMESPACES[1])
else {}
)
self._optuna_label_params[pair] = (
- self.optuna_load_best_params(pair, "label")
- if self.optuna_load_best_params(pair, "label")
+ self.optuna_load_best_params(
+ pair, self._OPTUNA_NAMESPACES[2]
+ ) # "label"
+ if self.optuna_load_best_params(pair, self._OPTUNA_NAMESPACES[2])
else {
"label_period_candles": self.ft_params.get(
"label_period_candles",
)
def get_optuna_params(self, pair: str, namespace: str) -> dict[str, Any]:
- if namespace == "hp":
+ if namespace == self._OPTUNA_NAMESPACES[0]: # "hp"
params = self._optuna_hp_params.get(pair)
- elif namespace == "train":
+ elif namespace == self._OPTUNA_NAMESPACES[1]: # "train"
params = self._optuna_train_params.get(pair)
- elif namespace == "label":
+ elif namespace == self._OPTUNA_NAMESPACES[2]: # "label"
params = self._optuna_label_params.get(pair)
else:
- raise ValueError(f"Invalid namespace: {namespace}")
+ raise ValueError(
+ f"Invalid namespace: {namespace}. "
+ f"Expected {', '.join(self._OPTUNA_NAMESPACES)}"
+ )
return params
def set_optuna_params(
self, pair: str, namespace: str, params: dict[str, Any]
) -> None:
- if namespace == "hp":
+ if namespace == self._OPTUNA_NAMESPACES[0]: # "hp"
self._optuna_hp_params[pair] = params
- elif namespace == "train":
+ elif namespace == self._OPTUNA_NAMESPACES[1]: # "train"
self._optuna_train_params[pair] = params
- elif namespace == "label":
+ elif namespace == self._OPTUNA_NAMESPACES[2]: # "label"
self._optuna_label_params[pair] = params
else:
- raise ValueError(f"Invalid namespace: {namespace}")
+ raise ValueError(
+ f"Invalid namespace: {namespace}. "
+ f"Expected {', '.join(self._OPTUNA_NAMESPACES)}"
+ )
def get_optuna_value(self, pair: str, namespace: str) -> float:
- if namespace == "hp":
+ if namespace == self._OPTUNA_NAMESPACES[0]: # "hp"
value = self._optuna_hp_value.get(pair)
- elif namespace == "train":
+ elif namespace == self._OPTUNA_NAMESPACES[1]: # "train"
value = self._optuna_train_value.get(pair)
else:
- raise ValueError(f"Invalid namespace: {namespace}")
+ raise ValueError(
+ f"Invalid namespace: {namespace}. "
+ f"Expected {', '.join(self._OPTUNA_NAMESPACES[:2])}" # Only hp and train
+ )
return value
def set_optuna_value(self, pair: str, namespace: str, value: float) -> None:
- if namespace == "hp":
+ if namespace == self._OPTUNA_NAMESPACES[0]: # "hp"
self._optuna_hp_value[pair] = value
- elif namespace == "train":
+ elif namespace == self._OPTUNA_NAMESPACES[1]: # "train"
self._optuna_train_value[pair] = value
else:
- raise ValueError(f"Invalid namespace: {namespace}")
+ raise ValueError(
+ f"Invalid namespace: {namespace}. "
+ f"Expected {', '.join(self._OPTUNA_NAMESPACES[:2])}" # Only hp and train
+ )
def get_optuna_values(self, pair: str, namespace: str) -> list[float | int]:
- if namespace == "label":
+ if namespace == self._OPTUNA_NAMESPACES[2]: # "label"
values = self._optuna_label_values.get(pair)
else:
- raise ValueError(f"Invalid namespace: {namespace}")
+ raise ValueError(
+ f"Invalid namespace: {namespace}. "
+ f"Expected {self._OPTUNA_NAMESPACES[2]}" # Only label
+ )
return values
def set_optuna_values(
self, pair: str, namespace: str, values: list[float | int]
) -> None:
- if namespace == "label":
+ if namespace == self._OPTUNA_NAMESPACES[2]: # "label"
self._optuna_label_values[pair] = values
else:
- raise ValueError(f"Invalid namespace: {namespace}")
+ raise ValueError(
+ f"Invalid namespace: {namespace}. "
+ f"Expected {self._OPTUNA_NAMESPACES[2]}" # Only label
+ )
def init_optuna_label_candle_pool(self) -> None:
optuna_label_candle_pool_full = self._optuna_label_candle_pool_full
if self._optuna_hyperopt:
self.optuna_optimize(
pair=dk.pair,
- namespace="hp",
+ namespace=self._OPTUNA_NAMESPACES[0], # "hp"
objective=lambda trial: hp_objective(
trial,
str(self.freqai_info.get("regressor", "xgboost")),
X_test,
y_test,
test_weights,
- self.get_optuna_params(dk.pair, "hp"),
+ self.get_optuna_params(dk.pair, self._OPTUNA_NAMESPACES[0]), # "hp"
model_training_parameters,
self._optuna_config.get("space_reduction"),
self._optuna_config.get("expansion_ratio"),
direction=optuna.study.StudyDirection.MINIMIZE,
)
- optuna_hp_params = self.get_optuna_params(dk.pair, "hp")
+ optuna_hp_params = self.get_optuna_params(
+ dk.pair, self._OPTUNA_NAMESPACES[0]
+ ) # "hp"
if optuna_hp_params:
model_training_parameters = {
**model_training_parameters,
train_study = self.optuna_optimize(
pair=dk.pair,
- namespace="train",
+ namespace=self._OPTUNA_NAMESPACES[1], # "train"
objective=lambda trial: train_objective(
trial,
str(self.freqai_info.get("regressor", "xgboost")),
direction=optuna.study.StudyDirection.MINIMIZE,
)
- optuna_hp_value = self.get_optuna_value(dk.pair, "hp")
- optuna_train_params = self.get_optuna_params(dk.pair, "train")
- optuna_train_value = self.get_optuna_value(dk.pair, "train")
+ optuna_hp_value = self.get_optuna_value(
+ dk.pair, self._OPTUNA_NAMESPACES[0]
+ ) # "hp"
+ optuna_train_params = self.get_optuna_params(
+ dk.pair, self._OPTUNA_NAMESPACES[1]
+ ) # "train"
+ optuna_train_value = self.get_optuna_value(
+ dk.pair, self._OPTUNA_NAMESPACES[1]
+ ) # "train"
if (
optuna_train_params
- and self.optuna_validate_params(dk.pair, "train", train_study)
+ and self.optuna_validate_params(
+ dk.pair, self._OPTUNA_NAMESPACES[1], train_study
+ ) # "train"
and optuna_train_value < optuna_hp_value
):
train_period_candles = optuna_train_params.get("train_period_candles")
namespace: str,
callback: Callable[[], None],
) -> None:
- if namespace != "label":
- raise ValueError(f"Invalid namespace: {namespace}")
+ if namespace not in {self._OPTUNA_NAMESPACES[2]}: # Only "label"
+ raise ValueError(
+ f"Invalid namespace: {namespace}. "
+ f"Expected {self._OPTUNA_NAMESPACES[2]}" # Only label
+ )
if not callable(callback):
raise ValueError("callback must be callable")
self._optuna_label_candles[pair] += 1
if self._optuna_hyperopt:
self.optuna_throttle_callback(
pair=pair,
- namespace="label",
+ namespace=self._OPTUNA_NAMESPACES[2], # "label"
callback=lambda: self.optuna_optimize(
pair=pair,
- namespace="label",
+ namespace=self._OPTUNA_NAMESPACES[2], # "label"
objective=lambda trial: label_objective(
trial,
self.data_provider.get_pair_dataframe(
min_pred, max_pred = self.min_max_pred(
pred_df,
fit_live_predictions_candles,
- self.get_optuna_params(pair, "label").get("label_period_candles"),
+ self.get_optuna_params(pair, self._OPTUNA_NAMESPACES[2]).get(
+ "label_period_candles"
+ ), # "label"
)
dk.data["extra_returns_per_train"][MINIMA_THRESHOLD_COLUMN] = min_pred
dk.data["extra_returns_per_train"][MAXIMA_THRESHOLD_COLUMN] = max_pred
dk.data["extra_returns_per_train"]["DI_cutoff"] = cutoff
dk.data["extra_returns_per_train"]["label_period_candles"] = (
- self.get_optuna_params(pair, "label").get("label_period_candles")
+ self.get_optuna_params(pair, self._OPTUNA_NAMESPACES[2]).get(
+ "label_period_candles"
+ ) # "label"
)
dk.data["extra_returns_per_train"]["label_natr_ratio"] = self.get_optuna_params(
- pair, "label"
+ pair,
+ self._OPTUNA_NAMESPACES[2], # "label"
).get("label_natr_ratio")
- hp_rmse = self.optuna_validate_value(self.get_optuna_value(pair, "hp"))
+ hp_rmse = self.optuna_validate_value(
+ self.get_optuna_value(pair, self._OPTUNA_NAMESPACES[0])
+ ) # "hp"
dk.data["extra_returns_per_train"]["hp_rmse"] = (
hp_rmse if hp_rmse is not None else np.inf
)
- train_rmse = self.optuna_validate_value(self.get_optuna_value(pair, "train"))
+ train_rmse = self.optuna_validate_value(
+ self.get_optuna_value(pair, self._OPTUNA_NAMESPACES[1])
+ ) # "train"
dk.data["extra_returns_per_train"]["train_rmse"] = (
train_rmse
if (train_rmse is not None and hp_rmse is not None and train_rmse < hp_rmse)
thresholds_candles = max(2, int(label_period_cycles)) * label_period_candles
pred_extrema = pred_df.get(EXTREMA_COLUMN).iloc[-thresholds_candles:].copy()
+
+ extrema_selection = str(
+ self.freqai_info.get(
+ "prediction_extrema_selection",
+ self._EXTREMA_SELECTION_METHODS[0],
+ )
+ )
+ if extrema_selection not in self._extrema_selection_methods_set():
+ raise ValueError(
+ f"Unsupported extrema selection method: {extrema_selection}. "
+ f"Supported methods are {', '.join(self._EXTREMA_SELECTION_METHODS)}"
+ )
+ extrema_selection: ExtremaSelectionMethod = extrema_selection # type: ignore[assignment]
thresholds_smoothing = str(
self.freqai_info.get("prediction_thresholds_smoothing", "mean")
)
self.freqai_info.get("prediction_thresholds_alpha", 12.0)
)
return QuickAdapterRegressorV3.soft_extremum_min_max(
- pred_extrema, thresholds_alpha
+ pred_extrema, thresholds_alpha, extrema_selection
)
elif thresholds_smoothing in skimage_thresholds_smoothing_methods:
return QuickAdapterRegressorV3.skimage_min_max(
- pred_extrema, thresholds_smoothing
+ pred_extrema, thresholds_smoothing, extrema_selection
)
else:
raise ValueError(
)
@staticmethod
- def get_pred_min_max(pred_extrema: pd.Series) -> tuple[pd.Series, pd.Series]:
+ def get_pred_min_max(
+ pred_extrema: pd.Series,
+ extrema_selection: ExtremaSelectionMethod,
+ ) -> tuple[pd.Series, pd.Series]:
pred_extrema = (
pd.to_numeric(pred_extrema, errors="coerce")
.where(np.isfinite, np.nan)
)
if pred_extrema.empty:
return pd.Series(dtype=float), pd.Series(dtype=float)
- n_pred_minima = max(1, sp.signal.find_peaks(-pred_extrema)[0].size)
- n_pred_maxima = max(1, sp.signal.find_peaks(pred_extrema)[0].size)
- sorted_pred_extrema = pred_extrema.sort_values(ascending=True)
- return sorted_pred_extrema.iloc[:n_pred_minima], sorted_pred_extrema.iloc[
- -n_pred_maxima:
- ]
+ minima_indices = sp.signal.find_peaks(-pred_extrema)[0]
+ maxima_indices = sp.signal.find_peaks(pred_extrema)[0]
+
+ if extrema_selection == QuickAdapterRegressorV3._EXTREMA_SELECTION_METHODS[0]:
+ pred_minima = (
+ pred_extrema.iloc[minima_indices]
+ if minima_indices.size > 0
+ else pd.Series(dtype=float)
+ )
+ pred_maxima = (
+ pred_extrema.iloc[maxima_indices]
+ if maxima_indices.size > 0
+ else pd.Series(dtype=float)
+ )
+ elif extrema_selection == QuickAdapterRegressorV3._EXTREMA_SELECTION_METHODS[1]:
+ n_minima = minima_indices.size
+ n_maxima = maxima_indices.size
+
+ if n_minima > 0:
+ pred_minima = pred_extrema.nsmallest(n_minima)
+ else:
+ pred_minima = pd.Series(dtype=float)
+
+ if n_maxima > 0:
+ pred_maxima = pred_extrema.nlargest(n_maxima)
+ else:
+ pred_maxima = pd.Series(dtype=float)
+ else:
+ raise ValueError(
+ f"Unsupported extrema selection method: {extrema_selection}. "
+ f"Supported methods are {', '.join(QuickAdapterRegressorV3._EXTREMA_SELECTION_METHODS)}"
+ )
+
+ return pred_minima, pred_maxima
@staticmethod
def safe_min_pred(pred_extrema: pd.Series) -> float:
@staticmethod
def soft_extremum_min_max(
- pred_extrema: pd.Series, alpha: float
+ pred_extrema: pd.Series,
+ alpha: float,
+ extrema_selection: ExtremaSelectionMethod,
) -> tuple[float, float]:
if alpha < 0:
raise ValueError("alpha must be non-negative")
pred_minima, pred_maxima = QuickAdapterRegressorV3.get_pred_min_max(
- pred_extrema
+ pred_extrema, extrema_selection
)
soft_minimum = soft_extremum(pred_minima, alpha=-alpha)
if not np.isfinite(soft_minimum):
return soft_minimum, soft_maximum
@staticmethod
- def skimage_min_max(pred_extrema: pd.Series, method: str) -> tuple[float, float]:
+ def skimage_min_max(
+ pred_extrema: pd.Series,
+ method: str,
+ extrema_selection: ExtremaSelectionMethod,
+ ) -> tuple[float, float]:
pred_minima, pred_maxima = QuickAdapterRegressorV3.get_pred_min_max(
- pred_extrema
+ pred_extrema, extrema_selection
)
method_functions = {
def get_multi_objective_study_best_trial(
self, namespace: str, study: optuna.study.Study
) -> Optional[optuna.trial.FrozenTrial]:
- if namespace != "label":
- raise ValueError(f"Invalid namespace: {namespace}")
+ if namespace not in {self._OPTUNA_NAMESPACES[2]}: # Only "label"
+ raise ValueError(
+ f"Invalid namespace: {namespace}. "
+ f"Expected {self._OPTUNA_NAMESPACES[2]}" # Only label
+ )
n_objectives = len(study.directions)
if n_objectives < 2:
raise ValueError(
storage_dir = self.full_path
storage_filename = f"optuna-{pair.split('/')[0]}"
storage_backend = self._optuna_config.get("storage")
- if storage_backend == "sqlite":
+ if storage_backend == self._OPTUNA_STORAGE_BACKENDS[0]: # "sqlite"
storage = optuna.storages.RDBStorage(
url=f"sqlite:///{storage_dir}/{storage_filename}.sqlite",
heartbeat_interval=60,
max_retry=3
),
)
- elif storage_backend == "file":
+ elif storage_backend == self._OPTUNA_STORAGE_BACKENDS[1]: # "file"
storage = optuna.storages.JournalStorage(
optuna.storages.journal.JournalFileBackend(
f"{storage_dir}/{storage_filename}.log"
)
else:
raise ValueError(
- f"Unsupported optuna storage backend: {storage_backend}. Supported backends are 'sqlite' and 'file'"
+ f"Unsupported optuna storage backend: {storage_backend}. "
+ f"Supported backends are {', '.join(self._OPTUNA_STORAGE_BACKENDS)}"
)
return storage
return optuna.pruners.NopPruner()
def optuna_create_sampler(self) -> optuna.samplers.BaseSampler:
- sampler = self._optuna_config.get("sampler", "tpe")
- if sampler == "auto":
+ sampler = self._optuna_config.get("sampler", self._OPTUNA_SAMPLERS[0])
+ if sampler == self._OPTUNA_SAMPLERS[1]: # "auto"
return optunahub.load_module("samplers/auto_sampler").AutoSampler(
seed=self._optuna_config.get("seed")
)
- elif sampler == "tpe":
+ elif sampler == self._OPTUNA_SAMPLERS[0]: # "tpe"
return optuna.samplers.TPESampler(
n_startup_trials=self._optuna_config.get("n_startup_trials"),
multivariate=True,
)
else:
raise ValueError(
- f"Unsupported sampler: '{sampler}'. Supported samplers: 'tpe', 'auto'"
+ f"Unsupported sampler: {sampler}. "
+ f"Supported samplers are {', '.join(self._OPTUNA_SAMPLERS)}"
)
def optuna_create_study(
MAXIMA_THRESHOLD_COLUMN = "&s-maxima_threshold"
MINIMA_THRESHOLD_COLUMN = "&s-minima_threshold"
+# Type aliases
+TradeDirection = Literal["long", "short"]
+InterpolationDirection = Literal["direct", "inverse"]
+OrderType = Literal["entry", "exit"]
+TradingMode = Literal["spot", "margin", "futures"]
+
class QuickAdapterV3(IStrategy):
"""
INTERFACE_VERSION = 3
+ _TRADE_DIRECTIONS: tuple[TradeDirection, ...] = ("long", "short")
+ _INTERPOLATION_DIRECTIONS: tuple[InterpolationDirection, ...] = (
+ "direct",
+ "inverse",
+ )
+ _ORDER_TYPES: tuple[OrderType, ...] = ("entry", "exit")
+ _TRADING_MODES: tuple[TradingMode, ...] = ("spot", "margin", "futures")
+
def version(self) -> str:
return "3.3.170"
def can_short(self) -> bool:
return self.is_short_allowed()
+ @staticmethod
+ def _trade_directions_set() -> set[TradeDirection]:
+ return {
+ QuickAdapterV3._TRADE_DIRECTIONS[0],
+ QuickAdapterV3._TRADE_DIRECTIONS[1],
+ }
+
+ @staticmethod
+ def _order_types_set() -> set[OrderType]:
+ return {QuickAdapterV3._ORDER_TYPES[0], QuickAdapterV3._ORDER_TYPES[1]}
+
@cached_property
def plot_config(self) -> dict[str, Any]:
return {
if isna(candle_label_natr_value_quantile):
return np.nan
- if interpolation_direction == "direct":
+ if interpolation_direction == self._INTERPOLATION_DIRECTIONS[0]: # "direct"
natr_ratio_percent = (
min_natr_ratio_percent
+ (max_natr_ratio_percent - min_natr_ratio_percent)
* candle_label_natr_value_quantile**quantile_exponent
)
- elif interpolation_direction == "inverse":
+ elif interpolation_direction == self._INTERPOLATION_DIRECTIONS[1]: # "inverse"
natr_ratio_percent = (
max_natr_ratio_percent
- (max_natr_ratio_percent - min_natr_ratio_percent)
)
else:
raise ValueError(
- f"Invalid interpolation_direction: {interpolation_direction}. Expected 'direct' or 'inverse'"
+ f"Invalid interpolation_direction: {interpolation_direction}. "
+ f"Expected {', '.join(self._INTERPOLATION_DIRECTIONS)}"
)
candle_deviation = (
candle_label_natr_value / 100.0
min_natr_ratio_percent=min_natr_ratio_percent,
max_natr_ratio_percent=max_natr_ratio_percent,
candle_idx=candle_idx,
- interpolation_direction="direct",
+ interpolation_direction=self._INTERPOLATION_DIRECTIONS[0], # "direct"
)
if isna(current_deviation) or current_deviation <= 0:
return np.nan
is_candle_bullish: bool = candle_close > candle_open
is_candle_bearish: bool = candle_close < candle_open
- if side == "long":
+ if side == self._TRADE_DIRECTIONS[0]: # "long"
base_price = (
QuickAdapterV3.weighted_close(candle)
if is_candle_bearish
else candle_close
)
candle_threshold = base_price * (1 + current_deviation)
- elif side == "short":
+ elif side == self._TRADE_DIRECTIONS[1]: # "short"
base_price = (
QuickAdapterV3.weighted_close(candle)
if is_candle_bullish
)
candle_threshold = base_price * (1 - current_deviation)
else:
- raise ValueError(f"Invalid side: {side}. Expected 'long' or 'short'")
+ raise ValueError(
+ f"Invalid side: {side}. Expected {', '.join(self._TRADE_DIRECTIONS)}"
+ )
self._candle_threshold_cache[cache_key] = candle_threshold
return self._candle_threshold_cache[cache_key]
"""
if df.empty:
return False
- if side not in {"long", "short"}:
+ if side not in self._sides_set():
return False
- if order not in {"entry", "exit"}:
+ if order not in self._order_types_set():
return False
trade_direction = side
candle_idx=-1,
)
current_ok = np.isfinite(current_threshold) and (
- (side == "long" and rate > current_threshold)
- or (side == "short" and rate < current_threshold)
- )
- if order == "exit":
- if side == "long":
- trade_direction = "short"
- if side == "short":
- trade_direction = "long"
+ (side == self._TRADE_DIRECTIONS[0] and rate > current_threshold) # "long"
+ or (
+ side == self._TRADE_DIRECTIONS[1] and rate < current_threshold
+ ) # "short"
+ )
+ if order == self._ORDER_TYPES[1]: # "exit"
+ if side == self._TRADE_DIRECTIONS[0]: # "long"
+ trade_direction = self._TRADE_DIRECTIONS[1] # "short"
+ if side == self._TRADE_DIRECTIONS[1]: # "short"
+ trade_direction = self._TRADE_DIRECTIONS[0] # "long"
if not current_ok:
logger.info(
f"User denied {trade_direction} {order} for {pair}: rate {format_number(rate)} did not break threshold {format_number(current_threshold)}"
):
return current_ok
- if (side == "long" and not (close_k > threshold_k)) or (
- side == "short" and not (close_k < threshold_k)
+ if (
+ side == self._TRADE_DIRECTIONS[0] and not (close_k > threshold_k)
+ ) or ( # "long"
+ side == self._TRADE_DIRECTIONS[1]
+ and not (close_k < threshold_k) # "short"
):
logger.info(
f"User denied {trade_direction} {order} for {pair}: "
trade.set_custom_data("last_outlier_date", last_candle_date.isoformat())
if (
- trade.trade_direction == "short"
+ trade.trade_direction == self._TRADE_DIRECTIONS[1] # "short"
and last_candle.get("do_predict") == 1
and last_candle.get("DI_catch") == 1
and last_candle.get(EXTREMA_COLUMN) < last_candle.get("minima_threshold")
and self.reversal_confirmed(
df,
pair,
- "long",
- "exit",
+ self._TRADE_DIRECTIONS[0], # "long"
+ self._ORDER_TYPES[1], # "exit"
current_rate,
self._reversal_lookback_period,
self._reversal_decay_ratio,
):
return "minima_detected_short"
if (
- trade.trade_direction == "long"
+ trade.trade_direction == self._TRADE_DIRECTIONS[0] # "long"
and last_candle.get("do_predict") == 1
and last_candle.get("DI_catch") == 1
and last_candle.get(EXTREMA_COLUMN) > last_candle.get("maxima_threshold")
and self.reversal_confirmed(
df,
pair,
- "short",
- "exit",
+ self._TRADE_DIRECTIONS[1], # "short"
+ self._ORDER_TYPES[1], # "exit"
current_rate,
self._reversal_lookback_period,
self._reversal_decay_ratio,
side: str,
**kwargs,
) -> bool:
- if side not in {"long", "short"}:
+ if side not in self._sides_set():
return False
- if side == "short" and not self.can_short:
+ if side == self._TRADE_DIRECTIONS[1] and not self.can_short: # "short"
logger.info(f"User denied short entry for {pair}: shorting not allowed")
return False
if Trade.get_open_trade_count() >= self.config.get("max_open_trades"):
df,
pair,
side,
- "entry",
+ self._ORDER_TYPES[0], # "entry"
rate,
self._reversal_lookback_period,
self._reversal_decay_ratio,
def is_short_allowed(self) -> bool:
trading_mode = self.config.get("trading_mode")
- if trading_mode in {"margin", "futures"}:
+ if trading_mode in {
+ self._TRADING_MODES[1],
+ self._TRADING_MODES[2],
+ }: # margin, futures
return True
- elif trading_mode == "spot":
+ elif trading_mode == self._TRADING_MODES[0]: # "spot"
return False
else:
- raise ValueError(f"Invalid trading_mode: {trading_mode}")
+ raise ValueError(
+ f"Invalid trading_mode: {trading_mode}. "
+ f"Expected {', '.join(self._TRADING_MODES)}"
+ )
def leverage(
self,