]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor(reforcexy): safer type handling
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Wed, 10 Sep 2025 16:48:12 +0000 (18:48 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Wed, 10 Sep 2025 16:48:12 +0000 (18:48 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index c78c079181e7f4aaca61249c213505430c81e65f..04a2d108fe32ae24fa593afde6efef59a5222436 100644 (file)
@@ -10,7 +10,7 @@ from enum import IntEnum
 from functools import lru_cache
 from pathlib import Path
 from statistics import stdev
-from typing import Any, Callable, Dict, Optional, Tuple, Type
+from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
 
 import matplotlib
 import matplotlib.pyplot as plt
@@ -243,51 +243,46 @@ class ReforceXY(BaseReinforcementLearningModel):
         model_params: Dict[str, Any] = copy.deepcopy(self.model_training_parameters)
 
         if self.lr_schedule:
-            lr = float(model_params.get("learning_rate", 0.0003))
-            model_params["learning_rate"] = linear_schedule(lr)
-            logger.info("Learning rate linear schedule enabled, initial value: %s", lr)
+            lr = model_params.get("learning_rate", 0.0003)
+            if isinstance(lr, (int, float)):
+                lr = float(lr)
+                model_params["learning_rate"] = linear_schedule(lr)
+                logger.info(
+                    "Learning rate linear schedule enabled, initial value: %s", lr
+                )
 
         if "PPO" in self.model_type and self.cr_schedule:
-            cr = float(model_params.get("clip_range", 0.2))
-            model_params["clip_range"] = linear_schedule(cr)
-            logger.info("Clip range linear schedule enabled, initial value: %s", cr)
+            cr = model_params.get("clip_range", 0.2)
+            if isinstance(cr, (int, float)):
+                cr = float(cr)
+                model_params["clip_range"] = linear_schedule(cr)
+                logger.info("Clip range linear schedule enabled, initial value: %s", cr)
 
         if "DQN" in self.model_type:
-            gradient_steps = model_params.get("gradient_steps")
-            if gradient_steps is None:
-                train_freq = model_params.get("train_freq")
-                if isinstance(train_freq, (tuple, list)) and train_freq:
-                    train_freq = (
-                        train_freq[0] if isinstance(train_freq[0], int) else None
-                    )
-                else:
-                    train_freq = train_freq if isinstance(train_freq, int) else None
-                subsample_steps = model_params.get("subsample_steps")
-                if (
-                    isinstance(train_freq, int)
-                    and train_freq > 0
-                    and isinstance(subsample_steps, int)
-                    and subsample_steps > 0
-                ):
-                    model_params["gradient_steps"] = min(
-                        train_freq, max(train_freq // subsample_steps, 1)
-                    )
-                else:
-                    model_params["gradient_steps"] = -1
+            if model_params.get("gradient_steps") is None:
+                model_params["gradient_steps"] = compute_gradient_steps(
+                    model_params.get("train_freq"), model_params.get("subsample_steps")
+                )
+            if "subsample_steps" in model_params:
+                model_params.pop("subsample_steps", None)
 
         if not model_params.get("policy_kwargs"):
             model_params["policy_kwargs"] = {}
 
-        net_arch = model_params.get("policy_kwargs", {}).get("net_arch", [128, 128])
+        default_net_arch: list[int] = [128, 128]
+        net_arch: Union[list[int], Dict[str, list[int]]] = model_params.get(
+            "policy_kwargs", {}
+        ).get("net_arch", default_net_arch)
+
         if "PPO" in self.model_type:
             if isinstance(net_arch, str):
-                net_arch = get_net_arch(self.model_type, net_arch)
-                if isinstance(net_arch, dict):
-                    model_params["policy_kwargs"]["net_arch"] = net_arch
+                resolved_net_arch = get_net_arch(self.model_type, net_arch)
+                if isinstance(resolved_net_arch, dict):
+                    model_params["policy_kwargs"]["net_arch"] = resolved_net_arch
                 else:
                     model_params["policy_kwargs"]["net_arch"] = {
-                        "pi": net_arch,
-                        "vf": net_arch,
+                        "pi": resolved_net_arch,
+                        "vf": resolved_net_arch,
                     }
             elif isinstance(net_arch, list):
                 model_params["policy_kwargs"]["net_arch"] = {
@@ -295,10 +290,16 @@ class ReforceXY(BaseReinforcementLearningModel):
                     "vf": net_arch,
                 }
             elif isinstance(net_arch, dict):
-                if not ("pi" in net_arch and "vf" in net_arch):
+                pi: Optional[list[int]] = net_arch.get("pi")
+                vf: Optional[list[int]] = net_arch.get("vf")
+                if not isinstance(pi, list) or not isinstance(vf, list):
                     model_params["policy_kwargs"]["net_arch"] = {
-                        "pi": net_arch.get("pi", net_arch.get("vf", [128, 128])),
-                        "vf": net_arch.get("vf", net_arch.get("pi", [128, 128])),
+                        "pi": pi
+                        if isinstance(pi, list)
+                        else (vf if isinstance(vf, list) else default_net_arch),
+                        "vf": vf
+                        if isinstance(vf, list)
+                        else (pi if isinstance(pi, list) else default_net_arch),
                     }
                 else:
                     model_params["policy_kwargs"]["net_arch"] = net_arch
@@ -313,7 +314,6 @@ class ReforceXY(BaseReinforcementLearningModel):
         model_params["policy_kwargs"]["activation_fn"] = get_activation_fn(
             model_params.get("policy_kwargs", {}).get("activation_fn", "relu")
         )
-
         model_params["policy_kwargs"]["optimizer_class"] = get_optimizer_class(
             model_params.get("policy_kwargs", {}).get("optimizer_class", "adam")
         )
@@ -322,7 +322,7 @@ class ReforceXY(BaseReinforcementLearningModel):
         return copy.deepcopy(self._model_params_cache)
 
     def get_callbacks(
-        self, eval_freq: int, data_path: str, trial: Trial = None
+        self, eval_freq: int, data_path: str, trial: Optional[Trial] = None
     ) -> list[BaseCallback]:
         """
         Get the model specific callbacks
@@ -330,7 +330,7 @@ class ReforceXY(BaseReinforcementLearningModel):
         callbacks: list[BaseCallback] = []
         no_improvement_callback = None
         rollout_plot_callback = None
-        verbose = self.get_model_params().get("verbose", 0)
+        verbose = int(self.get_model_params().get("verbose", 0))
 
         if self.plot_new_best:
             rollout_plot_callback = RolloutPlotCallback(verbose=verbose)
@@ -1481,21 +1481,24 @@ class InfoMetricsCallback(TensorboardCallback):
         )
 
     def _on_step(self) -> bool:
-        infos_list: list[dict[str, Any]] | None = self.locals.get("infos")
+        infos_list: list[Dict[str, Any]] | None = self.locals.get("infos")
         aggregated_info: Dict[str, Any] = {}
-        if infos_list:
-            numeric_acc: dict[str, list[float]] = defaultdict(list)
-            non_numeric_acc: dict[str, set[Any]] = defaultdict(set)
+
+        if isinstance(infos_list, list) and infos_list:
+            numeric_acc: Dict[str, list[float]] = defaultdict(list)
+            non_numeric_acc: Dict[str, set[Any]] = defaultdict(set)
+
             for info_dict in infos_list:
                 if not isinstance(info_dict, dict):
                     continue
                 for k, v in info_dict.items():
-                    if k in ("episode", "terminal_observation", "TimeLimit.truncated"):
+                    if k in {"episode", "terminal_observation", "TimeLimit.truncated"}:
                         continue
                     if isinstance(v, (int, float)) and not isinstance(v, bool):
                         numeric_acc[k].append(float(v))
                     else:
                         non_numeric_acc[k].add(v)
+
             for k, values in numeric_acc.items():
                 if not values:
                     continue
@@ -1506,39 +1509,34 @@ class InfoMetricsCallback(TensorboardCallback):
                         aggregated_info[f"{k}_std"] = stdev(values)
                     except Exception:
                         pass
+
             for key in ("reward", "pnl"):
-                values = numeric_acc.get(key, [])
+                values = numeric_acc.get(key)
                 if values:
-                    try:
-                        aggregated_info[f"{key}_min"] = float(min(values))
-                        aggregated_info[f"{key}_max"] = float(max(values))
-                    except Exception:
-                        pass
+                    aggregated_info[f"{key}_min"] = float(min(values))
+                    aggregated_info[f"{key}_max"] = float(max(values))
+
             for k, values in non_numeric_acc.items():
-                if len(values) == 1:
-                    aggregated_info[k] = next(iter(values))
-                else:
-                    aggregated_info[k] = "mixed"
+                aggregated_info[k] = next(iter(values)) if len(values) == 1 else "mixed"
 
         if self.training_env is None:
             return True
 
         try:
             tensorboard_metrics_list = self.training_env.get_attr("tensorboard_metrics")
+            tensorboard_metrics = (
+                tensorboard_metrics_list[0] if tensorboard_metrics_list else {}
+            )
         except Exception:
-            tensorboard_metrics_list = []
-        tensorboard_metrics = (
-            tensorboard_metrics_list[0] if tensorboard_metrics_list else {}
-        )
+            tensorboard_metrics = {}
 
         for metric, value in aggregated_info.items():
             self.logger.record(f"info/{metric}", value)
 
         for category, metrics in tensorboard_metrics.items():
-            if not isinstance(metrics, dict):
-                continue
-            for metric, value in metrics.items():
-                self.logger.record(f"{category}/{metric}", value)
+            if isinstance(metrics, dict):
+                for metric, value in metrics.items():
+                    self.logger.record(f"{category}/{metric}", value)
         return True
 
 
@@ -1646,7 +1644,7 @@ def make_env(
     return _init
 
 
-def deepmerge(dst: dict[str, Any], src: dict[str, Any]) -> dict[str, Any]:
+def deepmerge(dst: Dict[str, Any], src: Dict[str, Any]) -> Dict[str, Any]:
     """Recursively merge two dicts without mutating inputs"""
     dst_copy = copy.deepcopy(dst)
     for k, v in src.items():
@@ -1668,6 +1666,20 @@ def linear_schedule(initial_value: float) -> Callable[[float], float]:
     return func
 
 
+def compute_gradient_steps(train_freq: Any, subsample_steps: Any) -> int:
+    tf: Optional[int] = None
+    if isinstance(train_freq, (tuple, list)) and train_freq:
+        tf = train_freq[0] if isinstance(train_freq[0], int) else None
+    elif isinstance(train_freq, int):
+        tf = train_freq
+
+    ss: Optional[int] = subsample_steps if isinstance(subsample_steps, int) else None
+
+    if isinstance(tf, int) and tf > 0 and isinstance(ss, int) and ss > 0:
+        return min(tf, max(tf // ss, 1))
+    return -1
+
+
 @lru_cache(maxsize=32)
 def hours_to_seconds(hours: float) -> float:
     """
@@ -1689,7 +1701,7 @@ def steps_to_days(steps: int, timeframe: str) -> float:
 
 def get_net_arch(
     model_type: str, net_arch_type: str
-) -> Dict[str, list[int] | Dict[str, list[int]]]:
+) -> Union[list[int], Dict[str, list[int]]]:
     """
     Get network architecture
     """
@@ -1732,76 +1744,89 @@ def get_optimizer_class(optimizer_class_name: str) -> type[th.optim.Optimizer]:
 def convert_optuna_params_to_model_params(
     model_type: str, optuna_params: Dict[str, Any]
 ) -> Dict[str, Any]:
-    model_params = {}
-    policy_kwargs = {}
+    model_params: Dict[str, Any] = {}
+    policy_kwargs: Dict[str, Any] = {}
 
     lr = optuna_params.get("learning_rate")
-    if optuna_params.get("lr_schedule", "") == "linear":
+    if lr is None:
+        raise ValueError(f"missing 'learning_rate' in optuna params for {model_type}")
+    lr: float | Callable[[float], float] = float(lr)
+    if optuna_params.get("lr_schedule") == "linear":
         lr = linear_schedule(lr)
 
     if "PPO" in model_type:
         cr = optuna_params.get("clip_range")
+        if cr is None:
+            raise ValueError(f"missing 'clip_range' in optuna params for {model_type}")
+        cr: float | Callable[[float], float] = float(cr)
         if optuna_params.get("cr_schedule") == "linear":
             cr = linear_schedule(cr)
+
         model_params.update(
             {
-                "n_steps": optuna_params.get("n_steps"),
-                "batch_size": optuna_params.get("batch_size"),
-                "gamma": optuna_params.get("gamma"),
+                "n_steps": int(optuna_params.get("n_steps")),
+                "batch_size": int(optuna_params.get("batch_size")),
+                "gamma": float(optuna_params.get("gamma")),
                 "learning_rate": lr,
-                "ent_coef": optuna_params.get("ent_coef"),
+                "ent_coef": float(optuna_params.get("ent_coef")),
                 "clip_range": cr,
-                "n_epochs": optuna_params.get("n_epochs"),
-                "gae_lambda": optuna_params.get("gae_lambda"),
-                "max_grad_norm": optuna_params.get("max_grad_norm"),
-                "vf_coef": optuna_params.get("vf_coef"),
+                "n_epochs": int(optuna_params.get("n_epochs")),
+                "gae_lambda": float(optuna_params.get("gae_lambda")),
+                "max_grad_norm": float(optuna_params.get("max_grad_norm")),
+                "vf_coef": float(optuna_params.get("vf_coef")),
             }
         )
         if optuna_params.get("target_kl") is not None:
-            model_params["target_kl"] = optuna_params.get("target_kl")
+            model_params["target_kl"] = float(optuna_params.get("target_kl"))
     elif "DQN" in model_type:
         train_freq = optuna_params.get("train_freq")
         subsample_steps = optuna_params.get("subsample_steps")
+        gradient_steps = compute_gradient_steps(train_freq, subsample_steps)
+
         model_params.update(
             {
-                "gamma": optuna_params.get("gamma"),
-                "batch_size": optuna_params.get("batch_size"),
+                "gamma": float(optuna_params.get("gamma")),
+                "batch_size": int(optuna_params.get("batch_size")),
                 "learning_rate": lr,
-                "buffer_size": optuna_params.get("buffer_size"),
+                "buffer_size": int(optuna_params.get("buffer_size")),
                 "train_freq": train_freq,
-                "gradient_steps": min(
-                    train_freq,
-                    max(
-                        train_freq // subsample_steps,
-                        1,
-                    ),
+                "gradient_steps": gradient_steps,
+                "exploration_fraction": float(
+                    optuna_params.get("exploration_fraction")
+                ),
+                "exploration_initial_eps": float(
+                    optuna_params.get("exploration_initial_eps")
+                ),
+                "exploration_final_eps": float(
+                    optuna_params.get("exploration_final_eps")
                 ),
-                "exploration_fraction": optuna_params.get("exploration_fraction"),
-                "exploration_initial_eps": optuna_params.get("exploration_initial_eps"),
-                "exploration_final_eps": optuna_params.get("exploration_final_eps"),
-                "target_update_interval": optuna_params.get("target_update_interval"),
-                "learning_starts": optuna_params.get("learning_starts"),
+                "target_update_interval": int(
+                    optuna_params.get("target_update_interval")
+                ),
+                "learning_starts": int(optuna_params.get("learning_starts")),
             }
         )
-        if "QRDQN" in model_type and optuna_params.get("n_quantiles"):
-            policy_kwargs["n_quantiles"] = optuna_params["n_quantiles"]
+        if "QRDQN" in model_type and optuna_params.get("n_quantiles") is not None:
+            policy_kwargs["n_quantiles"] = int(optuna_params["n_quantiles"])
     else:
         raise ValueError(f"Model {model_type} not supported")
 
     if optuna_params.get("net_arch"):
-        policy_kwargs["net_arch"] = get_net_arch(model_type, optuna_params["net_arch"])
+        policy_kwargs["net_arch"] = get_net_arch(
+            model_type, str(optuna_params["net_arch"])
+        )
     if optuna_params.get("activation_fn"):
         policy_kwargs["activation_fn"] = get_activation_fn(
-            optuna_params["activation_fn"]
+            str(optuna_params["activation_fn"])
         )
     if optuna_params.get("optimizer_class"):
         policy_kwargs["optimizer_class"] = get_optimizer_class(
-            optuna_params["optimizer_class"]
+            str(optuna_params["optimizer_class"])
         )
     if optuna_params.get("ortho_init") is not None:
-        policy_kwargs["ortho_init"] = optuna_params["ortho_init"]
-    model_params["policy_kwargs"] = policy_kwargs
+        policy_kwargs["ortho_init"] = bool(optuna_params["ortho_init"])
 
+    model_params["policy_kwargs"] = policy_kwargs
     return model_params