From: Jérôme Benoit Date: Wed, 19 Nov 2025 13:07:26 +0000 (+0100) Subject: refactor: cleanup access to constants properties X-Git-Url: https://git.piment-noir.org/?a=commitdiff_plain;h=b962f5a482877f834290ce39582f0e201558c112;p=freqai-strategies.git refactor: cleanup access to constants properties Signed-off-by: Jérôme Benoit --- diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index 6a1edbd..da734e3 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -74,7 +74,7 @@ from stable_baselines3.common.vec_env import ( ModelType = Literal["PPO", "RecurrentPPO", "MaskablePPO", "DQN", "QRDQN"] ScheduleType = Literal["linear", "constant", "unknown"] -ScheduleTypeKnown = Literal["linear", "constant"] # Subset for get_schedule() function +ScheduleTypeKnown = Literal["linear", "constant"] ExitPotentialMode = Literal[ "canonical", "non_canonical", @@ -217,11 +217,13 @@ class ReforceXY(BaseReinforcementLearningModel): "FreqAI model requires StaticPairList method defined in pairlists configuration and pair_whitelist defined in exchange section configuration" ) self.action_masking: bool = ( - self.model_type == self._MODEL_TYPES[2] + self.model_type == ReforceXY._MODEL_TYPES[2] ) # "MaskablePPO" self.rl_config.setdefault("action_masking", self.action_masking) self.inference_masking: bool = self.rl_config.get("inference_masking", True) - self.recurrent: bool = self.model_type == self._MODEL_TYPES[1] # "RecurrentPPO" + self.recurrent: bool = ( + self.model_type == ReforceXY._MODEL_TYPES[1] + ) # "RecurrentPPO" self.lr_schedule: bool = self.rl_config.get("lr_schedule", False) self.cr_schedule: bool = self.rl_config.get("cr_schedule", False) self.n_envs: int = self.rl_config.get("n_envs", 1) @@ -496,7 +498,7 @@ class ReforceXY(BaseReinforcementLearningModel): if isinstance(lr, (int, float)): lr = float(lr) model_params["learning_rate"] = get_schedule( - cast(ScheduleTypeKnown, self._SCHEDULE_TYPES[0]), lr + cast(ScheduleTypeKnown, ReforceXY._SCHEDULE_TYPES[0]), lr ) logger.info( "Learning rate linear schedule enabled, initial value: %s", lr @@ -505,19 +507,19 @@ class ReforceXY(BaseReinforcementLearningModel): # "PPO" if ( not self.hyperopt - and self._MODEL_TYPES[0] in self.model_type + and ReforceXY._MODEL_TYPES[0] in self.model_type and self.cr_schedule ): cr = model_params.get("clip_range", 0.2) if isinstance(cr, (int, float)): cr = float(cr) model_params["clip_range"] = get_schedule( - cast(ScheduleTypeKnown, self._SCHEDULE_TYPES[0]), cr + cast(ScheduleTypeKnown, ReforceXY._SCHEDULE_TYPES[0]), cr ) logger.info("Clip range linear schedule enabled, initial value: %s", cr) # "DQN" - if self._MODEL_TYPES[3] in self.model_type: + if ReforceXY._MODEL_TYPES[3] in self.model_type: if model_params.get("gradient_steps") is None: model_params["gradient_steps"] = compute_gradient_steps( model_params.get("train_freq"), model_params.get("subsample_steps") @@ -536,9 +538,9 @@ class ReforceXY(BaseReinforcementLearningModel): ] = model_params.get("policy_kwargs", {}).get("net_arch", default_net_arch) # "PPO" - if self._MODEL_TYPES[0] in self.model_type: + if ReforceXY._MODEL_TYPES[0] in self.model_type: if isinstance(net_arch, str): - if net_arch in self._NET_ARCH_SIZES: + if net_arch in ReforceXY._NET_ARCH_SIZES: model_params["policy_kwargs"]["net_arch"] = get_net_arch( self.model_type, cast(NetArchSize, net_arch), @@ -576,7 +578,7 @@ class ReforceXY(BaseReinforcementLearningModel): } else: if isinstance(net_arch, str): - if net_arch in self._NET_ARCH_SIZES: + if net_arch in ReforceXY._NET_ARCH_SIZES: model_params["policy_kwargs"]["net_arch"] = get_net_arch( self.model_type, cast(NetArchSize, net_arch), @@ -594,12 +596,12 @@ class ReforceXY(BaseReinforcementLearningModel): model_params["policy_kwargs"]["activation_fn"] = get_activation_fn( model_params.get("policy_kwargs", {}).get( - "activation_fn", self._ACTIVATION_FUNCTIONS[1] + "activation_fn", ReforceXY._ACTIVATION_FUNCTIONS[1] ) # "relu" ) model_params["policy_kwargs"]["optimizer_class"] = get_optimizer_class( model_params.get("policy_kwargs", {}).get( - "optimizer_class", self._OPTIMIZER_CLASSES[1] + "optimizer_class", ReforceXY._OPTIMIZER_CLASSES[1] ) # "adamw" ) @@ -638,7 +640,7 @@ class ReforceXY(BaseReinforcementLearningModel): if total_timesteps <= 0: return 1 # "PPO" - if self._MODEL_TYPES[0] in self.model_type: + if ReforceXY._MODEL_TYPES[0] in self.model_type: eval_freq: Optional[int] = None if model_params: n_steps = model_params.get("n_steps") @@ -790,7 +792,7 @@ class ReforceXY(BaseReinforcementLearningModel): logger.info("%s params: %s", self.model_type, model_params) # "PPO" - if self._MODEL_TYPES[0] in self.model_type: + if ReforceXY._MODEL_TYPES[0] in self.model_type: n_steps = model_params.get("n_steps", 0) min_timesteps = 2 * n_steps * self.n_envs if total_timesteps <= min_timesteps: @@ -1040,23 +1042,23 @@ class ReforceXY(BaseReinforcementLearningModel): storage_dir = self.full_path storage_filename = f"optuna-{pair.split('/')[0]}" storage_backend: StorageBackend = self.rl_config_optuna.get( - "storage", self._STORAGE_BACKENDS[0] + "storage", ReforceXY._STORAGE_BACKENDS[0] ) # "sqlite" # "sqlite" - if storage_backend == self._STORAGE_BACKENDS[0]: + if storage_backend == ReforceXY._STORAGE_BACKENDS[0]: storage = RDBStorage( url=f"sqlite:///{storage_dir}/{storage_filename}.sqlite", heartbeat_interval=60, failed_trial_callback=RetryFailedTrialCallback(max_retry=3), ) # "file" - elif storage_backend == self._STORAGE_BACKENDS[1]: + elif storage_backend == ReforceXY._STORAGE_BACKENDS[1]: storage = JournalStorage( JournalFileBackend(f"{storage_dir}/{storage_filename}.log") ) else: raise ValueError( - f"Unsupported storage backend: {storage_backend}. Supported backends are: {', '.join(self._STORAGE_BACKENDS)}" + f"Unsupported storage backend: {storage_backend}. Supported backends are: {', '.join(ReforceXY._STORAGE_BACKENDS)}" ) return storage @@ -1072,15 +1074,15 @@ class ReforceXY(BaseReinforcementLearningModel): def create_sampler(self) -> BaseSampler: sampler: SamplerType = self.rl_config_optuna.get( - "sampler", self._SAMPLER_TYPES[0] + "sampler", ReforceXY._SAMPLER_TYPES[0] ) # "tpe" # "auto" - if sampler == self._SAMPLER_TYPES[1]: + if sampler == ReforceXY._SAMPLER_TYPES[1]: return optunahub.load_module("samplers/auto_sampler").AutoSampler( seed=self.rl_config_optuna.get("seed", 42) ) # "tpe" - elif sampler == self._SAMPLER_TYPES[0]: + elif sampler == ReforceXY._SAMPLER_TYPES[0]: return TPESampler( n_startup_trials=self.optuna_n_startup_trials, multivariate=True, @@ -1089,7 +1091,7 @@ class ReforceXY(BaseReinforcementLearningModel): ) else: raise ValueError( - f"Unsupported sampler: {sampler}. Supported samplers: {', '.join(self._SAMPLER_TYPES)}" + f"Unsupported sampler: {sampler}. Supported samplers: {', '.join(ReforceXY._SAMPLER_TYPES)}" ) @staticmethod @@ -1115,7 +1117,7 @@ class ReforceXY(BaseReinforcementLearningModel): if continuous: ReforceXY.delete_study(study_name, storage) # "PPO" - if self._MODEL_TYPES[0] in self.model_type: + if ReforceXY._MODEL_TYPES[0] in self.model_type: resource_eval_freq = min(PPO_N_STEPS) else: resource_eval_freq = self.get_eval_freq(total_timesteps, hyperopt=True) @@ -1322,16 +1324,16 @@ class ReforceXY(BaseReinforcementLearningModel): def get_optuna_params(self, trial: Trial) -> Dict[str, Any]: # "RecurrentPPO" - if self._MODEL_TYPES[1] in self.model_type: + if ReforceXY._MODEL_TYPES[1] in self.model_type: return sample_params_recurrentppo(trial) # "PPO" - elif self._MODEL_TYPES[0] in self.model_type: + elif ReforceXY._MODEL_TYPES[0] in self.model_type: return sample_params_ppo(trial) # "QRDQN" - elif self._MODEL_TYPES[4] in self.model_type: + elif ReforceXY._MODEL_TYPES[4] in self.model_type: return sample_params_qrdqn(trial) # "DQN" - elif self._MODEL_TYPES[3] in self.model_type: + elif ReforceXY._MODEL_TYPES[3] in self.model_type: return sample_params_dqn(trial) else: raise NotImplementedError(f"{self.model_type} not supported for hyperopt") @@ -1347,7 +1349,7 @@ class ReforceXY(BaseReinforcementLearningModel): params = self.get_optuna_params(trial) # "PPO" - if self._MODEL_TYPES[0] in self.model_type: + if ReforceXY._MODEL_TYPES[0] in self.model_type: n_steps = params.get("n_steps") if n_steps * self.n_envs > total_timesteps: raise TrialPruned( @@ -1360,7 +1362,7 @@ class ReforceXY(BaseReinforcementLearningModel): ) # "DQN" - if self._MODEL_TYPES[3] in self.model_type: + if ReforceXY._MODEL_TYPES[3] in self.model_type: gradient_steps = params.get("gradient_steps") if isinstance(gradient_steps, int) and gradient_steps <= 0: raise TrialPruned(f"{gradient_steps=} is negative or zero") @@ -1380,7 +1382,7 @@ class ReforceXY(BaseReinforcementLearningModel): logger.info("Trial %s params: %s", trial.number, params) # "PPO" - if self._MODEL_TYPES[0] in self.model_type: + if ReforceXY._MODEL_TYPES[0] in self.model_type: n_steps = params.get("n_steps", 0) if n_steps > 0: rollout = n_steps * self.n_envs diff --git a/ReforceXY/user_data/strategies/RLAgentStrategy.py b/ReforceXY/user_data/strategies/RLAgentStrategy.py index 1c37ecf..e3bc5ff 100644 --- a/ReforceXY/user_data/strategies/RLAgentStrategy.py +++ b/ReforceXY/user_data/strategies/RLAgentStrategy.py @@ -82,21 +82,21 @@ class RLAgentStrategy(IStrategy): ) -> DataFrame: enter_long_conditions = [ dataframe.get("do_predict") == 1, - dataframe.get(ACTION_COLUMN) == self._ACTION_ENTER_LONG, # 1, + dataframe.get(ACTION_COLUMN) == RLAgentStrategy._ACTION_ENTER_LONG, # 1, ] dataframe.loc[ reduce(lambda x, y: x & y, enter_long_conditions), ["enter_long", "enter_tag"], - ] = (1, self._TRADE_DIRECTIONS[0]) # "long" + ] = (1, RLAgentStrategy._TRADE_DIRECTIONS[0]) # "long" enter_short_conditions = [ dataframe.get("do_predict") == 1, - dataframe.get(ACTION_COLUMN) == self._ACTION_ENTER_SHORT, # 3, + dataframe.get(ACTION_COLUMN) == RLAgentStrategy._ACTION_ENTER_SHORT, # 3, ] dataframe.loc[ reduce(lambda x, y: x & y, enter_short_conditions), ["enter_short", "enter_tag"], - ] = (1, self._TRADE_DIRECTIONS[1]) # "short" + ] = (1, RLAgentStrategy._TRADE_DIRECTIONS[1]) # "short" return dataframe @@ -105,13 +105,13 @@ class RLAgentStrategy(IStrategy): ) -> DataFrame: exit_long_conditions = [ dataframe.get("do_predict") == 1, - dataframe.get(ACTION_COLUMN) == self._ACTION_EXIT_LONG, # 2, + dataframe.get(ACTION_COLUMN) == RLAgentStrategy._ACTION_EXIT_LONG, # 2, ] dataframe.loc[reduce(lambda x, y: x & y, exit_long_conditions), "exit_long"] = 1 exit_short_conditions = [ dataframe.get("do_predict") == 1, - dataframe.get(ACTION_COLUMN) == self._ACTION_EXIT_SHORT, # 4, + dataframe.get(ACTION_COLUMN) == RLAgentStrategy._ACTION_EXIT_SHORT, # 4, ] dataframe.loc[ reduce(lambda x, y: x & y, exit_short_conditions), "exit_short" @@ -159,10 +159,13 @@ class RLAgentStrategy(IStrategy): def is_short_allowed(self) -> bool: trading_mode = self.config.get("trading_mode") # "margin", "futures" - if trading_mode in {self._TRADING_MODES[0], self._TRADING_MODES[1]}: + if trading_mode in { + RLAgentStrategy._TRADING_MODES[0], + RLAgentStrategy._TRADING_MODES[1], + }: return True # "spot" - elif trading_mode == self._TRADING_MODES[2]: + elif trading_mode == RLAgentStrategy._TRADING_MODES[2]: return False else: raise ValueError(f"Invalid trading_mode: {trading_mode}") diff --git a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py index e214346..26dca00 100644 --- a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py +++ b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py @@ -97,8 +97,8 @@ class QuickAdapterRegressorV3(BaseRegressionModel): .get("n_jobs", 1), max(int(self.max_system_threads / 4), 1), ), - "sampler": self._OPTUNA_SAMPLERS[0], # "tpe" - "storage": self._OPTUNA_STORAGE_BACKENDS[1], # "file" + "sampler": QuickAdapterRegressorV3._OPTUNA_SAMPLERS[0], # "tpe" + "storage": QuickAdapterRegressorV3._OPTUNA_STORAGE_BACKENDS[1], # "file" "continuous": True, "warm_start": True, "n_startup_trials": 15, @@ -226,22 +226,30 @@ class QuickAdapterRegressorV3(BaseRegressionModel): self._optuna_train_value[pair] = -1 self._optuna_label_values[pair] = [-1, -1] self._optuna_hp_params[pair] = ( - self.optuna_load_best_params(pair, self._OPTUNA_NAMESPACES[0]) # "hp" - if self.optuna_load_best_params(pair, self._OPTUNA_NAMESPACES[0]) + self.optuna_load_best_params( + pair, QuickAdapterRegressorV3._OPTUNA_NAMESPACES[0] + ) # "hp" + if self.optuna_load_best_params( + pair, QuickAdapterRegressorV3._OPTUNA_NAMESPACES[0] + ) else {} ) self._optuna_train_params[pair] = ( self.optuna_load_best_params( - pair, self._OPTUNA_NAMESPACES[1] + pair, QuickAdapterRegressorV3._OPTUNA_NAMESPACES[1] ) # "train" - if self.optuna_load_best_params(pair, self._OPTUNA_NAMESPACES[1]) + if self.optuna_load_best_params( + pair, QuickAdapterRegressorV3._OPTUNA_NAMESPACES[1] + ) else {} ) self._optuna_label_params[pair] = ( self.optuna_load_best_params( - pair, self._OPTUNA_NAMESPACES[2] + pair, QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2] ) # "label" - if self.optuna_load_best_params(pair, self._OPTUNA_NAMESPACES[2]) + if self.optuna_load_best_params( + pair, QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2] + ) else { "label_period_candles": self.ft_params.get( "label_period_candles", @@ -263,76 +271,76 @@ class QuickAdapterRegressorV3(BaseRegressionModel): ) def get_optuna_params(self, pair: str, namespace: str) -> dict[str, Any]: - if namespace == self._OPTUNA_NAMESPACES[0]: # "hp" + if namespace == QuickAdapterRegressorV3._OPTUNA_NAMESPACES[0]: # "hp" params = self._optuna_hp_params.get(pair) - elif namespace == self._OPTUNA_NAMESPACES[1]: # "train" + elif namespace == QuickAdapterRegressorV3._OPTUNA_NAMESPACES[1]: # "train" params = self._optuna_train_params.get(pair) - elif namespace == self._OPTUNA_NAMESPACES[2]: # "label" + elif namespace == QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2]: # "label" params = self._optuna_label_params.get(pair) else: raise ValueError( f"Invalid namespace: {namespace}. " - f"Expected {', '.join(self._OPTUNA_NAMESPACES)}" + f"Expected {', '.join(QuickAdapterRegressorV3._OPTUNA_NAMESPACES)}" ) return params def set_optuna_params( self, pair: str, namespace: str, params: dict[str, Any] ) -> None: - if namespace == self._OPTUNA_NAMESPACES[0]: # "hp" + if namespace == QuickAdapterRegressorV3._OPTUNA_NAMESPACES[0]: # "hp" self._optuna_hp_params[pair] = params - elif namespace == self._OPTUNA_NAMESPACES[1]: # "train" + elif namespace == QuickAdapterRegressorV3._OPTUNA_NAMESPACES[1]: # "train" self._optuna_train_params[pair] = params - elif namespace == self._OPTUNA_NAMESPACES[2]: # "label" + elif namespace == QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2]: # "label" self._optuna_label_params[pair] = params else: raise ValueError( f"Invalid namespace: {namespace}. " - f"Expected {', '.join(self._OPTUNA_NAMESPACES)}" + f"Expected {', '.join(QuickAdapterRegressorV3._OPTUNA_NAMESPACES)}" ) def get_optuna_value(self, pair: str, namespace: str) -> float: - if namespace == self._OPTUNA_NAMESPACES[0]: # "hp" + if namespace == QuickAdapterRegressorV3._OPTUNA_NAMESPACES[0]: # "hp" value = self._optuna_hp_value.get(pair) - elif namespace == self._OPTUNA_NAMESPACES[1]: # "train" + elif namespace == QuickAdapterRegressorV3._OPTUNA_NAMESPACES[1]: # "train" value = self._optuna_train_value.get(pair) else: raise ValueError( f"Invalid namespace: {namespace}. " - f"Expected {', '.join(self._OPTUNA_NAMESPACES[:2])}" # Only hp and train + f"Expected {', '.join(QuickAdapterRegressorV3._OPTUNA_NAMESPACES[:2])}" # Only hp and train ) return value def set_optuna_value(self, pair: str, namespace: str, value: float) -> None: - if namespace == self._OPTUNA_NAMESPACES[0]: # "hp" + if namespace == QuickAdapterRegressorV3._OPTUNA_NAMESPACES[0]: # "hp" self._optuna_hp_value[pair] = value - elif namespace == self._OPTUNA_NAMESPACES[1]: # "train" + elif namespace == QuickAdapterRegressorV3._OPTUNA_NAMESPACES[1]: # "train" self._optuna_train_value[pair] = value else: raise ValueError( f"Invalid namespace: {namespace}. " - f"Expected {', '.join(self._OPTUNA_NAMESPACES[:2])}" # Only hp and train + f"Expected {', '.join(QuickAdapterRegressorV3._OPTUNA_NAMESPACES[:2])}" # Only hp and train ) def get_optuna_values(self, pair: str, namespace: str) -> list[float | int]: - if namespace == self._OPTUNA_NAMESPACES[2]: # "label" + if namespace == QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2]: # "label" values = self._optuna_label_values.get(pair) else: raise ValueError( f"Invalid namespace: {namespace}. " - f"Expected {self._OPTUNA_NAMESPACES[2]}" # Only label + f"Expected {QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2]}" # Only label ) return values def set_optuna_values( self, pair: str, namespace: str, values: list[float | int] ) -> None: - if namespace == self._OPTUNA_NAMESPACES[2]: # "label" + if namespace == QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2]: # "label" self._optuna_label_values[pair] = values else: raise ValueError( f"Invalid namespace: {namespace}. " - f"Expected {self._OPTUNA_NAMESPACES[2]}" # Only label + f"Expected {QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2]}" # Only label ) def init_optuna_label_candle_pool(self) -> None: @@ -406,7 +414,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel): if self._optuna_hyperopt: self.optuna_optimize( pair=dk.pair, - namespace=self._OPTUNA_NAMESPACES[0], # "hp" + namespace=QuickAdapterRegressorV3._OPTUNA_NAMESPACES[0], # "hp" objective=lambda trial: hp_objective( trial, str(self.freqai_info.get("regressor", "xgboost")), @@ -416,7 +424,9 @@ class QuickAdapterRegressorV3(BaseRegressionModel): X_test, y_test, test_weights, - self.get_optuna_params(dk.pair, self._OPTUNA_NAMESPACES[0]), # "hp" + self.get_optuna_params( + dk.pair, QuickAdapterRegressorV3._OPTUNA_NAMESPACES[0] + ), # "hp" model_training_parameters, self._optuna_config.get("space_reduction"), self._optuna_config.get("expansion_ratio"), @@ -425,7 +435,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel): ) optuna_hp_params = self.get_optuna_params( - dk.pair, self._OPTUNA_NAMESPACES[0] + dk.pair, QuickAdapterRegressorV3._OPTUNA_NAMESPACES[0] ) # "hp" if optuna_hp_params: model_training_parameters = { @@ -435,7 +445,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel): train_study = self.optuna_optimize( pair=dk.pair, - namespace=self._OPTUNA_NAMESPACES[1], # "train" + namespace=QuickAdapterRegressorV3._OPTUNA_NAMESPACES[1], # "train" objective=lambda trial: train_objective( trial, str(self.freqai_info.get("regressor", "xgboost")), @@ -454,18 +464,18 @@ class QuickAdapterRegressorV3(BaseRegressionModel): ) optuna_hp_value = self.get_optuna_value( - dk.pair, self._OPTUNA_NAMESPACES[0] + dk.pair, QuickAdapterRegressorV3._OPTUNA_NAMESPACES[0] ) # "hp" optuna_train_params = self.get_optuna_params( - dk.pair, self._OPTUNA_NAMESPACES[1] + dk.pair, QuickAdapterRegressorV3._OPTUNA_NAMESPACES[1] ) # "train" optuna_train_value = self.get_optuna_value( - dk.pair, self._OPTUNA_NAMESPACES[1] + dk.pair, QuickAdapterRegressorV3._OPTUNA_NAMESPACES[1] ) # "train" if ( optuna_train_params and self.optuna_validate_params( - dk.pair, self._OPTUNA_NAMESPACES[1], train_study + dk.pair, QuickAdapterRegressorV3._OPTUNA_NAMESPACES[1], train_study ) # "train" and optuna_train_value < optuna_hp_value ): @@ -513,10 +523,12 @@ class QuickAdapterRegressorV3(BaseRegressionModel): namespace: str, callback: Callable[[], None], ) -> None: - if namespace not in {self._OPTUNA_NAMESPACES[2]}: # Only "label" + if namespace not in { + QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2] + }: # Only "label" raise ValueError( f"Invalid namespace: {namespace}. " - f"Expected {self._OPTUNA_NAMESPACES[2]}" # Only label + f"Expected {QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2]}" # Only label ) if not callable(callback): raise ValueError("callback must be callable") @@ -554,10 +566,10 @@ class QuickAdapterRegressorV3(BaseRegressionModel): if self._optuna_hyperopt: self.optuna_throttle_callback( pair=pair, - namespace=self._OPTUNA_NAMESPACES[2], # "label" + namespace=QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2], # "label" callback=lambda: self.optuna_optimize( pair=pair, - namespace=self._OPTUNA_NAMESPACES[2], # "label" + namespace=QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2], # "label" objective=lambda trial: label_objective( trial, self.data_provider.get_pair_dataframe( @@ -610,9 +622,9 @@ class QuickAdapterRegressorV3(BaseRegressionModel): min_pred, max_pred = self.min_max_pred( pred_df, fit_live_predictions_candles, - self.get_optuna_params(pair, self._OPTUNA_NAMESPACES[2]).get( - "label_period_candles" - ), # "label" + self.get_optuna_params( + pair, QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2] + ).get("label_period_candles"), # "label" ) dk.data["extra_returns_per_train"][MINIMA_THRESHOLD_COLUMN] = min_pred dk.data["extra_returns_per_train"][MAXIMA_THRESHOLD_COLUMN] = max_pred @@ -650,17 +662,17 @@ class QuickAdapterRegressorV3(BaseRegressionModel): dk.data["extra_returns_per_train"]["DI_cutoff"] = cutoff dk.data["extra_returns_per_train"]["label_period_candles"] = ( - self.get_optuna_params(pair, self._OPTUNA_NAMESPACES[2]).get( - "label_period_candles" - ) # "label" + self.get_optuna_params( + pair, QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2] + ).get("label_period_candles") # "label" ) dk.data["extra_returns_per_train"]["label_natr_ratio"] = self.get_optuna_params( pair, - self._OPTUNA_NAMESPACES[2], # "label" + QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2], # "label" ).get("label_natr_ratio") hp_rmse = self.optuna_validate_value( - self.get_optuna_value(pair, self._OPTUNA_NAMESPACES[0]) + self.get_optuna_value(pair, QuickAdapterRegressorV3._OPTUNA_NAMESPACES[0]) ) # "hp" dk.data["extra_returns_per_train"]["hp_rmse"] = ( hp_rmse if hp_rmse is not None else np.inf @@ -711,7 +723,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel): extrema_selection = str( self.freqai_info.get( "prediction_extrema_selection", - self._EXTREMA_SELECTION_METHODS[1], + QuickAdapterRegressorV3._EXTREMA_SELECTION_METHODS[1], ) ) if extrema_selection not in self._extrema_selection_methods_set(): @@ -1004,10 +1016,12 @@ class QuickAdapterRegressorV3(BaseRegressionModel): def get_multi_objective_study_best_trial( self, namespace: str, study: optuna.study.Study ) -> Optional[optuna.trial.FrozenTrial]: - if namespace not in {self._OPTUNA_NAMESPACES[2]}: # Only "label" + if namespace not in { + QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2] + }: # Only "label" raise ValueError( f"Invalid namespace: {namespace}. " - f"Expected {self._OPTUNA_NAMESPACES[2]}" # Only label + f"Expected {QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2]}" # Only label ) n_objectives = len(study.directions) if n_objectives < 2: @@ -1643,7 +1657,9 @@ class QuickAdapterRegressorV3(BaseRegressionModel): storage_dir = self.full_path storage_filename = f"optuna-{pair.split('/')[0]}" storage_backend = self._optuna_config.get("storage") - if storage_backend == self._OPTUNA_STORAGE_BACKENDS[0]: # "sqlite" + if ( + storage_backend == QuickAdapterRegressorV3._OPTUNA_STORAGE_BACKENDS[0] + ): # "sqlite" storage = optuna.storages.RDBStorage( url=f"sqlite:///{storage_dir}/{storage_filename}.sqlite", heartbeat_interval=60, @@ -1651,7 +1667,9 @@ class QuickAdapterRegressorV3(BaseRegressionModel): max_retry=3 ), ) - elif storage_backend == self._OPTUNA_STORAGE_BACKENDS[1]: # "file" + elif ( + storage_backend == QuickAdapterRegressorV3._OPTUNA_STORAGE_BACKENDS[1] + ): # "file" storage = optuna.storages.JournalStorage( optuna.storages.journal.JournalFileBackend( f"{storage_dir}/{storage_filename}.log" @@ -1660,7 +1678,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel): else: raise ValueError( f"Unsupported optuna storage backend: {storage_backend}. " - f"Supported backends are {', '.join(self._OPTUNA_STORAGE_BACKENDS)}" + f"Supported backends are {', '.join(QuickAdapterRegressorV3._OPTUNA_STORAGE_BACKENDS)}" ) return storage @@ -1675,12 +1693,14 @@ class QuickAdapterRegressorV3(BaseRegressionModel): return optuna.pruners.NopPruner() def optuna_create_sampler(self) -> optuna.samplers.BaseSampler: - sampler = self._optuna_config.get("sampler", self._OPTUNA_SAMPLERS[0]) - if sampler == self._OPTUNA_SAMPLERS[1]: # "auto" + sampler = self._optuna_config.get( + "sampler", QuickAdapterRegressorV3._OPTUNA_SAMPLERS[0] + ) + if sampler == QuickAdapterRegressorV3._OPTUNA_SAMPLERS[1]: # "auto" return optunahub.load_module("samplers/auto_sampler").AutoSampler( seed=self._optuna_config.get("seed") ) - elif sampler == self._OPTUNA_SAMPLERS[0]: # "tpe" + elif sampler == QuickAdapterRegressorV3._OPTUNA_SAMPLERS[0]: # "tpe" return optuna.samplers.TPESampler( n_startup_trials=self._optuna_config.get("n_startup_trials"), multivariate=True, @@ -1690,7 +1710,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel): else: raise ValueError( f"Unsupported sampler: {sampler}. " - f"Supported samplers are {', '.join(self._OPTUNA_SAMPLERS)}" + f"Supported samplers are {', '.join(QuickAdapterRegressorV3._OPTUNA_SAMPLERS)}" ) def optuna_create_study( diff --git a/quickadapter/user_data/strategies/QuickAdapterV3.py b/quickadapter/user_data/strategies/QuickAdapterV3.py index a7ac5ef..00e0832 100644 --- a/quickadapter/user_data/strategies/QuickAdapterV3.py +++ b/quickadapter/user_data/strategies/QuickAdapterV3.py @@ -665,7 +665,7 @@ class QuickAdapterV3(IStrategy): dataframe.loc[ reduce(lambda x, y: x & y, enter_long_conditions), ["enter_long", "enter_tag"], - ] = (1, self._TRADE_DIRECTIONS[0]) # "long" + ] = (1, QuickAdapterV3._TRADE_DIRECTIONS[0]) # "long" enter_short_conditions = [ dataframe.get("do_predict") == 1, @@ -675,7 +675,7 @@ class QuickAdapterV3(IStrategy): dataframe.loc[ reduce(lambda x, y: x & y, enter_short_conditions), ["enter_short", "enter_tag"], - ] = (1, self._TRADE_DIRECTIONS[1]) # "short" + ] = (1, QuickAdapterV3._TRADE_DIRECTIONS[1]) # "short" return dataframe @@ -1223,13 +1223,17 @@ class QuickAdapterV3(IStrategy): if isna(candle_label_natr_value_quantile): return np.nan - if interpolation_direction == self._INTERPOLATION_DIRECTIONS[0]: # "direct" + if ( + interpolation_direction == QuickAdapterV3._INTERPOLATION_DIRECTIONS[0] + ): # "direct" natr_ratio_percent = ( min_natr_ratio_percent + (max_natr_ratio_percent - min_natr_ratio_percent) * candle_label_natr_value_quantile**quantile_exponent ) - elif interpolation_direction == self._INTERPOLATION_DIRECTIONS[1]: # "inverse" + elif ( + interpolation_direction == QuickAdapterV3._INTERPOLATION_DIRECTIONS[1] + ): # "inverse" natr_ratio_percent = ( max_natr_ratio_percent - (max_natr_ratio_percent - min_natr_ratio_percent) @@ -1238,7 +1242,7 @@ class QuickAdapterV3(IStrategy): else: raise ValueError( f"Invalid interpolation_direction: {interpolation_direction}. " - f"Expected {', '.join(self._INTERPOLATION_DIRECTIONS)}" + f"Expected {', '.join(QuickAdapterV3._INTERPOLATION_DIRECTIONS)}" ) candle_deviation = ( candle_label_natr_value / 100.0 @@ -1278,7 +1282,9 @@ class QuickAdapterV3(IStrategy): min_natr_ratio_percent=min_natr_ratio_percent, max_natr_ratio_percent=max_natr_ratio_percent, candle_idx=candle_idx, - interpolation_direction=self._INTERPOLATION_DIRECTIONS[0], # "direct" + interpolation_direction=QuickAdapterV3._INTERPOLATION_DIRECTIONS[ + 0 + ], # "direct" ) if isna(current_deviation) or current_deviation <= 0: return np.nan @@ -1293,14 +1299,14 @@ class QuickAdapterV3(IStrategy): is_candle_bullish: bool = candle_close > candle_open is_candle_bearish: bool = candle_close < candle_open - if side == self._TRADE_DIRECTIONS[0]: # "long" + if side == QuickAdapterV3._TRADE_DIRECTIONS[0]: # "long" base_price = ( QuickAdapterV3.weighted_close(candle) if is_candle_bearish else candle_close ) candle_threshold = base_price * (1 + current_deviation) - elif side == self._TRADE_DIRECTIONS[1]: # "short" + elif side == QuickAdapterV3._TRADE_DIRECTIONS[1]: # "short" base_price = ( QuickAdapterV3.weighted_close(candle) if is_candle_bullish @@ -1309,7 +1315,7 @@ class QuickAdapterV3(IStrategy): candle_threshold = base_price * (1 - current_deviation) else: raise ValueError( - f"Invalid side: {side}. Expected {', '.join(self._TRADE_DIRECTIONS)}" + f"Invalid side: {side}. Expected {', '.join(QuickAdapterV3._TRADE_DIRECTIONS)}" ) self._candle_threshold_cache[cache_key] = candle_threshold return self._candle_threshold_cache[cache_key] @@ -1424,16 +1430,18 @@ class QuickAdapterV3(IStrategy): candle_idx=-1, ) current_ok = np.isfinite(current_threshold) and ( - (side == self._TRADE_DIRECTIONS[0] and rate > current_threshold) # "long" + ( + side == QuickAdapterV3._TRADE_DIRECTIONS[0] and rate > current_threshold + ) # "long" or ( - side == self._TRADE_DIRECTIONS[1] and rate < current_threshold + side == QuickAdapterV3._TRADE_DIRECTIONS[1] and rate < current_threshold ) # "short" ) - if order == self._ORDER_TYPES[1]: # "exit" - if side == self._TRADE_DIRECTIONS[0]: # "long" - trade_direction = self._TRADE_DIRECTIONS[1] # "short" - if side == self._TRADE_DIRECTIONS[1]: # "short" - trade_direction = self._TRADE_DIRECTIONS[0] # "long" + if order == QuickAdapterV3._ORDER_TYPES[1]: # "exit" + if side == QuickAdapterV3._TRADE_DIRECTIONS[0]: # "long" + trade_direction = QuickAdapterV3._TRADE_DIRECTIONS[1] # "short" + if side == QuickAdapterV3._TRADE_DIRECTIONS[1]: # "short" + trade_direction = QuickAdapterV3._TRADE_DIRECTIONS[0] # "long" if not current_ok: logger.info( f"User denied {trade_direction} {order} for {pair}: rate {format_number(rate)} did not break threshold {format_number(current_threshold)}" @@ -1471,9 +1479,10 @@ class QuickAdapterV3(IStrategy): return current_ok if ( - side == self._TRADE_DIRECTIONS[0] and not (close_k > threshold_k) + side == QuickAdapterV3._TRADE_DIRECTIONS[0] + and not (close_k > threshold_k) ) or ( # "long" - side == self._TRADE_DIRECTIONS[1] + side == QuickAdapterV3._TRADE_DIRECTIONS[1] and not (close_k < threshold_k) # "short" ): logger.info( @@ -1702,15 +1711,15 @@ class QuickAdapterV3(IStrategy): trade.set_custom_data("last_outlier_date", last_candle_date.isoformat()) if ( - trade.trade_direction == self._TRADE_DIRECTIONS[1] # "short" + trade.trade_direction == QuickAdapterV3._TRADE_DIRECTIONS[1] # "short" and last_candle.get("do_predict") == 1 and last_candle.get("DI_catch") == 1 and last_candle.get(EXTREMA_COLUMN) < last_candle.get("minima_threshold") and self.reversal_confirmed( df, pair, - self._TRADE_DIRECTIONS[0], # "long" - self._ORDER_TYPES[1], # "exit" + QuickAdapterV3._TRADE_DIRECTIONS[0], # "long" + QuickAdapterV3._ORDER_TYPES[1], # "exit" current_rate, self._reversal_lookback_period, self._reversal_decay_ratio, @@ -1720,15 +1729,15 @@ class QuickAdapterV3(IStrategy): ): return "minima_detected_short" if ( - trade.trade_direction == self._TRADE_DIRECTIONS[0] # "long" + trade.trade_direction == QuickAdapterV3._TRADE_DIRECTIONS[0] # "long" and last_candle.get("do_predict") == 1 and last_candle.get("DI_catch") == 1 and last_candle.get(EXTREMA_COLUMN) > last_candle.get("maxima_threshold") and self.reversal_confirmed( df, pair, - self._TRADE_DIRECTIONS[1], # "short" - self._ORDER_TYPES[1], # "exit" + QuickAdapterV3._TRADE_DIRECTIONS[1], # "short" + QuickAdapterV3._ORDER_TYPES[1], # "exit" current_rate, self._reversal_lookback_period, self._reversal_decay_ratio, @@ -1839,7 +1848,9 @@ class QuickAdapterV3(IStrategy): ) -> bool: if side not in self._trade_directions_set(): return False - if side == self._TRADE_DIRECTIONS[1] and not self.can_short: # "short" + if ( + side == QuickAdapterV3._TRADE_DIRECTIONS[1] and not self.can_short + ): # "short" logger.info(f"User denied short entry for {pair}: shorting not allowed") return False if Trade.get_open_trade_count() >= self.config.get("max_open_trades"): @@ -1863,7 +1874,7 @@ class QuickAdapterV3(IStrategy): df, pair, side, - self._ORDER_TYPES[0], # "entry" + QuickAdapterV3._ORDER_TYPES[0], # "entry" rate, self._reversal_lookback_period, self._reversal_decay_ratio, @@ -1876,16 +1887,16 @@ class QuickAdapterV3(IStrategy): def is_short_allowed(self) -> bool: trading_mode = self.config.get("trading_mode") if trading_mode in { - self._TRADING_MODES[1], - self._TRADING_MODES[2], + QuickAdapterV3._TRADING_MODES[1], + QuickAdapterV3._TRADING_MODES[2], }: # margin, futures return True - elif trading_mode == self._TRADING_MODES[0]: # "spot" + elif trading_mode == QuickAdapterV3._TRADING_MODES[0]: # "spot" return False else: raise ValueError( f"Invalid trading_mode: {trading_mode}. " - f"Expected {', '.join(self._TRADING_MODES)}" + f"Expected {', '.join(QuickAdapterV3._TRADING_MODES)}" ) def leverage(