]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor: consolidate some tunables definition
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Wed, 19 Nov 2025 01:17:48 +0000 (02:17 +0100)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Wed, 19 Nov 2025 01:17:48 +0000 (02:17 +0100)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
README.md
ReforceXY/reward_space_analysis/tests/helpers/warnings.py
ReforceXY/user_data/freqaimodels/ReforceXY.py
ReforceXY/user_data/strategies/RLAgentStrategy.py
quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py
quickadapter/user_data/strategies/QuickAdapterV3.py

index b9d6d397c7bd6fb74d80564fabd7deff957caef0..95e865652a00c4deabfdeaf8e5cac4576719d19b 100644 (file)
--- a/README.md
+++ b/README.md
@@ -75,7 +75,7 @@ docker compose up -d --build
 | freqai.feature_parameters.label_knn_p_order          | `None`            | float                                                                                                                            | p-order for KNN Minkowski metric distance. (optional)                                                                                                                                                      |
 | freqai.feature_parameters.label_knn_n_neighbors      | 5                 | int >= 1                                                                                                                         | Number of neighbors for KNN.                                                                                                                                                                               |
 | _Prediction thresholds_                              |                   |                                                                                                                                  |                                                                                                                                                                                                            |
-| freqai.prediction_extrema_selection                  | `peak_values`     | enum {`peak_values`,`extrema_rank`}                                                                                              | Extrema selection method. `peak_values` uses detected peaks, `extrema_rank` uses ranked extrema values.                                                                                                    |
+| freqai.prediction_extrema_selection                  | `extrema_rank`    | enum {`peak_values`,`extrema_rank`}                                                                                              | Extrema selection method. `peak_values` uses detected peaks, `extrema_rank` uses ranked extrema values.                                                                                                    |
 | freqai.prediction_thresholds_smoothing               | `mean`            | enum {`mean`,`isodata`,`li`,`minimum`,`otsu`,`triangle`,`yen`,`soft_extremum`}                                                   | Thresholding method for prediction thresholds smoothing.                                                                                                                                                   |
 | freqai.prediction_thresholds_alpha                   | 12.0              | float > 0                                                                                                                        | Alpha for `soft_extremum`.                                                                                                                                                                                 |
 | freqai.outlier_threshold                             | 0.999             | float (0,1)                                                                                                                      | Quantile threshold for predictions outlier filtering.                                                                                                                                                      |
index 2fa753d3bfb87dff56791f61f1a0477875cf27f2..908fa0a171e1be3e08bdf5003e38379e70bb2180 100644 (file)
@@ -21,7 +21,7 @@ from typing import Any, Optional
 try:
     from reward_space_analysis import RewardDiagnosticsWarning
 except ImportError:
-    RewardDiagnosticsWarning = RuntimeWarning  # type: ignore
+    RewardDiagnosticsWarning = RuntimeWarning
 
 
 @contextmanager
index 655b9692e771d02d294ded959e00257dfd56c421..37222c97f8772c82844a24df1141bf54d591e15a 100644 (file)
@@ -8,7 +8,18 @@ import warnings
 from collections import defaultdict, deque
 from collections.abc import Mapping
 from pathlib import Path
-from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    List,
+    Literal,
+    Optional,
+    Tuple,
+    Type,
+    Union,
+    cast,
+)
 
 import matplotlib
 import matplotlib.pyplot as plt
@@ -59,6 +70,23 @@ from stable_baselines3.common.vec_env import (
     VecMonitor,
 )
 
+ModelType = Literal["PPO", "RecurrentPPO", "MaskablePPO", "DQN", "QRDQN"]
+ScheduleType = Literal["linear", "constant", "unknown"]
+ExitPotentialMode = Literal[
+    "canonical",
+    "non_canonical",
+    "progressive_release",
+    "spike_cancel",
+    "retain_previous",
+]
+TransformFunction = Literal["tanh", "softsign", "arctan", "sigmoid", "asinh", "clip"]
+ExitAttenuationMode = Literal["legacy", "sqrt", "linear", "power", "half_life"]
+ActivationFunction = Literal["tanh", "relu", "elu", "leaky_relu"]
+OptimizerClass = Literal["adam", "adamw", "rmsprop"]
+NetArchSize = Literal["small", "medium", "large", "extra_large"]
+StorageBackend = Literal["sqlite", "file"]
+SamplerType = Literal["tpe", "auto"]
+
 matplotlib.use("Agg")
 warnings.filterwarnings("ignore", category=UserWarning)
 warnings.filterwarnings("ignore", category=FutureWarning)
@@ -117,8 +145,67 @@ class ReforceXY(BaseReinforcementLearningModel):
 
     _LOG_2 = math.log(2.0)
     DEFAULT_IDLE_DURATION_MULTIPLIER: int = 4
+
+    _MODEL_TYPES: tuple[ModelType, ...] = (
+        "PPO",
+        "RecurrentPPO",
+        "MaskablePPO",
+        "DQN",
+        "QRDQN",
+    )
+    _SCHEDULE_TYPES: tuple[ScheduleType, ...] = ("linear", "constant", "unknown")
+    _EXIT_POTENTIAL_MODES: tuple[ExitPotentialMode, ...] = (
+        "canonical",
+        "non_canonical",
+        "progressive_release",
+        "spike_cancel",
+        "retain_previous",
+    )
+    _TRANSFORM_FUNCTIONS: tuple[TransformFunction, ...] = (
+        "tanh",
+        "softsign",
+        "arctan",
+        "sigmoid",
+        "asinh",
+        "clip",
+    )
+    _EXIT_ATTENUATION_MODES: tuple[ExitAttenuationMode, ...] = (
+        "legacy",
+        "sqrt",
+        "linear",
+        "power",
+        "half_life",
+    )
+    _ACTIVATION_FUNCTIONS: tuple[ActivationFunction, ...] = (
+        "tanh",
+        "relu",
+        "elu",
+        "leaky_relu",
+    )
+    _OPTIMIZER_CLASSES: tuple[OptimizerClass, ...] = ("adam", "adamw", "rmsprop")
+    _NET_ARCH_SIZES: tuple[NetArchSize, ...] = (
+        "small",
+        "medium",
+        "large",
+        "extra_large",
+    )
+    _STORAGE_BACKENDS: tuple[StorageBackend, ...] = ("sqlite", "file")
+    _SAMPLER_TYPES: tuple[SamplerType, ...] = ("tpe", "auto")
+
     _action_masks_cache: Dict[Tuple[bool, float], NDArray[np.bool_]] = {}
 
+    @staticmethod
+    def _model_types_set() -> set[ModelType]:
+        return set(ReforceXY._MODEL_TYPES)
+
+    @staticmethod
+    def _exit_potential_modes_set() -> set[ExitPotentialMode]:
+        return set(ReforceXY._EXIT_POTENTIAL_MODES)
+
+    @staticmethod
+    def _transform_functions_set() -> set[TransformFunction]:
+        return set(ReforceXY._TRANSFORM_FUNCTIONS)
+
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
         self.pairs: List[str] = self.config.get("exchange", {}).get("pair_whitelist")
@@ -126,10 +213,12 @@ class ReforceXY(BaseReinforcementLearningModel):
             raise ValueError(
                 "FreqAI model requires StaticPairList method defined in pairlists configuration and pair_whitelist defined in exchange section configuration"
             )
-        self.action_masking: bool = self.model_type == "MaskablePPO"
+        self.action_masking: bool = (
+            self.model_type == self._MODEL_TYPES[2]
+        )  # "MaskablePPO"
         self.rl_config.setdefault("action_masking", self.action_masking)
         self.inference_masking: bool = self.rl_config.get("inference_masking", True)
-        self.recurrent: bool = self.model_type == "RecurrentPPO"
+        self.recurrent: bool = self.model_type == self._MODEL_TYPES[1]  # "RecurrentPPO"
         self.lr_schedule: bool = self.rl_config.get("lr_schedule", False)
         self.cr_schedule: bool = self.rl_config.get("cr_schedule", False)
         self.n_envs: int = self.rl_config.get("n_envs", 1)
@@ -403,19 +492,27 @@ class ReforceXY(BaseReinforcementLearningModel):
             lr = model_params.get("learning_rate", 0.0003)
             if isinstance(lr, (int, float)):
                 lr = float(lr)
-                model_params["learning_rate"] = get_schedule("linear", lr)
+                model_params["learning_rate"] = get_schedule(
+                    self._SCHEDULE_TYPES[0], lr
+                )
                 logger.info(
                     "Learning rate linear schedule enabled, initial value: %s", lr
                 )
 
-        if not self.hyperopt and "PPO" in self.model_type and self.cr_schedule:
+        # "PPO"
+        if (
+            not self.hyperopt
+            and self._MODEL_TYPES[0] in self.model_type
+            and self.cr_schedule
+        ):
             cr = model_params.get("clip_range", 0.2)
             if isinstance(cr, (int, float)):
                 cr = float(cr)
-                model_params["clip_range"] = get_schedule("linear", cr)
+                model_params["clip_range"] = get_schedule(self._SCHEDULE_TYPES[0], cr)
                 logger.info("Clip range linear schedule enabled, initial value: %s", cr)
 
-        if "DQN" in self.model_type:
+        # "DQN"
+        if self._MODEL_TYPES[3] in self.model_type:
             if model_params.get("gradient_steps") is None:
                 model_params["gradient_steps"] = compute_gradient_steps(
                     model_params.get("train_freq"), model_params.get("subsample_steps")
@@ -433,11 +530,20 @@ class ReforceXY(BaseReinforcementLearningModel):
             Literal["small", "medium", "large", "extra_large"],
         ] = model_params.get("policy_kwargs", {}).get("net_arch", default_net_arch)
 
-        if "PPO" in self.model_type:
+        # "PPO"
+        if self._MODEL_TYPES[0] in self.model_type:
             if isinstance(net_arch, str):
-                model_params["policy_kwargs"]["net_arch"] = get_net_arch(
-                    self.model_type, net_arch
-                )
+                if net_arch in self._NET_ARCH_SIZES:
+                    model_params["policy_kwargs"]["net_arch"] = get_net_arch(
+                        self.model_type,
+                        cast(NetArchSize, net_arch),
+                    )
+                else:
+                    logger.warning("Invalid net_arch=%s, using default", net_arch)
+                    model_params["policy_kwargs"]["net_arch"] = {
+                        "pi": default_net_arch,
+                        "vf": default_net_arch,
+                    }
             elif isinstance(net_arch, list):
                 model_params["policy_kwargs"]["net_arch"] = {
                     "pi": net_arch,
@@ -465,9 +571,14 @@ class ReforceXY(BaseReinforcementLearningModel):
                 }
         else:
             if isinstance(net_arch, str):
-                model_params["policy_kwargs"]["net_arch"] = get_net_arch(
-                    self.model_type, net_arch
-                )
+                if net_arch in self._NET_ARCH_SIZES:
+                    model_params["policy_kwargs"]["net_arch"] = get_net_arch(
+                        self.model_type,
+                        cast(NetArchSize, net_arch),
+                    )
+                else:
+                    logger.warning("Invalid net_arch=%s, using default", net_arch)
+                    model_params["policy_kwargs"]["net_arch"] = default_net_arch
             elif isinstance(net_arch, list):
                 model_params["policy_kwargs"]["net_arch"] = net_arch
             else:
@@ -477,10 +588,14 @@ class ReforceXY(BaseReinforcementLearningModel):
                 model_params["policy_kwargs"]["net_arch"] = default_net_arch
 
         model_params["policy_kwargs"]["activation_fn"] = get_activation_fn(
-            model_params.get("policy_kwargs", {}).get("activation_fn", "relu")
+            model_params.get("policy_kwargs", {}).get(
+                "activation_fn", self._ACTIVATION_FUNCTIONS[1]
+            )  # "relu"
         )
         model_params["policy_kwargs"]["optimizer_class"] = get_optimizer_class(
-            model_params.get("policy_kwargs", {}).get("optimizer_class", "adamw")
+            model_params.get("policy_kwargs", {}).get(
+                "optimizer_class", self._OPTIMIZER_CLASSES[1]
+            )  # "adamw"
         )
 
         self._model_params_cache = model_params
@@ -517,7 +632,8 @@ class ReforceXY(BaseReinforcementLearningModel):
         """
         if total_timesteps <= 0:
             return 1
-        if "PPO" in self.model_type:
+        # "PPO"
+        if self._MODEL_TYPES[0] in self.model_type:
             eval_freq: Optional[int] = None
             if model_params:
                 n_steps = model_params.get("n_steps")
@@ -668,7 +784,8 @@ class ReforceXY(BaseReinforcementLearningModel):
             model_params = self.get_model_params()
         logger.info("%s params: %s", self.model_type, model_params)
 
-        if "PPO" in self.model_type:
+        # "PPO"
+        if self._MODEL_TYPES[0] in self.model_type:
             n_steps = model_params.get("n_steps", 0)
             min_timesteps = 2 * n_steps * self.n_envs
             if total_timesteps <= min_timesteps:
@@ -917,20 +1034,24 @@ class ReforceXY(BaseReinforcementLearningModel):
         """
         storage_dir = self.full_path
         storage_filename = f"optuna-{pair.split('/')[0]}"
-        storage_backend = self.rl_config_optuna.get("storage", "sqlite")
-        if storage_backend == "sqlite":
+        storage_backend: StorageBackend = self.rl_config_optuna.get(
+            "storage", self._STORAGE_BACKENDS[0]
+        )  # "sqlite"
+        # "sqlite"
+        if storage_backend == self._STORAGE_BACKENDS[0]:
             storage = RDBStorage(
                 url=f"sqlite:///{storage_dir}/{storage_filename}.sqlite",
                 heartbeat_interval=60,
                 failed_trial_callback=RetryFailedTrialCallback(max_retry=3),
             )
-        elif storage_backend == "file":
+        # "file"
+        elif storage_backend == self._STORAGE_BACKENDS[1]:
             storage = JournalStorage(
                 JournalFileBackend(f"{storage_dir}/{storage_filename}.log")
             )
         else:
             raise ValueError(
-                f"Unsupported storage backend: {storage_backend}. Supported backends are: 'sqlite' and 'file'"
+                f"Unsupported storage backend: {storage_backend}. Supported backends are: {', '.join(self._STORAGE_BACKENDS)}"
             )
         return storage
 
@@ -945,12 +1066,16 @@ class ReforceXY(BaseReinforcementLearningModel):
             return False
 
     def create_sampler(self) -> BaseSampler:
-        sampler = self.rl_config_optuna.get("sampler", "tpe")
-        if sampler == "auto":
+        sampler: SamplerType = self.rl_config_optuna.get(
+            "sampler", self._SAMPLER_TYPES[0]
+        )  # "tpe"
+        # "auto"
+        if sampler == self._SAMPLER_TYPES[1]:
             return optunahub.load_module("samplers/auto_sampler").AutoSampler(
                 seed=self.rl_config_optuna.get("seed", 42)
             )
-        elif sampler == "tpe":
+        # "tpe"
+        elif sampler == self._SAMPLER_TYPES[0]:
             return TPESampler(
                 n_startup_trials=self.optuna_n_startup_trials,
                 multivariate=True,
@@ -959,7 +1084,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             )
         else:
             raise ValueError(
-                f"Unsupported sampler: '{sampler}'. Supported samplers: 'tpe', 'auto'"
+                f"Unsupported sampler: {sampler!r}. Supported samplers: {', '.join(self._SAMPLER_TYPES)}"
             )
 
     @staticmethod
@@ -984,7 +1109,8 @@ class ReforceXY(BaseReinforcementLearningModel):
         continuous = self.rl_config_optuna.get("continuous", False)
         if continuous:
             ReforceXY.delete_study(study_name, storage)
-        if "PPO" in self.model_type:
+        # "PPO"
+        if self._MODEL_TYPES[0] in self.model_type:
             resource_eval_freq = min(PPO_N_STEPS)
         else:
             resource_eval_freq = self.get_eval_freq(total_timesteps, hyperopt=True)
@@ -1190,13 +1316,17 @@ class ReforceXY(BaseReinforcementLearningModel):
         return train_env, eval_env
 
     def get_optuna_params(self, trial: Trial) -> Dict[str, Any]:
-        if "RecurrentPPO" in self.model_type:
+        # "RecurrentPPO"
+        if self._MODEL_TYPES[1] in self.model_type:
             return sample_params_recurrentppo(trial)
-        elif "PPO" in self.model_type:
+        # "PPO"
+        elif self._MODEL_TYPES[0] in self.model_type:
             return sample_params_ppo(trial)
-        elif "QRDQN" in self.model_type:
+        # "QRDQN"
+        elif self._MODEL_TYPES[4] in self.model_type:
             return sample_params_qrdqn(trial)
-        elif "DQN" in self.model_type:
+        # "DQN"
+        elif self._MODEL_TYPES[3] in self.model_type:
             return sample_params_dqn(trial)
         else:
             raise NotImplementedError(f"{self.model_type} not supported for hyperopt")
@@ -1211,7 +1341,8 @@ class ReforceXY(BaseReinforcementLearningModel):
 
         params = self.get_optuna_params(trial)
 
-        if "PPO" in self.model_type:
+        # "PPO"
+        if self._MODEL_TYPES[0] in self.model_type:
             n_steps = params.get("n_steps")
             if n_steps * self.n_envs > total_timesteps:
                 raise TrialPruned(
@@ -1223,7 +1354,8 @@ class ReforceXY(BaseReinforcementLearningModel):
                     f"{n_steps=} * {self.n_envs=} = {n_steps * self.n_envs} is not divisible by {batch_size=}"
                 )
 
-        if "DQN" in self.model_type:
+        # "DQN"
+        if self._MODEL_TYPES[3] in self.model_type:
             gradient_steps = params.get("gradient_steps")
             if isinstance(gradient_steps, int) and gradient_steps <= 0:
                 raise TrialPruned(f"{gradient_steps=} is negative or zero")
@@ -1242,7 +1374,9 @@ class ReforceXY(BaseReinforcementLearningModel):
         params["seed"] = params.get("seed", 42) + trial.number
         logger.info("Trial %s params: %s", trial.number, params)
 
-        if "PPO" in self.model_type:
+        # "PPO"
+
+        if self._MODEL_TYPES[0] in self.model_type:
             n_steps = params.get("n_steps", 0)
             if n_steps > 0:
                 rollout = n_steps * self.n_envs
@@ -1432,21 +1566,21 @@ class MyRLEnv(Base5ActionRLEnv):
         #   'spike_cancel'        -> Φ(s')=Φ(s)/γ (Δ ≈ 0, cancels shaping)
         #   'retain_previous'     -> Φ(s')=Φ(s)
         self._exit_potential_mode = str(
-            model_reward_parameters.get("exit_potential_mode", "canonical")
-        )
-        _allowed_exit_modes = {
-            "canonical",
-            "non_canonical",
-            "progressive_release",
-            "spike_cancel",
-            "retain_previous",
-        }
+            model_reward_parameters.get(
+                "exit_potential_mode", ReforceXY._EXIT_POTENTIAL_MODES[0]
+            )  # "canonical"
+        )
+        _allowed_exit_modes = set(ReforceXY._EXIT_POTENTIAL_MODES)
         if self._exit_potential_mode not in _allowed_exit_modes:
             logger.warning(
-                "Unknown exit_potential_mode '%s'; defaulting to 'canonical'",
+                "Unknown exit_potential_mode %r; defaulting to %r. Valid modes: %s",
                 self._exit_potential_mode,
+                ReforceXY._EXIT_POTENTIAL_MODES[0],
+                ", ".join(ReforceXY._EXIT_POTENTIAL_MODES),
             )
-            self._exit_potential_mode = "canonical"
+            self._exit_potential_mode = ReforceXY._EXIT_POTENTIAL_MODES[
+                0
+            ]  # "canonical"
         self._exit_potential_decay: float = float(
             model_reward_parameters.get("exit_potential_decay", 0.5)
         )
@@ -1460,11 +1594,17 @@ class MyRLEnv(Base5ActionRLEnv):
         self._entry_additive_gain: float = float(
             model_reward_parameters.get("entry_additive_gain", 1.0)
         )
-        self._entry_additive_transform_pnl: str = str(
-            model_reward_parameters.get("entry_additive_transform_pnl", "tanh")
+        self._entry_additive_transform_pnl: TransformFunction = cast(
+            TransformFunction,
+            model_reward_parameters.get(
+                "entry_additive_transform_pnl", ReforceXY._TRANSFORM_FUNCTIONS[0]
+            ),  # "tanh"
         )
-        self._entry_additive_transform_duration: str = str(
-            model_reward_parameters.get("entry_additive_transform_duration", "tanh")
+        self._entry_additive_transform_duration: TransformFunction = cast(
+            TransformFunction,
+            model_reward_parameters.get(
+                "entry_additive_transform_duration", ReforceXY._TRANSFORM_FUNCTIONS[0]
+            ),  # "tanh"
         )
         # === HOLD POTENTIAL (PBRS function Φ) ===
         self._hold_potential_enabled: bool = bool(
@@ -1476,11 +1616,17 @@ class MyRLEnv(Base5ActionRLEnv):
         self._hold_potential_gain: float = float(
             model_reward_parameters.get("hold_potential_gain", 1.0)
         )
-        self._hold_potential_transform_pnl: str = str(
-            model_reward_parameters.get("hold_potential_transform_pnl", "tanh")
+        self._hold_potential_transform_pnl: TransformFunction = cast(
+            TransformFunction,
+            model_reward_parameters.get(
+                "hold_potential_transform_pnl", ReforceXY._TRANSFORM_FUNCTIONS[0]
+            ),  # "tanh"
         )
-        self._hold_potential_transform_duration: str = str(
-            model_reward_parameters.get("hold_potential_transform_duration", "tanh")
+        self._hold_potential_transform_duration: TransformFunction = cast(
+            TransformFunction,
+            model_reward_parameters.get(
+                "hold_potential_transform_duration", ReforceXY._TRANSFORM_FUNCTIONS[0]
+            ),  # "tanh"
         )
         # === EXIT ADDITIVE (non-PBRS additive term) ===
         self._exit_additive_enabled: bool = bool(
@@ -1492,22 +1638,30 @@ class MyRLEnv(Base5ActionRLEnv):
         self._exit_additive_gain: float = float(
             model_reward_parameters.get("exit_additive_gain", 1.0)
         )
-        self._exit_additive_transform_pnl: str = str(
-            model_reward_parameters.get("exit_additive_transform_pnl", "tanh")
+        self._exit_additive_transform_pnl: TransformFunction = cast(
+            TransformFunction,
+            model_reward_parameters.get(
+                "exit_additive_transform_pnl", ReforceXY._TRANSFORM_FUNCTIONS[0]
+            ),  # "tanh"
         )
-        self._exit_additive_transform_duration: str = str(
-            model_reward_parameters.get("exit_additive_transform_duration", "tanh")
+        self._exit_additive_transform_duration: TransformFunction = cast(
+            TransformFunction,
+            model_reward_parameters.get(
+                "exit_additive_transform_duration", ReforceXY._TRANSFORM_FUNCTIONS[0]
+            ),  # "tanh"
         )
         # === PBRS INVARIANCE CHECKS ===
-        if self._exit_potential_mode == "canonical":
+        # "canonical"
+        if self._exit_potential_mode == ReforceXY._EXIT_POTENTIAL_MODES[0]:
             if self._entry_additive_enabled or self._exit_additive_enabled:
                 logger.info(
                     "PBRS canonical mode: additive rewards disabled with Φ(terminal)=0. PBRS invariance is preserved. "
-                    "To use additive rewards, set exit_potential_mode='non_canonical'."
+                    f"To use additive rewards, set exit_potential_mode={ReforceXY._EXIT_POTENTIAL_MODES[1]!r}."
                 )
                 self._entry_additive_enabled = False
                 self._exit_additive_enabled = False
-        elif self._exit_potential_mode == "non_canonical":
+        # "non_canonical"
+        elif self._exit_potential_mode == ReforceXY._EXIT_POTENTIAL_MODES[1]:
             if self._entry_additive_enabled or self._exit_additive_enabled:
                 logger.info(
                     "PBRS non-canonical mode: additive rewards enabled with Φ(terminal)=0. PBRS invariance is intentionally broken."
@@ -1593,8 +1747,8 @@ class MyRLEnv(Base5ActionRLEnv):
         duration_ratio: float,
         scale: float,
         gain: float,
-        transform_pnl: str,
-        transform_duration: str,
+        transform_pnl: TransformFunction,
+        transform_duration: TransformFunction,
     ) -> float:
         """Generic bounded bi-component signal combining PnL and duration.
 
@@ -1621,9 +1775,9 @@ class MyRLEnv(Base5ActionRLEnv):
             Output scaling factor
         gain : float
             Gain multiplier before transform
-        transform_pnl : str
+        transform_pnl : TransformFunction
             Transform name for PnL component
-        transform_duration : str
+        transform_duration : TransformFunction
             Transform name for duration component
 
         Returns
@@ -1720,7 +1874,7 @@ class MyRLEnv(Base5ActionRLEnv):
             transform_duration=self._entry_additive_transform_duration,
         )
 
-    def _potential_transform(self, name: str, x: float) -> float:
+    def _potential_transform(self, name: TransformFunction, x: float) -> float:
         """Apply bounded transform function for potential and additive computations.
 
         Provides numerical stability by mapping unbounded inputs to bounded outputs
@@ -1730,8 +1884,7 @@ class MyRLEnv(Base5ActionRLEnv):
         Parameters
         ----------
         name : str
-            Transform function name: 'tanh', 'softsign', 'arctan', 'sigmoid',
-            'asinh', or 'clip'
+            Transform function name: one of ReforceXY._TRANSFORM_FUNCTIONS
         x : float
             Input value to transform
 
@@ -1740,17 +1893,17 @@ class MyRLEnv(Base5ActionRLEnv):
         float
             Bounded output in [-1, 1]
         """
-        if name == "tanh":
+        if name == ReforceXY._TRANSFORM_FUNCTIONS[0]:  # "tanh"
             return math.tanh(x)
 
-        if name == "softsign":
+        if name == ReforceXY._TRANSFORM_FUNCTIONS[1]:  # "softsign"
             ax = abs(x)
             return x / (1.0 + ax)
 
-        if name == "arctan":
+        if name == ReforceXY._TRANSFORM_FUNCTIONS[2]:  # "arctan"
             return (2.0 / math.pi) * math.atan(x)
 
-        if name == "sigmoid":
+        if name == ReforceXY._TRANSFORM_FUNCTIONS[3]:  # "sigmoid"
             try:
                 if x >= 0:
                     exp_neg_x = math.exp(-x)
@@ -1762,13 +1915,17 @@ class MyRLEnv(Base5ActionRLEnv):
             except OverflowError:
                 return 1.0 if x > 0 else -1.0
 
-        if name == "asinh":
+        if name == ReforceXY._TRANSFORM_FUNCTIONS[4]:  # "asinh"
             return x / math.hypot(1.0, x)
 
-        if name == "clip":
+        if name == ReforceXY._TRANSFORM_FUNCTIONS[5]:  # "clip"
             return max(-1.0, min(1.0, x))
 
-        logger.warning("Unknown potential transform '%s'; falling back to tanh", name)
+        logger.warning(
+            "Unknown potential transform '%s'; falling back to tanh. Valid transforms: %s",
+            name,
+            ", ".join(ReforceXY._TRANSFORM_FUNCTIONS),
+        )
         return math.tanh(x)
 
     def _compute_exit_potential(self, prev_potential: float, gamma: float) -> float:
@@ -1777,21 +1934,28 @@ class MyRLEnv(Base5ActionRLEnv):
         See ``_apply_potential_shaping`` for complete PBRS documentation.
         """
         mode = self._exit_potential_mode
-        if mode == "canonical" or mode == "non_canonical":
+        # "canonical" or "non_canonical"
+        if (
+            mode == ReforceXY._EXIT_POTENTIAL_MODES[0]
+            or mode == ReforceXY._EXIT_POTENTIAL_MODES[1]
+        ):
             return 0.0
-        if mode == "progressive_release":
+        # "progressive_release"
+        if mode == ReforceXY._EXIT_POTENTIAL_MODES[2]:
             decay = self._exit_potential_decay
             if not np.isfinite(decay) or decay < 0.0:
                 decay = 0.0
             if decay > 1.0:
                 decay = 1.0
             next_potential = prev_potential * (1.0 - decay)
-        elif mode == "spike_cancel":
+        # "spike_cancel"
+        elif mode == ReforceXY._EXIT_POTENTIAL_MODES[3]:
             if gamma <= 0.0 or not np.isfinite(gamma):
                 next_potential = prev_potential
             else:
                 next_potential = prev_potential / gamma
-        elif mode == "retain_previous":
+        # "retain_previous"
+        elif mode == ReforceXY._EXIT_POTENTIAL_MODES[4]:
             next_potential = prev_potential
         else:
             next_potential = 0.0
@@ -1814,7 +1978,8 @@ class MyRLEnv(Base5ActionRLEnv):
         bool
             True if configuration preserves theoretical PBRS invariance
         """
-        return self._exit_potential_mode == "canonical" and not (
+        # "canonical"
+        return self._exit_potential_mode == ReforceXY._EXIT_POTENTIAL_MODES[0] and not (
             self._entry_additive_enabled or self._exit_additive_enabled
         )
 
@@ -2028,8 +2193,11 @@ class MyRLEnv(Base5ActionRLEnv):
             return base_reward + reward_shaping
         elif is_exit:
             if (
-                self._exit_potential_mode == "canonical"
-                or self._exit_potential_mode == "non_canonical"
+                self._exit_potential_mode
+                == ReforceXY._EXIT_POTENTIAL_MODES[0]  # "canonical"
+            ) or (
+                self._exit_potential_mode
+                == ReforceXY._EXIT_POTENTIAL_MODES[1]  # "non_canonical"
             ):
                 next_potential = 0.0
                 exit_reward_shaping = -prev_potential
@@ -2137,7 +2305,9 @@ class MyRLEnv(Base5ActionRLEnv):
 
         model_reward_parameters = self.rl_config.get("model_reward_parameters", {})
         exit_attenuation_mode = str(
-            model_reward_parameters.get("exit_attenuation_mode", "linear")
+            model_reward_parameters.get(
+                "exit_attenuation_mode", ReforceXY._EXIT_ATTENUATION_MODES[2]
+            )  # "linear"
         )
         exit_plateau = bool(model_reward_parameters.get("exit_plateau", True))
         exit_plateau_grace = float(
@@ -2180,11 +2350,11 @@ class MyRLEnv(Base5ActionRLEnv):
             return f * math.pow(2.0, -dr / hl)
 
         strategies: Dict[str, Callable[[float, float, Mapping], float]] = {
-            "legacy": _legacy,
-            "sqrt": _sqrt,
-            "linear": _linear,
-            "power": _power,
-            "half_life": _half_life,
+            ReforceXY._EXIT_ATTENUATION_MODES[0]: _legacy,
+            ReforceXY._EXIT_ATTENUATION_MODES[1]: _sqrt,
+            ReforceXY._EXIT_ATTENUATION_MODES[2]: _linear,
+            ReforceXY._EXIT_ATTENUATION_MODES[3]: _power,
+            ReforceXY._EXIT_ATTENUATION_MODES[4]: _half_life,
         }
 
         if exit_plateau:
@@ -2198,8 +2368,9 @@ class MyRLEnv(Base5ActionRLEnv):
         strategy_fn = strategies.get(exit_attenuation_mode, None)
         if strategy_fn is None:
             logger.debug(
-                "Unknown exit_attenuation_mode '%s'; defaulting to linear",
+                "Unknown exit_attenuation_mode '%s'; defaulting to linear. Valid modes: %s",
                 exit_attenuation_mode,
+                ", ".join(ReforceXY._EXIT_ATTENUATION_MODES),
             )
             strategy_fn = _linear
 
@@ -3299,9 +3470,9 @@ class InfoMetricsCallback(TensorboardCallback):
         def _eval_schedule(schedule: Any) -> float | None:
             schedule_type, _, _ = get_schedule_type(schedule)
             try:
-                if schedule_type == "linear":
+                if schedule_type == ReforceXY._SCHEDULE_TYPES[0]:  # "linear"
                     return float(schedule(progress_remaining))
-                if schedule_type == "constant":
+                if schedule_type == ReforceXY._SCHEDULE_TYPES[1]:  # "constant"
                     if callable(schedule):
                         return float(schedule(0.0))
                     if isinstance(schedule, (int, float)):
@@ -3588,9 +3759,9 @@ def get_schedule(
     schedule_type: Literal["linear", "constant"],
     initial_value: float,
 ) -> Callable[[float], float]:
-    if schedule_type == "linear":
+    if schedule_type == ReforceXY._SCHEDULE_TYPES[0]:  # "linear"
         return SimpleLinearSchedule(initial_value)
-    elif schedule_type == "constant":
+    elif schedule_type == ReforceXY._SCHEDULE_TYPES[1]:  # "constant"
         return ConstantSchedule(initial_value)
     else:
         return ConstantSchedule(initial_value)
@@ -3652,8 +3823,12 @@ def convert_optuna_params_to_model_params(
 
     lr = optuna_params.get("learning_rate")
     if lr is None:
-        raise ValueError(f"missing 'learning_rate' in optuna params for {model_type}")
-    lr = get_schedule(optuna_params.get("lr_schedule", "constant"), float(lr))
+        raise ValueError(
+            f"missing {'learning_rate'!r} in optuna params for {model_type}"
+        )
+    lr = get_schedule(
+        optuna_params.get("lr_schedule", ReforceXY._SCHEDULE_TYPES[1]), float(lr)
+    )  # default: "constant"
 
     if "PPO" in model_type:
         required_ppo_params = [
@@ -3671,7 +3846,10 @@ def convert_optuna_params_to_model_params(
             if optuna_params.get(param) is None:
                 raise ValueError(f"missing '{param}' in optuna params for {model_type}")
         cr = optuna_params.get("clip_range")
-        cr = get_schedule(optuna_params.get("cr_schedule", "constant"), float(cr))
+        cr = get_schedule(
+            optuna_params.get("cr_schedule", ReforceXY._SCHEDULE_TYPES[1]),
+            float(cr),
+        )  # default: "constant"
 
         model_params.update(
             {
@@ -3743,17 +3921,24 @@ def convert_optuna_params_to_model_params(
         raise ValueError(f"Model {model_type} not supported")
 
     if optuna_params.get("net_arch"):
-        policy_kwargs["net_arch"] = get_net_arch(
-            model_type, str(optuna_params["net_arch"])
-        )
+        net_arch_value = str(optuna_params["net_arch"])
+        if net_arch_value in ReforceXY._NET_ARCH_SIZES:
+            policy_kwargs["net_arch"] = get_net_arch(
+                model_type,
+                cast(NetArchSize, net_arch_value),
+            )
     if optuna_params.get("activation_fn"):
-        policy_kwargs["activation_fn"] = get_activation_fn(
-            str(optuna_params["activation_fn"])
-        )
+        activation_fn_value = str(optuna_params["activation_fn"])
+        if activation_fn_value in ReforceXY._ACTIVATION_FUNCTIONS:
+            policy_kwargs["activation_fn"] = get_activation_fn(
+                cast(ActivationFunction, activation_fn_value)
+            )
     if optuna_params.get("optimizer_class"):
-        policy_kwargs["optimizer_class"] = get_optimizer_class(
-            str(optuna_params["optimizer_class"])
-        )
+        optimizer_value = str(optuna_params["optimizer_class"])
+        if optimizer_value in ReforceXY._OPTIMIZER_CLASSES:
+            policy_kwargs["optimizer_class"] = get_optimizer_class(
+                cast(OptimizerClass, optimizer_value)
+            )
     if optuna_params.get("ortho_init") is not None:
         policy_kwargs["ortho_init"] = bool(optuna_params["ortho_init"])
 
@@ -3780,20 +3965,25 @@ def get_common_ppo_optuna_params(trial: Trial) -> Dict[str, Any]:
         "gae_lambda": trial.suggest_float("gae_lambda", 0.9, 0.99, step=0.01),
         "max_grad_norm": trial.suggest_float("max_grad_norm", 0.3, 1.0, step=0.05),
         "vf_coef": trial.suggest_float("vf_coef", 0.0, 1.0, step=0.05),
-        "lr_schedule": trial.suggest_categorical("lr_schedule", ["linear", "constant"]),
-        "cr_schedule": trial.suggest_categorical("cr_schedule", ["linear", "constant"]),
+        "lr_schedule": trial.suggest_categorical(
+            "lr_schedule", list(ReforceXY._SCHEDULE_TYPES)
+        ),
+        "cr_schedule": trial.suggest_categorical(
+            "cr_schedule", list(ReforceXY._SCHEDULE_TYPES)
+        ),
         "target_kl": trial.suggest_categorical(
             "target_kl", [None, 0.01, 0.015, 0.02, 0.03, 0.04]
         ),
         "ortho_init": trial.suggest_categorical("ortho_init", [True, False]),
         "net_arch": trial.suggest_categorical(
-            "net_arch", ["small", "medium", "large", "extra_large"]
+            "net_arch", list(ReforceXY._NET_ARCH_SIZES)
         ),
         "activation_fn": trial.suggest_categorical(
-            "activation_fn", ["tanh", "relu", "elu", "leaky_relu"]
+            "activation_fn", list(ReforceXY._ACTIVATION_FUNCTIONS)
         ),
         "optimizer_class": trial.suggest_categorical(
-            "optimizer_class", ["adamw", "rmsprop"]
+            "optimizer_class",
+            [ReforceXY._OPTIMIZER_CLASSES[1], ReforceXY._OPTIMIZER_CLASSES[2]],
         ),
     }
 
@@ -3864,13 +4054,14 @@ def get_common_dqn_optuna_params(trial: Trial) -> Dict[str, Any]:
             "learning_starts", [500, 1000, 2000, 3000, 4000, 5000, 8000, 10000]
         ),
         "net_arch": trial.suggest_categorical(
-            "net_arch", ["small", "medium", "large", "extra_large"]
+            "net_arch", list(ReforceXY._NET_ARCH_SIZES)
         ),
         "activation_fn": trial.suggest_categorical(
-            "activation_fn", ["tanh", "relu", "elu", "leaky_relu"]
+            "activation_fn", list(ReforceXY._ACTIVATION_FUNCTIONS)
         ),
         "optimizer_class": trial.suggest_categorical(
-            "optimizer_class", ["adamw", "rmsprop"]
+            "optimizer_class",
+            [ReforceXY._OPTIMIZER_CLASSES[1], ReforceXY._OPTIMIZER_CLASSES[2]],
         ),
     }
 
index af9251e392a7e421c16404388ca877ed91873a2a..7e2330140ec30f58454c16e46deb3509274e9ae5 100644 (file)
@@ -1,13 +1,16 @@
 import datetime
 import logging
 from functools import cached_property, reduce
-from typing import Any, Optional
+from typing import Any, Literal, Optional
 
 # import talib.abstract as ta
 from freqtrade.persistence import Trade
 from freqtrade.strategy import IStrategy
 from pandas import DataFrame
 
+TradingMode = Literal["margin", "futures", "spot"]
+TradeDirection = Literal["long", "short"]
+
 logger = logging.getLogger(__name__)
 
 ACTION_COLUMN = "&-action"
@@ -20,6 +23,13 @@ class RLAgentStrategy(IStrategy):
 
     INTERFACE_VERSION = 3
 
+    _TRADING_MODES: tuple[TradingMode, ...] = ("margin", "futures", "spot")
+    _TRADE_DIRECTIONS: tuple[TradeDirection, ...] = ("long", "short")
+    _ACTION_ENTER_LONG: int = 1
+    _ACTION_EXIT_LONG: int = 2
+    _ACTION_ENTER_SHORT: int = 3
+    _ACTION_EXIT_SHORT: int = 4
+
     @cached_property
     def can_short(self) -> bool:
         return self.is_short_allowed()
@@ -72,21 +82,21 @@ class RLAgentStrategy(IStrategy):
     ) -> DataFrame:
         enter_long_conditions = [
             dataframe.get("do_predict") == 1,
-            dataframe.get(ACTION_COLUMN) == 1,
+            dataframe.get(ACTION_COLUMN) == self._ACTION_ENTER_LONG,  # 1,
         ]
         dataframe.loc[
             reduce(lambda x, y: x & y, enter_long_conditions),
             ["enter_long", "enter_tag"],
-        ] = (1, "long")
+        ] = (1, self._TRADE_DIRECTIONS[0])  # "long"
 
         enter_short_conditions = [
             dataframe.get("do_predict") == 1,
-            dataframe.get(ACTION_COLUMN) == 3,
+            dataframe.get(ACTION_COLUMN) == self._ACTION_ENTER_SHORT,  # 3,
         ]
         dataframe.loc[
             reduce(lambda x, y: x & y, enter_short_conditions),
             ["enter_short", "enter_tag"],
-        ] = (1, "short")
+        ] = (1, self._TRADE_DIRECTIONS[1])  # "short"
 
         return dataframe
 
@@ -95,13 +105,13 @@ class RLAgentStrategy(IStrategy):
     ) -> DataFrame:
         exit_long_conditions = [
             dataframe.get("do_predict") == 1,
-            dataframe.get(ACTION_COLUMN) == 2,
+            dataframe.get(ACTION_COLUMN) == self._ACTION_EXIT_LONG,  # 2,
         ]
         dataframe.loc[reduce(lambda x, y: x & y, exit_long_conditions), "exit_long"] = 1
 
         exit_short_conditions = [
             dataframe.get("do_predict") == 1,
-            dataframe.get(ACTION_COLUMN) == 4,
+            dataframe.get(ACTION_COLUMN) == self._ACTION_EXIT_SHORT,  # 4,
         ]
         dataframe.loc[
             reduce(lambda x, y: x & y, exit_short_conditions), "exit_short"
@@ -148,9 +158,11 @@ class RLAgentStrategy(IStrategy):
 
     def is_short_allowed(self) -> bool:
         trading_mode = self.config.get("trading_mode")
-        if trading_mode in {"margin", "futures"}:
+        # "margin", "futures"
+        if trading_mode in {self._TRADING_MODES[0], self._TRADING_MODES[1]}:
             return True
-        elif trading_mode == "spot":
+        # "spot"
+        elif trading_mode == self._TRADING_MODES[2]:
             return False
         else:
             raise ValueError(f"Invalid trading_mode: {trading_mode}")
index e88ee9e0e982131d0c5e20b333a6f87e32a88431..6c5c977229940b466a53d539fd381d602cbb85b7 100644 (file)
@@ -33,6 +33,9 @@ from Utils import (
     zigzag,
 )
 
+ExtremaSelectionMethod = Literal["peak_values", "extrema_rank"]
+OptunaNamespace = Literal["hp", "train", "label"]
+
 debug = False
 
 TEST_SIZE = 0.1
@@ -45,9 +48,6 @@ warnings.simplefilter(action="ignore", category=FutureWarning)
 
 logger = logging.getLogger(__name__)
 
-ExtremaSelectionMethod = Literal["peak_values", "extrema_rank"]
-OptunaNamespace = Literal["hp", "train", "label"]
-
 
 class QuickAdapterRegressorV3(BaseRegressionModel):
     """
@@ -66,7 +66,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
     https://github.com/sponsors/robcaulk
     """
 
-    version = "3.7.120"
+    version = "3.7.121"
 
     _SQRT_2 = np.sqrt(2.0)
 
@@ -710,7 +710,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
         extrema_selection = str(
             self.freqai_info.get(
                 "prediction_extrema_selection",
-                self._EXTREMA_SELECTION_METHODS[0],
+                self._EXTREMA_SELECTION_METHODS[1],
             )
         )
         if extrema_selection not in self._extrema_selection_methods_set():
@@ -718,7 +718,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
                 f"Unsupported extrema selection method: {extrema_selection}. "
                 f"Supported methods are {', '.join(self._EXTREMA_SELECTION_METHODS)}"
             )
-        extrema_selection: ExtremaSelectionMethod = extrema_selection  # type: ignore[assignment]
+        extrema_selection: ExtremaSelectionMethod = extrema_selectionx
         thresholds_smoothing = str(
             self.freqai_info.get("prediction_thresholds_smoothing", "mean")
         )
index 8f53359296836a106b2517db2f9cdcb1ecf9d735..cbafd5fd872db40e6c92be13247fb5fcde8a9bfd 100644 (file)
@@ -40,6 +40,11 @@ from Utils import (
     zlema,
 )
 
+TradeDirection = Literal["long", "short"]
+InterpolationDirection = Literal["direct", "inverse"]
+OrderType = Literal["entry", "exit"]
+TradingMode = Literal["spot", "margin", "futures"]
+
 DfSignature = Tuple[int, Optional[datetime.datetime]]
 CandleDeviationCacheKey = Tuple[
     str, DfSignature, float, float, int, Literal["direct", "inverse"], float
@@ -54,12 +59,6 @@ EXTREMA_COLUMN = "&s-extrema"
 MAXIMA_THRESHOLD_COLUMN = "&s-maxima_threshold"
 MINIMA_THRESHOLD_COLUMN = "&s-minima_threshold"
 
-# Type aliases
-TradeDirection = Literal["long", "short"]
-InterpolationDirection = Literal["direct", "inverse"]
-OrderType = Literal["entry", "exit"]
-TradingMode = Literal["spot", "margin", "futures"]
-
 
 class QuickAdapterV3(IStrategy):
     """
@@ -89,7 +88,7 @@ class QuickAdapterV3(IStrategy):
     _TRADING_MODES: tuple[TradingMode, ...] = ("spot", "margin", "futures")
 
     def version(self) -> str:
-        return "3.3.170"
+        return "3.3.171"
 
     timeframe = "5m"
 
@@ -139,10 +138,6 @@ class QuickAdapterV3(IStrategy):
 
     process_only_new_candles = True
 
-    @cached_property
-    def can_short(self) -> bool:
-        return self.is_short_allowed()
-
     @staticmethod
     def _trade_directions_set() -> set[TradeDirection]:
         return {
@@ -154,6 +149,10 @@ class QuickAdapterV3(IStrategy):
     def _order_types_set() -> set[OrderType]:
         return {QuickAdapterV3._ORDER_TYPES[0], QuickAdapterV3._ORDER_TYPES[1]}
 
+    @cached_property
+    def can_short(self) -> bool:
+        return self.is_short_allowed()
+
     @cached_property
     def plot_config(self) -> dict[str, Any]:
         return {
@@ -666,7 +665,7 @@ class QuickAdapterV3(IStrategy):
         dataframe.loc[
             reduce(lambda x, y: x & y, enter_long_conditions),
             ["enter_long", "enter_tag"],
-        ] = (1, "long")
+        ] = (1, self._TRADE_DIRECTIONS[0])  # "long"
 
         enter_short_conditions = [
             dataframe.get("do_predict") == 1,
@@ -676,7 +675,7 @@ class QuickAdapterV3(IStrategy):
         dataframe.loc[
             reduce(lambda x, y: x & y, enter_short_conditions),
             ["enter_short", "enter_tag"],
-        ] = (1, "short")
+        ] = (1, self._TRADE_DIRECTIONS[1])  # "short"
 
         return dataframe
 
@@ -1395,7 +1394,7 @@ class QuickAdapterV3(IStrategy):
         """
         if df.empty:
             return False
-        if side not in self._sides_set():
+        if side not in self._trade_directions_set():
             return False
         if order not in self._order_types_set():
             return False
@@ -1838,7 +1837,7 @@ class QuickAdapterV3(IStrategy):
         side: str,
         **kwargs,
     ) -> bool:
-        if side not in self._sides_set():
+        if side not in self._trade_directions_set():
             return False
         if side == self._TRADE_DIRECTIONS[1] and not self.can_short:  # "short"
             logger.info(f"User denied short entry for {pair}: shorting not allowed")