]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor: cleanup access to constants properties
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Wed, 19 Nov 2025 13:07:26 +0000 (14:07 +0100)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Wed, 19 Nov 2025 13:07:26 +0000 (14:07 +0100)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py
ReforceXY/user_data/strategies/RLAgentStrategy.py
quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py
quickadapter/user_data/strategies/QuickAdapterV3.py

index 6a1edbdf19038abdbb31cf75be37f5712c672251..da734e38c15284729a847fff579731f79aca32a7 100644 (file)
@@ -74,7 +74,7 @@ from stable_baselines3.common.vec_env import (
 
 ModelType = Literal["PPO", "RecurrentPPO", "MaskablePPO", "DQN", "QRDQN"]
 ScheduleType = Literal["linear", "constant", "unknown"]
-ScheduleTypeKnown = Literal["linear", "constant"]  # Subset for get_schedule() function
+ScheduleTypeKnown = Literal["linear", "constant"]
 ExitPotentialMode = Literal[
     "canonical",
     "non_canonical",
@@ -217,11 +217,13 @@ class ReforceXY(BaseReinforcementLearningModel):
                 "FreqAI model requires StaticPairList method defined in pairlists configuration and pair_whitelist defined in exchange section configuration"
             )
         self.action_masking: bool = (
-            self.model_type == self._MODEL_TYPES[2]
+            self.model_type == ReforceXY._MODEL_TYPES[2]
         )  # "MaskablePPO"
         self.rl_config.setdefault("action_masking", self.action_masking)
         self.inference_masking: bool = self.rl_config.get("inference_masking", True)
-        self.recurrent: bool = self.model_type == self._MODEL_TYPES[1]  # "RecurrentPPO"
+        self.recurrent: bool = (
+            self.model_type == ReforceXY._MODEL_TYPES[1]
+        )  # "RecurrentPPO"
         self.lr_schedule: bool = self.rl_config.get("lr_schedule", False)
         self.cr_schedule: bool = self.rl_config.get("cr_schedule", False)
         self.n_envs: int = self.rl_config.get("n_envs", 1)
@@ -496,7 +498,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             if isinstance(lr, (int, float)):
                 lr = float(lr)
                 model_params["learning_rate"] = get_schedule(
-                    cast(ScheduleTypeKnown, self._SCHEDULE_TYPES[0]), lr
+                    cast(ScheduleTypeKnown, ReforceXY._SCHEDULE_TYPES[0]), lr
                 )
                 logger.info(
                     "Learning rate linear schedule enabled, initial value: %s", lr
@@ -505,19 +507,19 @@ class ReforceXY(BaseReinforcementLearningModel):
         # "PPO"
         if (
             not self.hyperopt
-            and self._MODEL_TYPES[0] in self.model_type
+            and ReforceXY._MODEL_TYPES[0] in self.model_type
             and self.cr_schedule
         ):
             cr = model_params.get("clip_range", 0.2)
             if isinstance(cr, (int, float)):
                 cr = float(cr)
                 model_params["clip_range"] = get_schedule(
-                    cast(ScheduleTypeKnown, self._SCHEDULE_TYPES[0]), cr
+                    cast(ScheduleTypeKnown, ReforceXY._SCHEDULE_TYPES[0]), cr
                 )
                 logger.info("Clip range linear schedule enabled, initial value: %s", cr)
 
         # "DQN"
-        if self._MODEL_TYPES[3] in self.model_type:
+        if ReforceXY._MODEL_TYPES[3] in self.model_type:
             if model_params.get("gradient_steps") is None:
                 model_params["gradient_steps"] = compute_gradient_steps(
                     model_params.get("train_freq"), model_params.get("subsample_steps")
@@ -536,9 +538,9 @@ class ReforceXY(BaseReinforcementLearningModel):
         ] = model_params.get("policy_kwargs", {}).get("net_arch", default_net_arch)
 
         # "PPO"
-        if self._MODEL_TYPES[0] in self.model_type:
+        if ReforceXY._MODEL_TYPES[0] in self.model_type:
             if isinstance(net_arch, str):
-                if net_arch in self._NET_ARCH_SIZES:
+                if net_arch in ReforceXY._NET_ARCH_SIZES:
                     model_params["policy_kwargs"]["net_arch"] = get_net_arch(
                         self.model_type,
                         cast(NetArchSize, net_arch),
@@ -576,7 +578,7 @@ class ReforceXY(BaseReinforcementLearningModel):
                 }
         else:
             if isinstance(net_arch, str):
-                if net_arch in self._NET_ARCH_SIZES:
+                if net_arch in ReforceXY._NET_ARCH_SIZES:
                     model_params["policy_kwargs"]["net_arch"] = get_net_arch(
                         self.model_type,
                         cast(NetArchSize, net_arch),
@@ -594,12 +596,12 @@ class ReforceXY(BaseReinforcementLearningModel):
 
         model_params["policy_kwargs"]["activation_fn"] = get_activation_fn(
             model_params.get("policy_kwargs", {}).get(
-                "activation_fn", self._ACTIVATION_FUNCTIONS[1]
+                "activation_fn", ReforceXY._ACTIVATION_FUNCTIONS[1]
             )  # "relu"
         )
         model_params["policy_kwargs"]["optimizer_class"] = get_optimizer_class(
             model_params.get("policy_kwargs", {}).get(
-                "optimizer_class", self._OPTIMIZER_CLASSES[1]
+                "optimizer_class", ReforceXY._OPTIMIZER_CLASSES[1]
             )  # "adamw"
         )
 
@@ -638,7 +640,7 @@ class ReforceXY(BaseReinforcementLearningModel):
         if total_timesteps <= 0:
             return 1
         # "PPO"
-        if self._MODEL_TYPES[0] in self.model_type:
+        if ReforceXY._MODEL_TYPES[0] in self.model_type:
             eval_freq: Optional[int] = None
             if model_params:
                 n_steps = model_params.get("n_steps")
@@ -790,7 +792,7 @@ class ReforceXY(BaseReinforcementLearningModel):
         logger.info("%s params: %s", self.model_type, model_params)
 
         # "PPO"
-        if self._MODEL_TYPES[0] in self.model_type:
+        if ReforceXY._MODEL_TYPES[0] in self.model_type:
             n_steps = model_params.get("n_steps", 0)
             min_timesteps = 2 * n_steps * self.n_envs
             if total_timesteps <= min_timesteps:
@@ -1040,23 +1042,23 @@ class ReforceXY(BaseReinforcementLearningModel):
         storage_dir = self.full_path
         storage_filename = f"optuna-{pair.split('/')[0]}"
         storage_backend: StorageBackend = self.rl_config_optuna.get(
-            "storage", self._STORAGE_BACKENDS[0]
+            "storage", ReforceXY._STORAGE_BACKENDS[0]
         )  # "sqlite"
         # "sqlite"
-        if storage_backend == self._STORAGE_BACKENDS[0]:
+        if storage_backend == ReforceXY._STORAGE_BACKENDS[0]:
             storage = RDBStorage(
                 url=f"sqlite:///{storage_dir}/{storage_filename}.sqlite",
                 heartbeat_interval=60,
                 failed_trial_callback=RetryFailedTrialCallback(max_retry=3),
             )
         # "file"
-        elif storage_backend == self._STORAGE_BACKENDS[1]:
+        elif storage_backend == ReforceXY._STORAGE_BACKENDS[1]:
             storage = JournalStorage(
                 JournalFileBackend(f"{storage_dir}/{storage_filename}.log")
             )
         else:
             raise ValueError(
-                f"Unsupported storage backend: {storage_backend}. Supported backends are: {', '.join(self._STORAGE_BACKENDS)}"
+                f"Unsupported storage backend: {storage_backend}. Supported backends are: {', '.join(ReforceXY._STORAGE_BACKENDS)}"
             )
         return storage
 
@@ -1072,15 +1074,15 @@ class ReforceXY(BaseReinforcementLearningModel):
 
     def create_sampler(self) -> BaseSampler:
         sampler: SamplerType = self.rl_config_optuna.get(
-            "sampler", self._SAMPLER_TYPES[0]
+            "sampler", ReforceXY._SAMPLER_TYPES[0]
         )  # "tpe"
         # "auto"
-        if sampler == self._SAMPLER_TYPES[1]:
+        if sampler == ReforceXY._SAMPLER_TYPES[1]:
             return optunahub.load_module("samplers/auto_sampler").AutoSampler(
                 seed=self.rl_config_optuna.get("seed", 42)
             )
         # "tpe"
-        elif sampler == self._SAMPLER_TYPES[0]:
+        elif sampler == ReforceXY._SAMPLER_TYPES[0]:
             return TPESampler(
                 n_startup_trials=self.optuna_n_startup_trials,
                 multivariate=True,
@@ -1089,7 +1091,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             )
         else:
             raise ValueError(
-                f"Unsupported sampler: {sampler}. Supported samplers: {', '.join(self._SAMPLER_TYPES)}"
+                f"Unsupported sampler: {sampler}. Supported samplers: {', '.join(ReforceXY._SAMPLER_TYPES)}"
             )
 
     @staticmethod
@@ -1115,7 +1117,7 @@ class ReforceXY(BaseReinforcementLearningModel):
         if continuous:
             ReforceXY.delete_study(study_name, storage)
         # "PPO"
-        if self._MODEL_TYPES[0] in self.model_type:
+        if ReforceXY._MODEL_TYPES[0] in self.model_type:
             resource_eval_freq = min(PPO_N_STEPS)
         else:
             resource_eval_freq = self.get_eval_freq(total_timesteps, hyperopt=True)
@@ -1322,16 +1324,16 @@ class ReforceXY(BaseReinforcementLearningModel):
 
     def get_optuna_params(self, trial: Trial) -> Dict[str, Any]:
         # "RecurrentPPO"
-        if self._MODEL_TYPES[1] in self.model_type:
+        if ReforceXY._MODEL_TYPES[1] in self.model_type:
             return sample_params_recurrentppo(trial)
         # "PPO"
-        elif self._MODEL_TYPES[0] in self.model_type:
+        elif ReforceXY._MODEL_TYPES[0] in self.model_type:
             return sample_params_ppo(trial)
         # "QRDQN"
-        elif self._MODEL_TYPES[4] in self.model_type:
+        elif ReforceXY._MODEL_TYPES[4] in self.model_type:
             return sample_params_qrdqn(trial)
         # "DQN"
-        elif self._MODEL_TYPES[3] in self.model_type:
+        elif ReforceXY._MODEL_TYPES[3] in self.model_type:
             return sample_params_dqn(trial)
         else:
             raise NotImplementedError(f"{self.model_type} not supported for hyperopt")
@@ -1347,7 +1349,7 @@ class ReforceXY(BaseReinforcementLearningModel):
         params = self.get_optuna_params(trial)
 
         # "PPO"
-        if self._MODEL_TYPES[0] in self.model_type:
+        if ReforceXY._MODEL_TYPES[0] in self.model_type:
             n_steps = params.get("n_steps")
             if n_steps * self.n_envs > total_timesteps:
                 raise TrialPruned(
@@ -1360,7 +1362,7 @@ class ReforceXY(BaseReinforcementLearningModel):
                 )
 
         # "DQN"
-        if self._MODEL_TYPES[3] in self.model_type:
+        if ReforceXY._MODEL_TYPES[3] in self.model_type:
             gradient_steps = params.get("gradient_steps")
             if isinstance(gradient_steps, int) and gradient_steps <= 0:
                 raise TrialPruned(f"{gradient_steps=} is negative or zero")
@@ -1380,7 +1382,7 @@ class ReforceXY(BaseReinforcementLearningModel):
         logger.info("Trial %s params: %s", trial.number, params)
 
         # "PPO"
-        if self._MODEL_TYPES[0] in self.model_type:
+        if ReforceXY._MODEL_TYPES[0] in self.model_type:
             n_steps = params.get("n_steps", 0)
             if n_steps > 0:
                 rollout = n_steps * self.n_envs
index 1c37ecf15ec9d442c6c4d5364c22137baaf3afc2..e3bc5ff2ded4771bc659398dc40da0ccda625199 100644 (file)
@@ -82,21 +82,21 @@ class RLAgentStrategy(IStrategy):
     ) -> DataFrame:
         enter_long_conditions = [
             dataframe.get("do_predict") == 1,
-            dataframe.get(ACTION_COLUMN) == self._ACTION_ENTER_LONG,  # 1,
+            dataframe.get(ACTION_COLUMN) == RLAgentStrategy._ACTION_ENTER_LONG,  # 1,
         ]
         dataframe.loc[
             reduce(lambda x, y: x & y, enter_long_conditions),
             ["enter_long", "enter_tag"],
-        ] = (1, self._TRADE_DIRECTIONS[0])  # "long"
+        ] = (1, RLAgentStrategy._TRADE_DIRECTIONS[0])  # "long"
 
         enter_short_conditions = [
             dataframe.get("do_predict") == 1,
-            dataframe.get(ACTION_COLUMN) == self._ACTION_ENTER_SHORT,  # 3,
+            dataframe.get(ACTION_COLUMN) == RLAgentStrategy._ACTION_ENTER_SHORT,  # 3,
         ]
         dataframe.loc[
             reduce(lambda x, y: x & y, enter_short_conditions),
             ["enter_short", "enter_tag"],
-        ] = (1, self._TRADE_DIRECTIONS[1])  # "short"
+        ] = (1, RLAgentStrategy._TRADE_DIRECTIONS[1])  # "short"
 
         return dataframe
 
@@ -105,13 +105,13 @@ class RLAgentStrategy(IStrategy):
     ) -> DataFrame:
         exit_long_conditions = [
             dataframe.get("do_predict") == 1,
-            dataframe.get(ACTION_COLUMN) == self._ACTION_EXIT_LONG,  # 2,
+            dataframe.get(ACTION_COLUMN) == RLAgentStrategy._ACTION_EXIT_LONG,  # 2,
         ]
         dataframe.loc[reduce(lambda x, y: x & y, exit_long_conditions), "exit_long"] = 1
 
         exit_short_conditions = [
             dataframe.get("do_predict") == 1,
-            dataframe.get(ACTION_COLUMN) == self._ACTION_EXIT_SHORT,  # 4,
+            dataframe.get(ACTION_COLUMN) == RLAgentStrategy._ACTION_EXIT_SHORT,  # 4,
         ]
         dataframe.loc[
             reduce(lambda x, y: x & y, exit_short_conditions), "exit_short"
@@ -159,10 +159,13 @@ class RLAgentStrategy(IStrategy):
     def is_short_allowed(self) -> bool:
         trading_mode = self.config.get("trading_mode")
         # "margin", "futures"
-        if trading_mode in {self._TRADING_MODES[0], self._TRADING_MODES[1]}:
+        if trading_mode in {
+            RLAgentStrategy._TRADING_MODES[0],
+            RLAgentStrategy._TRADING_MODES[1],
+        }:
             return True
         # "spot"
-        elif trading_mode == self._TRADING_MODES[2]:
+        elif trading_mode == RLAgentStrategy._TRADING_MODES[2]:
             return False
         else:
             raise ValueError(f"Invalid trading_mode: {trading_mode}")
index e2143465da756f8086ffd423b1bc3bc65573a149..26dca00f100f8a49ee57d430ed9c4e41a9379342 100644 (file)
@@ -97,8 +97,8 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
                 .get("n_jobs", 1),
                 max(int(self.max_system_threads / 4), 1),
             ),
-            "sampler": self._OPTUNA_SAMPLERS[0],  # "tpe"
-            "storage": self._OPTUNA_STORAGE_BACKENDS[1],  # "file"
+            "sampler": QuickAdapterRegressorV3._OPTUNA_SAMPLERS[0],  # "tpe"
+            "storage": QuickAdapterRegressorV3._OPTUNA_STORAGE_BACKENDS[1],  # "file"
             "continuous": True,
             "warm_start": True,
             "n_startup_trials": 15,
@@ -226,22 +226,30 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
             self._optuna_train_value[pair] = -1
             self._optuna_label_values[pair] = [-1, -1]
             self._optuna_hp_params[pair] = (
-                self.optuna_load_best_params(pair, self._OPTUNA_NAMESPACES[0])  # "hp"
-                if self.optuna_load_best_params(pair, self._OPTUNA_NAMESPACES[0])
+                self.optuna_load_best_params(
+                    pair, QuickAdapterRegressorV3._OPTUNA_NAMESPACES[0]
+                )  # "hp"
+                if self.optuna_load_best_params(
+                    pair, QuickAdapterRegressorV3._OPTUNA_NAMESPACES[0]
+                )
                 else {}
             )
             self._optuna_train_params[pair] = (
                 self.optuna_load_best_params(
-                    pair, self._OPTUNA_NAMESPACES[1]
+                    pair, QuickAdapterRegressorV3._OPTUNA_NAMESPACES[1]
                 )  # "train"
-                if self.optuna_load_best_params(pair, self._OPTUNA_NAMESPACES[1])
+                if self.optuna_load_best_params(
+                    pair, QuickAdapterRegressorV3._OPTUNA_NAMESPACES[1]
+                )
                 else {}
             )
             self._optuna_label_params[pair] = (
                 self.optuna_load_best_params(
-                    pair, self._OPTUNA_NAMESPACES[2]
+                    pair, QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2]
                 )  # "label"
-                if self.optuna_load_best_params(pair, self._OPTUNA_NAMESPACES[2])
+                if self.optuna_load_best_params(
+                    pair, QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2]
+                )
                 else {
                     "label_period_candles": self.ft_params.get(
                         "label_period_candles",
@@ -263,76 +271,76 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
         )
 
     def get_optuna_params(self, pair: str, namespace: str) -> dict[str, Any]:
-        if namespace == self._OPTUNA_NAMESPACES[0]:  # "hp"
+        if namespace == QuickAdapterRegressorV3._OPTUNA_NAMESPACES[0]:  # "hp"
             params = self._optuna_hp_params.get(pair)
-        elif namespace == self._OPTUNA_NAMESPACES[1]:  # "train"
+        elif namespace == QuickAdapterRegressorV3._OPTUNA_NAMESPACES[1]:  # "train"
             params = self._optuna_train_params.get(pair)
-        elif namespace == self._OPTUNA_NAMESPACES[2]:  # "label"
+        elif namespace == QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2]:  # "label"
             params = self._optuna_label_params.get(pair)
         else:
             raise ValueError(
                 f"Invalid namespace: {namespace}. "
-                f"Expected {', '.join(self._OPTUNA_NAMESPACES)}"
+                f"Expected {', '.join(QuickAdapterRegressorV3._OPTUNA_NAMESPACES)}"
             )
         return params
 
     def set_optuna_params(
         self, pair: str, namespace: str, params: dict[str, Any]
     ) -> None:
-        if namespace == self._OPTUNA_NAMESPACES[0]:  # "hp"
+        if namespace == QuickAdapterRegressorV3._OPTUNA_NAMESPACES[0]:  # "hp"
             self._optuna_hp_params[pair] = params
-        elif namespace == self._OPTUNA_NAMESPACES[1]:  # "train"
+        elif namespace == QuickAdapterRegressorV3._OPTUNA_NAMESPACES[1]:  # "train"
             self._optuna_train_params[pair] = params
-        elif namespace == self._OPTUNA_NAMESPACES[2]:  # "label"
+        elif namespace == QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2]:  # "label"
             self._optuna_label_params[pair] = params
         else:
             raise ValueError(
                 f"Invalid namespace: {namespace}. "
-                f"Expected {', '.join(self._OPTUNA_NAMESPACES)}"
+                f"Expected {', '.join(QuickAdapterRegressorV3._OPTUNA_NAMESPACES)}"
             )
 
     def get_optuna_value(self, pair: str, namespace: str) -> float:
-        if namespace == self._OPTUNA_NAMESPACES[0]:  # "hp"
+        if namespace == QuickAdapterRegressorV3._OPTUNA_NAMESPACES[0]:  # "hp"
             value = self._optuna_hp_value.get(pair)
-        elif namespace == self._OPTUNA_NAMESPACES[1]:  # "train"
+        elif namespace == QuickAdapterRegressorV3._OPTUNA_NAMESPACES[1]:  # "train"
             value = self._optuna_train_value.get(pair)
         else:
             raise ValueError(
                 f"Invalid namespace: {namespace}. "
-                f"Expected {', '.join(self._OPTUNA_NAMESPACES[:2])}"  # Only hp and train
+                f"Expected {', '.join(QuickAdapterRegressorV3._OPTUNA_NAMESPACES[:2])}"  # Only hp and train
             )
         return value
 
     def set_optuna_value(self, pair: str, namespace: str, value: float) -> None:
-        if namespace == self._OPTUNA_NAMESPACES[0]:  # "hp"
+        if namespace == QuickAdapterRegressorV3._OPTUNA_NAMESPACES[0]:  # "hp"
             self._optuna_hp_value[pair] = value
-        elif namespace == self._OPTUNA_NAMESPACES[1]:  # "train"
+        elif namespace == QuickAdapterRegressorV3._OPTUNA_NAMESPACES[1]:  # "train"
             self._optuna_train_value[pair] = value
         else:
             raise ValueError(
                 f"Invalid namespace: {namespace}. "
-                f"Expected {', '.join(self._OPTUNA_NAMESPACES[:2])}"  # Only hp and train
+                f"Expected {', '.join(QuickAdapterRegressorV3._OPTUNA_NAMESPACES[:2])}"  # Only hp and train
             )
 
     def get_optuna_values(self, pair: str, namespace: str) -> list[float | int]:
-        if namespace == self._OPTUNA_NAMESPACES[2]:  # "label"
+        if namespace == QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2]:  # "label"
             values = self._optuna_label_values.get(pair)
         else:
             raise ValueError(
                 f"Invalid namespace: {namespace}. "
-                f"Expected {self._OPTUNA_NAMESPACES[2]}"  # Only label
+                f"Expected {QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2]}"  # Only label
             )
         return values
 
     def set_optuna_values(
         self, pair: str, namespace: str, values: list[float | int]
     ) -> None:
-        if namespace == self._OPTUNA_NAMESPACES[2]:  # "label"
+        if namespace == QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2]:  # "label"
             self._optuna_label_values[pair] = values
         else:
             raise ValueError(
                 f"Invalid namespace: {namespace}. "
-                f"Expected {self._OPTUNA_NAMESPACES[2]}"  # Only label
+                f"Expected {QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2]}"  # Only label
             )
 
     def init_optuna_label_candle_pool(self) -> None:
@@ -406,7 +414,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
         if self._optuna_hyperopt:
             self.optuna_optimize(
                 pair=dk.pair,
-                namespace=self._OPTUNA_NAMESPACES[0],  # "hp"
+                namespace=QuickAdapterRegressorV3._OPTUNA_NAMESPACES[0],  # "hp"
                 objective=lambda trial: hp_objective(
                     trial,
                     str(self.freqai_info.get("regressor", "xgboost")),
@@ -416,7 +424,9 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
                     X_test,
                     y_test,
                     test_weights,
-                    self.get_optuna_params(dk.pair, self._OPTUNA_NAMESPACES[0]),  # "hp"
+                    self.get_optuna_params(
+                        dk.pair, QuickAdapterRegressorV3._OPTUNA_NAMESPACES[0]
+                    ),  # "hp"
                     model_training_parameters,
                     self._optuna_config.get("space_reduction"),
                     self._optuna_config.get("expansion_ratio"),
@@ -425,7 +435,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
             )
 
             optuna_hp_params = self.get_optuna_params(
-                dk.pair, self._OPTUNA_NAMESPACES[0]
+                dk.pair, QuickAdapterRegressorV3._OPTUNA_NAMESPACES[0]
             )  # "hp"
             if optuna_hp_params:
                 model_training_parameters = {
@@ -435,7 +445,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
 
             train_study = self.optuna_optimize(
                 pair=dk.pair,
-                namespace=self._OPTUNA_NAMESPACES[1],  # "train"
+                namespace=QuickAdapterRegressorV3._OPTUNA_NAMESPACES[1],  # "train"
                 objective=lambda trial: train_objective(
                     trial,
                     str(self.freqai_info.get("regressor", "xgboost")),
@@ -454,18 +464,18 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
             )
 
             optuna_hp_value = self.get_optuna_value(
-                dk.pair, self._OPTUNA_NAMESPACES[0]
+                dk.pair, QuickAdapterRegressorV3._OPTUNA_NAMESPACES[0]
             )  # "hp"
             optuna_train_params = self.get_optuna_params(
-                dk.pair, self._OPTUNA_NAMESPACES[1]
+                dk.pair, QuickAdapterRegressorV3._OPTUNA_NAMESPACES[1]
             )  # "train"
             optuna_train_value = self.get_optuna_value(
-                dk.pair, self._OPTUNA_NAMESPACES[1]
+                dk.pair, QuickAdapterRegressorV3._OPTUNA_NAMESPACES[1]
             )  # "train"
             if (
                 optuna_train_params
                 and self.optuna_validate_params(
-                    dk.pair, self._OPTUNA_NAMESPACES[1], train_study
+                    dk.pair, QuickAdapterRegressorV3._OPTUNA_NAMESPACES[1], train_study
                 )  # "train"
                 and optuna_train_value < optuna_hp_value
             ):
@@ -513,10 +523,12 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
         namespace: str,
         callback: Callable[[], None],
     ) -> None:
-        if namespace not in {self._OPTUNA_NAMESPACES[2]}:  # Only "label"
+        if namespace not in {
+            QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2]
+        }:  # Only "label"
             raise ValueError(
                 f"Invalid namespace: {namespace}. "
-                f"Expected {self._OPTUNA_NAMESPACES[2]}"  # Only label
+                f"Expected {QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2]}"  # Only label
             )
         if not callable(callback):
             raise ValueError("callback must be callable")
@@ -554,10 +566,10 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
         if self._optuna_hyperopt:
             self.optuna_throttle_callback(
                 pair=pair,
-                namespace=self._OPTUNA_NAMESPACES[2],  # "label"
+                namespace=QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2],  # "label"
                 callback=lambda: self.optuna_optimize(
                     pair=pair,
-                    namespace=self._OPTUNA_NAMESPACES[2],  # "label"
+                    namespace=QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2],  # "label"
                     objective=lambda trial: label_objective(
                         trial,
                         self.data_provider.get_pair_dataframe(
@@ -610,9 +622,9 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
             min_pred, max_pred = self.min_max_pred(
                 pred_df,
                 fit_live_predictions_candles,
-                self.get_optuna_params(pair, self._OPTUNA_NAMESPACES[2]).get(
-                    "label_period_candles"
-                ),  # "label"
+                self.get_optuna_params(
+                    pair, QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2]
+                ).get("label_period_candles"),  # "label"
             )
             dk.data["extra_returns_per_train"][MINIMA_THRESHOLD_COLUMN] = min_pred
             dk.data["extra_returns_per_train"][MAXIMA_THRESHOLD_COLUMN] = max_pred
@@ -650,17 +662,17 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
         dk.data["extra_returns_per_train"]["DI_cutoff"] = cutoff
 
         dk.data["extra_returns_per_train"]["label_period_candles"] = (
-            self.get_optuna_params(pair, self._OPTUNA_NAMESPACES[2]).get(
-                "label_period_candles"
-            )  # "label"
+            self.get_optuna_params(
+                pair, QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2]
+            ).get("label_period_candles")  # "label"
         )
         dk.data["extra_returns_per_train"]["label_natr_ratio"] = self.get_optuna_params(
             pair,
-            self._OPTUNA_NAMESPACES[2],  # "label"
+            QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2],  # "label"
         ).get("label_natr_ratio")
 
         hp_rmse = self.optuna_validate_value(
-            self.get_optuna_value(pair, self._OPTUNA_NAMESPACES[0])
+            self.get_optuna_value(pair, QuickAdapterRegressorV3._OPTUNA_NAMESPACES[0])
         )  # "hp"
         dk.data["extra_returns_per_train"]["hp_rmse"] = (
             hp_rmse if hp_rmse is not None else np.inf
@@ -711,7 +723,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
         extrema_selection = str(
             self.freqai_info.get(
                 "prediction_extrema_selection",
-                self._EXTREMA_SELECTION_METHODS[1],
+                QuickAdapterRegressorV3._EXTREMA_SELECTION_METHODS[1],
             )
         )
         if extrema_selection not in self._extrema_selection_methods_set():
@@ -1004,10 +1016,12 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
     def get_multi_objective_study_best_trial(
         self, namespace: str, study: optuna.study.Study
     ) -> Optional[optuna.trial.FrozenTrial]:
-        if namespace not in {self._OPTUNA_NAMESPACES[2]}:  # Only "label"
+        if namespace not in {
+            QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2]
+        }:  # Only "label"
             raise ValueError(
                 f"Invalid namespace: {namespace}. "
-                f"Expected {self._OPTUNA_NAMESPACES[2]}"  # Only label
+                f"Expected {QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2]}"  # Only label
             )
         n_objectives = len(study.directions)
         if n_objectives < 2:
@@ -1643,7 +1657,9 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
         storage_dir = self.full_path
         storage_filename = f"optuna-{pair.split('/')[0]}"
         storage_backend = self._optuna_config.get("storage")
-        if storage_backend == self._OPTUNA_STORAGE_BACKENDS[0]:  # "sqlite"
+        if (
+            storage_backend == QuickAdapterRegressorV3._OPTUNA_STORAGE_BACKENDS[0]
+        ):  # "sqlite"
             storage = optuna.storages.RDBStorage(
                 url=f"sqlite:///{storage_dir}/{storage_filename}.sqlite",
                 heartbeat_interval=60,
@@ -1651,7 +1667,9 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
                     max_retry=3
                 ),
             )
-        elif storage_backend == self._OPTUNA_STORAGE_BACKENDS[1]:  # "file"
+        elif (
+            storage_backend == QuickAdapterRegressorV3._OPTUNA_STORAGE_BACKENDS[1]
+        ):  # "file"
             storage = optuna.storages.JournalStorage(
                 optuna.storages.journal.JournalFileBackend(
                     f"{storage_dir}/{storage_filename}.log"
@@ -1660,7 +1678,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
         else:
             raise ValueError(
                 f"Unsupported optuna storage backend: {storage_backend}. "
-                f"Supported backends are {', '.join(self._OPTUNA_STORAGE_BACKENDS)}"
+                f"Supported backends are {', '.join(QuickAdapterRegressorV3._OPTUNA_STORAGE_BACKENDS)}"
             )
         return storage
 
@@ -1675,12 +1693,14 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
             return optuna.pruners.NopPruner()
 
     def optuna_create_sampler(self) -> optuna.samplers.BaseSampler:
-        sampler = self._optuna_config.get("sampler", self._OPTUNA_SAMPLERS[0])
-        if sampler == self._OPTUNA_SAMPLERS[1]:  # "auto"
+        sampler = self._optuna_config.get(
+            "sampler", QuickAdapterRegressorV3._OPTUNA_SAMPLERS[0]
+        )
+        if sampler == QuickAdapterRegressorV3._OPTUNA_SAMPLERS[1]:  # "auto"
             return optunahub.load_module("samplers/auto_sampler").AutoSampler(
                 seed=self._optuna_config.get("seed")
             )
-        elif sampler == self._OPTUNA_SAMPLERS[0]:  # "tpe"
+        elif sampler == QuickAdapterRegressorV3._OPTUNA_SAMPLERS[0]:  # "tpe"
             return optuna.samplers.TPESampler(
                 n_startup_trials=self._optuna_config.get("n_startup_trials"),
                 multivariate=True,
@@ -1690,7 +1710,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
         else:
             raise ValueError(
                 f"Unsupported sampler: {sampler}. "
-                f"Supported samplers are {', '.join(self._OPTUNA_SAMPLERS)}"
+                f"Supported samplers are {', '.join(QuickAdapterRegressorV3._OPTUNA_SAMPLERS)}"
             )
 
     def optuna_create_study(
index a7ac5efe306f29f2b08366a1043e629d1d4719de..00e083234846cd8dea2772b94404ac98c6410364 100644 (file)
@@ -665,7 +665,7 @@ class QuickAdapterV3(IStrategy):
         dataframe.loc[
             reduce(lambda x, y: x & y, enter_long_conditions),
             ["enter_long", "enter_tag"],
-        ] = (1, self._TRADE_DIRECTIONS[0])  # "long"
+        ] = (1, QuickAdapterV3._TRADE_DIRECTIONS[0])  # "long"
 
         enter_short_conditions = [
             dataframe.get("do_predict") == 1,
@@ -675,7 +675,7 @@ class QuickAdapterV3(IStrategy):
         dataframe.loc[
             reduce(lambda x, y: x & y, enter_short_conditions),
             ["enter_short", "enter_tag"],
-        ] = (1, self._TRADE_DIRECTIONS[1])  # "short"
+        ] = (1, QuickAdapterV3._TRADE_DIRECTIONS[1])  # "short"
 
         return dataframe
 
@@ -1223,13 +1223,17 @@ class QuickAdapterV3(IStrategy):
         if isna(candle_label_natr_value_quantile):
             return np.nan
 
-        if interpolation_direction == self._INTERPOLATION_DIRECTIONS[0]:  # "direct"
+        if (
+            interpolation_direction == QuickAdapterV3._INTERPOLATION_DIRECTIONS[0]
+        ):  # "direct"
             natr_ratio_percent = (
                 min_natr_ratio_percent
                 + (max_natr_ratio_percent - min_natr_ratio_percent)
                 * candle_label_natr_value_quantile**quantile_exponent
             )
-        elif interpolation_direction == self._INTERPOLATION_DIRECTIONS[1]:  # "inverse"
+        elif (
+            interpolation_direction == QuickAdapterV3._INTERPOLATION_DIRECTIONS[1]
+        ):  # "inverse"
             natr_ratio_percent = (
                 max_natr_ratio_percent
                 - (max_natr_ratio_percent - min_natr_ratio_percent)
@@ -1238,7 +1242,7 @@ class QuickAdapterV3(IStrategy):
         else:
             raise ValueError(
                 f"Invalid interpolation_direction: {interpolation_direction}. "
-                f"Expected {', '.join(self._INTERPOLATION_DIRECTIONS)}"
+                f"Expected {', '.join(QuickAdapterV3._INTERPOLATION_DIRECTIONS)}"
             )
         candle_deviation = (
             candle_label_natr_value / 100.0
@@ -1278,7 +1282,9 @@ class QuickAdapterV3(IStrategy):
             min_natr_ratio_percent=min_natr_ratio_percent,
             max_natr_ratio_percent=max_natr_ratio_percent,
             candle_idx=candle_idx,
-            interpolation_direction=self._INTERPOLATION_DIRECTIONS[0],  # "direct"
+            interpolation_direction=QuickAdapterV3._INTERPOLATION_DIRECTIONS[
+                0
+            ],  # "direct"
         )
         if isna(current_deviation) or current_deviation <= 0:
             return np.nan
@@ -1293,14 +1299,14 @@ class QuickAdapterV3(IStrategy):
         is_candle_bullish: bool = candle_close > candle_open
         is_candle_bearish: bool = candle_close < candle_open
 
-        if side == self._TRADE_DIRECTIONS[0]:  # "long"
+        if side == QuickAdapterV3._TRADE_DIRECTIONS[0]:  # "long"
             base_price = (
                 QuickAdapterV3.weighted_close(candle)
                 if is_candle_bearish
                 else candle_close
             )
             candle_threshold = base_price * (1 + current_deviation)
-        elif side == self._TRADE_DIRECTIONS[1]:  # "short"
+        elif side == QuickAdapterV3._TRADE_DIRECTIONS[1]:  # "short"
             base_price = (
                 QuickAdapterV3.weighted_close(candle)
                 if is_candle_bullish
@@ -1309,7 +1315,7 @@ class QuickAdapterV3(IStrategy):
             candle_threshold = base_price * (1 - current_deviation)
         else:
             raise ValueError(
-                f"Invalid side: {side}. Expected {', '.join(self._TRADE_DIRECTIONS)}"
+                f"Invalid side: {side}. Expected {', '.join(QuickAdapterV3._TRADE_DIRECTIONS)}"
             )
         self._candle_threshold_cache[cache_key] = candle_threshold
         return self._candle_threshold_cache[cache_key]
@@ -1424,16 +1430,18 @@ class QuickAdapterV3(IStrategy):
             candle_idx=-1,
         )
         current_ok = np.isfinite(current_threshold) and (
-            (side == self._TRADE_DIRECTIONS[0] and rate > current_threshold)  # "long"
+            (
+                side == QuickAdapterV3._TRADE_DIRECTIONS[0] and rate > current_threshold
+            )  # "long"
             or (
-                side == self._TRADE_DIRECTIONS[1] and rate < current_threshold
+                side == QuickAdapterV3._TRADE_DIRECTIONS[1] and rate < current_threshold
             )  # "short"
         )
-        if order == self._ORDER_TYPES[1]:  # "exit"
-            if side == self._TRADE_DIRECTIONS[0]:  # "long"
-                trade_direction = self._TRADE_DIRECTIONS[1]  # "short"
-            if side == self._TRADE_DIRECTIONS[1]:  # "short"
-                trade_direction = self._TRADE_DIRECTIONS[0]  # "long"
+        if order == QuickAdapterV3._ORDER_TYPES[1]:  # "exit"
+            if side == QuickAdapterV3._TRADE_DIRECTIONS[0]:  # "long"
+                trade_direction = QuickAdapterV3._TRADE_DIRECTIONS[1]  # "short"
+            if side == QuickAdapterV3._TRADE_DIRECTIONS[1]:  # "short"
+                trade_direction = QuickAdapterV3._TRADE_DIRECTIONS[0]  # "long"
         if not current_ok:
             logger.info(
                 f"User denied {trade_direction} {order} for {pair}: rate {format_number(rate)} did not break threshold {format_number(current_threshold)}"
@@ -1471,9 +1479,10 @@ class QuickAdapterV3(IStrategy):
                 return current_ok
 
             if (
-                side == self._TRADE_DIRECTIONS[0] and not (close_k > threshold_k)
+                side == QuickAdapterV3._TRADE_DIRECTIONS[0]
+                and not (close_k > threshold_k)
             ) or (  # "long"
-                side == self._TRADE_DIRECTIONS[1]
+                side == QuickAdapterV3._TRADE_DIRECTIONS[1]
                 and not (close_k < threshold_k)  # "short"
             ):
                 logger.info(
@@ -1702,15 +1711,15 @@ class QuickAdapterV3(IStrategy):
                 trade.set_custom_data("last_outlier_date", last_candle_date.isoformat())
 
         if (
-            trade.trade_direction == self._TRADE_DIRECTIONS[1]  # "short"
+            trade.trade_direction == QuickAdapterV3._TRADE_DIRECTIONS[1]  # "short"
             and last_candle.get("do_predict") == 1
             and last_candle.get("DI_catch") == 1
             and last_candle.get(EXTREMA_COLUMN) < last_candle.get("minima_threshold")
             and self.reversal_confirmed(
                 df,
                 pair,
-                self._TRADE_DIRECTIONS[0],  # "long"
-                self._ORDER_TYPES[1],  # "exit"
+                QuickAdapterV3._TRADE_DIRECTIONS[0],  # "long"
+                QuickAdapterV3._ORDER_TYPES[1],  # "exit"
                 current_rate,
                 self._reversal_lookback_period,
                 self._reversal_decay_ratio,
@@ -1720,15 +1729,15 @@ class QuickAdapterV3(IStrategy):
         ):
             return "minima_detected_short"
         if (
-            trade.trade_direction == self._TRADE_DIRECTIONS[0]  # "long"
+            trade.trade_direction == QuickAdapterV3._TRADE_DIRECTIONS[0]  # "long"
             and last_candle.get("do_predict") == 1
             and last_candle.get("DI_catch") == 1
             and last_candle.get(EXTREMA_COLUMN) > last_candle.get("maxima_threshold")
             and self.reversal_confirmed(
                 df,
                 pair,
-                self._TRADE_DIRECTIONS[1],  # "short"
-                self._ORDER_TYPES[1],  # "exit"
+                QuickAdapterV3._TRADE_DIRECTIONS[1],  # "short"
+                QuickAdapterV3._ORDER_TYPES[1],  # "exit"
                 current_rate,
                 self._reversal_lookback_period,
                 self._reversal_decay_ratio,
@@ -1839,7 +1848,9 @@ class QuickAdapterV3(IStrategy):
     ) -> bool:
         if side not in self._trade_directions_set():
             return False
-        if side == self._TRADE_DIRECTIONS[1] and not self.can_short:  # "short"
+        if (
+            side == QuickAdapterV3._TRADE_DIRECTIONS[1] and not self.can_short
+        ):  # "short"
             logger.info(f"User denied short entry for {pair}: shorting not allowed")
             return False
         if Trade.get_open_trade_count() >= self.config.get("max_open_trades"):
@@ -1863,7 +1874,7 @@ class QuickAdapterV3(IStrategy):
             df,
             pair,
             side,
-            self._ORDER_TYPES[0],  # "entry"
+            QuickAdapterV3._ORDER_TYPES[0],  # "entry"
             rate,
             self._reversal_lookback_period,
             self._reversal_decay_ratio,
@@ -1876,16 +1887,16 @@ class QuickAdapterV3(IStrategy):
     def is_short_allowed(self) -> bool:
         trading_mode = self.config.get("trading_mode")
         if trading_mode in {
-            self._TRADING_MODES[1],
-            self._TRADING_MODES[2],
+            QuickAdapterV3._TRADING_MODES[1],
+            QuickAdapterV3._TRADING_MODES[2],
         }:  # margin, futures
             return True
-        elif trading_mode == self._TRADING_MODES[0]:  # "spot"
+        elif trading_mode == QuickAdapterV3._TRADING_MODES[0]:  # "spot"
             return False
         else:
             raise ValueError(
                 f"Invalid trading_mode: {trading_mode}. "
-                f"Expected {', '.join(self._TRADING_MODES)}"
+                f"Expected {', '.join(QuickAdapterV3._TRADING_MODES)}"
             )
 
     def leverage(