]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor(reforcexy): more constants usage
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Wed, 19 Nov 2025 01:58:49 +0000 (02:58 +0100)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Wed, 19 Nov 2025 01:58:49 +0000 (02:58 +0100)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index 37222c97f8772c82844a24df1141bf54d591e15a..37ab6e2aa3dc96f31fd741bfdf947f53df0733e3 100644 (file)
@@ -1084,7 +1084,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             )
         else:
             raise ValueError(
-                f"Unsupported sampler: {sampler!r}. Supported samplers: {', '.join(self._SAMPLER_TYPES)}"
+                f"Unsupported sampler: {sampler}. Supported samplers: {', '.join(self._SAMPLER_TYPES)}"
             )
 
     @staticmethod
@@ -1656,7 +1656,7 @@ class MyRLEnv(Base5ActionRLEnv):
             if self._entry_additive_enabled or self._exit_additive_enabled:
                 logger.info(
                     "PBRS canonical mode: additive rewards disabled with Φ(terminal)=0. PBRS invariance is preserved. "
-                    f"To use additive rewards, set exit_potential_mode={ReforceXY._EXIT_POTENTIAL_MODES[1]!r}."
+                    f"To use additive rewards, set exit_potential_mode={ReforceXY._EXIT_POTENTIAL_MODES[1]}."
                 )
                 self._entry_additive_enabled = False
                 self._exit_additive_enabled = False
@@ -3175,7 +3175,7 @@ class InfoMetricsCallback(TensorboardCallback):
                 hparam_dict.update({"n_updates": int(n_updates)})
         except Exception:
             pass
-        if "PPO" in self.model.__class__.__name__:
+        if ReforceXY._MODEL_TYPES[0] in self.model.__class__.__name__:  # "PPO"
             cr = getattr(self.model, "clip_range", None)
             cr_schedule, cr_iv, cr_fv = get_schedule_type(cr)
             hparam_dict.update(
@@ -3193,7 +3193,9 @@ class InfoMetricsCallback(TensorboardCallback):
             )
             if getattr(self.model, "target_kl", None) is not None:
                 hparam_dict["target_kl"] = float(self.model.target_kl)
-            if "RecurrentPPO" in self.model.__class__.__name__:
+            if (
+                ReforceXY._MODEL_TYPES[1] in self.model.__class__.__name__
+            ):  # "RecurrentPPO"
                 policy = getattr(self.model, "policy", None)
                 if policy is not None:
                     lstm_actor = getattr(policy, "lstm_actor", None)
@@ -3204,7 +3206,7 @@ class InfoMetricsCallback(TensorboardCallback):
                                 "n_lstm_layers": int(lstm_actor.num_layers),
                             }
                         )
-        if "DQN" in self.model.__class__.__name__:
+        if ReforceXY._MODEL_TYPES[3] in self.model.__class__.__name__:  # "DQN"
             hparam_dict.update(
                 {
                     "buffer_size": int(self.model.buffer_size),
@@ -3224,7 +3226,7 @@ class InfoMetricsCallback(TensorboardCallback):
             )
             if train_freq is not None:
                 hparam_dict.update({"train_freq": train_freq})
-            if "QRDQN" in self.model.__class__.__name__:
+            if ReforceXY._MODEL_TYPES[4] in self.model.__class__.__name__:  # "QRDQN"
                 hparam_dict.update({"n_quantiles": int(self.model.n_quantiles)})
         metric_dict: dict[str, float | int] = {
             "eval/mean_reward": 0.0,
@@ -3240,7 +3242,7 @@ class InfoMetricsCallback(TensorboardCallback):
             "info/trade_count": 0,
             "info/trade_duration": 0,
         }
-        if "PPO" in self.model.__class__.__name__:
+        if ReforceXY._MODEL_TYPES[0] in self.model.__class__.__name__:  # "PPO"
             metric_dict.update(
                 {
                     "train/approx_kl": 0.0,
@@ -3252,7 +3254,7 @@ class InfoMetricsCallback(TensorboardCallback):
                     "train/explained_variance": 0.0,
                 }
             )
-        if "DQN" in self.model.__class__.__name__:
+        if ReforceXY._MODEL_TYPES[3] in self.model.__class__.__name__:  # "DQN"
             metric_dict.update(
                 {
                     "train/loss": 0.0,
@@ -3491,7 +3493,7 @@ class InfoMetricsCallback(TensorboardCallback):
         except Exception:
             pass
 
-        if "PPO" in self.model.__class__.__name__:
+        if ReforceXY._MODEL_TYPES[0] in self.model.__class__.__name__:  # "PPO"
             try:
                 cr = getattr(self.model, "clip_range", None)
                 cr = _eval_schedule(cr)
@@ -3502,7 +3504,7 @@ class InfoMetricsCallback(TensorboardCallback):
             except Exception:
                 pass
 
-        if "DQN" in self.model.__class__.__name__:
+        if ReforceXY._MODEL_TYPES[3] in self.model.__class__.__name__:  # "DQN"
             try:
                 er = getattr(self.model, "exploration_rate", None)
                 if _is_finite_number(er):
@@ -3738,80 +3740,100 @@ def get_schedule_type(
     if isinstance(schedule, (int, float)):
         try:
             schedule = float(schedule)
-            return "constant", schedule, schedule
+            return ReforceXY._SCHEDULE_TYPES[1], schedule, schedule  # "constant"
         except Exception:
-            return "constant", np.nan, np.nan
+            return ReforceXY._SCHEDULE_TYPES[1], np.nan, np.nan  # "constant"
     elif isinstance(schedule, ConstantSchedule):
         try:
-            return "constant", schedule(1.0), schedule(0.0)
+            return (
+                ReforceXY._SCHEDULE_TYPES[1],
+                schedule(1.0),
+                schedule(0.0),
+            )  # "constant"
         except Exception:
-            return "constant", np.nan, np.nan
+            return ReforceXY._SCHEDULE_TYPES[1], np.nan, np.nan  # "constant"
     elif isinstance(schedule, SimpleLinearSchedule):
         try:
-            return "linear", schedule(1.0), schedule(0.0)
+            return (
+                ReforceXY._SCHEDULE_TYPES[0],
+                schedule(1.0),
+                schedule(0.0),
+            )  # "linear"
         except Exception:
-            return "linear", np.nan, np.nan
+            return ReforceXY._SCHEDULE_TYPES[0], np.nan, np.nan  # "linear"
 
-    return "unknown", np.nan, np.nan
+    return ReforceXY._SCHEDULE_TYPES[2], np.nan, np.nan  # "unknown"
 
 
 def get_schedule(
     schedule_type: Literal["linear", "constant"],
     initial_value: float,
 ) -> Callable[[float], float]:
-    if schedule_type == ReforceXY._SCHEDULE_TYPES[0]:  # "linear"
+    if schedule_type == ReforceXY._SCHEDULE_TYPES[0]:
         return SimpleLinearSchedule(initial_value)
-    elif schedule_type == ReforceXY._SCHEDULE_TYPES[1]:  # "constant"
+    elif schedule_type == ReforceXY._SCHEDULE_TYPES[1]:
         return ConstantSchedule(initial_value)
     else:
         return ConstantSchedule(initial_value)
 
 
 def get_net_arch(
-    model_type: str, net_arch_type: Literal["small", "medium", "large", "extra_large"]
+    model_type: str, net_arch_type: NetArchSize
 ) -> Union[List[int], Dict[str, List[int]]]:
     """
     Get network architecture
     """
-    if "PPO" in model_type:
+    if ReforceXY._MODEL_TYPES[0] in model_type:  # "PPO"
         return {
-            "small": {"pi": [128, 128], "vf": [128, 128]},
-            "medium": {"pi": [256, 256], "vf": [256, 256]},
-            "large": {"pi": [512, 512], "vf": [512, 512]},
-            "extra_large": {"pi": [1024, 1024], "vf": [1024, 1024]},
+            ReforceXY._NET_ARCH_SIZES[0]: {
+                "pi": [128, 128],
+                "vf": [128, 128],
+            },  # ReforceXY._NET_ARCH_SIZES[0]
+            ReforceXY._NET_ARCH_SIZES[1]: {
+                "pi": [256, 256],
+                "vf": [256, 256],
+            },  # ReforceXY._NET_ARCH_SIZES[1]
+            ReforceXY._NET_ARCH_SIZES[2]: {
+                "pi": [512, 512],
+                "vf": [512, 512],
+            },  # "large"
+            ReforceXY._NET_ARCH_SIZES[3]: {
+                "pi": [1024, 1024],
+                "vf": [1024, 1024],
+            },  # "extra_large"
         }.get(net_arch_type, {"pi": [128, 128], "vf": [128, 128]})
     return {
-        "small": [128, 128],
-        "medium": [256, 256],
-        "large": [512, 512],
-        "extra_large": [1024, 1024],
+        ReforceXY._NET_ARCH_SIZES[0]: [128, 128],  # "small"
+        ReforceXY._NET_ARCH_SIZES[1]: [256, 256],  # "medium"
+        ReforceXY._NET_ARCH_SIZES[2]: [512, 512],  # "large"
+        ReforceXY._NET_ARCH_SIZES[3]: [1024, 1024],  # "extra_large"
     }.get(net_arch_type, [128, 128])
 
 
 def get_activation_fn(
-    activation_fn_name: Literal["tanh", "relu", "elu", "leaky_relu"],
+    activation_fn_name: ActivationFunction,
 ) -> Type[th.nn.Module]:
     """
     Get activation function
     """
     return {
-        "tanh": th.nn.Tanh,
-        "relu": th.nn.ReLU,
-        "elu": th.nn.ELU,
-        "leaky_relu": th.nn.LeakyReLU,
+        ReforceXY._ACTIVATION_FUNCTIONS[0]: th.nn.Tanh,  # "tanh"
+        ReforceXY._ACTIVATION_FUNCTIONS[1]: th.nn.ReLU,  # "relu"
+        ReforceXY._ACTIVATION_FUNCTIONS[2]: th.nn.ELU,  # "elu"
+        ReforceXY._ACTIVATION_FUNCTIONS[3]: th.nn.LeakyReLU,  # "leaky_relu"
     }.get(activation_fn_name, th.nn.ReLU)
 
 
 def get_optimizer_class(
-    optimizer_class_name: Literal["adam", "adamw", "rmsprop"],
+    optimizer_class_name: OptimizerClass,
 ) -> Type[th.optim.Optimizer]:
     """
     Get optimizer class
     """
     return {
-        "adam": th.optim.Adam,
-        "adamw": th.optim.AdamW,
-        "rmsprop": th.optim.RMSprop,
+        ReforceXY._OPTIMIZER_CLASSES[0]: th.optim.Adam,  # "adam"
+        ReforceXY._OPTIMIZER_CLASSES[1]: th.optim.AdamW,  # "adamw"
+        ReforceXY._OPTIMIZER_CLASSES[2]: th.optim.RMSprop,  # "rmsprop"
     }.get(optimizer_class_name, th.optim.Adam)
 
 
@@ -3823,14 +3845,12 @@ def convert_optuna_params_to_model_params(
 
     lr = optuna_params.get("learning_rate")
     if lr is None:
-        raise ValueError(
-            f"missing {'learning_rate'!r} in optuna params for {model_type}"
-        )
+        raise ValueError(f"missing {'learning_rate'} in optuna params for {model_type}")
     lr = get_schedule(
         optuna_params.get("lr_schedule", ReforceXY._SCHEDULE_TYPES[1]), float(lr)
     )  # default: "constant"
 
-    if "PPO" in model_type:
+    if ReforceXY._MODEL_TYPES[0] in model_type:  # "PPO"
         required_ppo_params = [
             "clip_range",
             "n_steps",
@@ -3867,12 +3887,12 @@ def convert_optuna_params_to_model_params(
         )
         if optuna_params.get("target_kl") is not None:
             model_params["target_kl"] = float(optuna_params.get("target_kl"))
-        if "RecurrentPPO" in model_type:
+        if ReforceXY._MODEL_TYPES[1] in model_type:  # "RecurrentPPO"
             policy_kwargs["lstm_hidden_size"] = int(
                 optuna_params.get("lstm_hidden_size")
             )
             policy_kwargs["n_lstm_layers"] = int(optuna_params.get("n_lstm_layers"))
-    elif "DQN" in model_type:
+    elif ReforceXY._MODEL_TYPES[3] in model_type:  # "DQN"
         required_dqn_params = [
             "gamma",
             "batch_size",
@@ -3915,7 +3935,10 @@ def convert_optuna_params_to_model_params(
                 "learning_starts": int(optuna_params.get("learning_starts")),
             }
         )
-        if "QRDQN" in model_type and optuna_params.get("n_quantiles") is not None:
+        if (
+            ReforceXY._MODEL_TYPES[4] in model_type
+            and optuna_params.get("n_quantiles") is not None
+        ):  # "QRDQN"
             policy_kwargs["n_quantiles"] = int(optuna_params["n_quantiles"])
     else:
         raise ValueError(f"Model {model_type} not supported")
@@ -4038,7 +4061,9 @@ def get_common_dqn_optuna_params(trial: Trial) -> Dict[str, Any]:
             "batch_size", [64, 128, 256, 512, 1024]
         ),
         "learning_rate": trial.suggest_float("learning_rate", 1e-5, 3e-3, log=True),
-        "lr_schedule": trial.suggest_categorical("lr_schedule", ["linear", "constant"]),
+        "lr_schedule": trial.suggest_categorical(
+            "lr_schedule", list(ReforceXY._SCHEDULE_TYPES[:2])
+        ),  # ["linear", "constant"]
         "buffer_size": trial.suggest_categorical(
             "buffer_size", [int(1e4), int(5e4), int(1e5), int(2e5)]
         ),