From: Jérôme Benoit Date: Wed, 10 Sep 2025 16:48:12 +0000 (+0200) Subject: refactor(reforcexy): safer type handling X-Git-Url: https://git.piment-noir.org/?a=commitdiff_plain;h=9baf51671f7e03f5b865b534c245788e8d20eb23;p=freqai-strategies.git refactor(reforcexy): safer type handling Signed-off-by: Jérôme Benoit --- diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index c78c079..04a2d10 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -10,7 +10,7 @@ from enum import IntEnum from functools import lru_cache from pathlib import Path from statistics import stdev -from typing import Any, Callable, Dict, Optional, Tuple, Type +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union import matplotlib import matplotlib.pyplot as plt @@ -243,51 +243,46 @@ class ReforceXY(BaseReinforcementLearningModel): model_params: Dict[str, Any] = copy.deepcopy(self.model_training_parameters) if self.lr_schedule: - lr = float(model_params.get("learning_rate", 0.0003)) - model_params["learning_rate"] = linear_schedule(lr) - logger.info("Learning rate linear schedule enabled, initial value: %s", lr) + lr = model_params.get("learning_rate", 0.0003) + if isinstance(lr, (int, float)): + lr = float(lr) + model_params["learning_rate"] = linear_schedule(lr) + logger.info( + "Learning rate linear schedule enabled, initial value: %s", lr + ) if "PPO" in self.model_type and self.cr_schedule: - cr = float(model_params.get("clip_range", 0.2)) - model_params["clip_range"] = linear_schedule(cr) - logger.info("Clip range linear schedule enabled, initial value: %s", cr) + cr = model_params.get("clip_range", 0.2) + if isinstance(cr, (int, float)): + cr = float(cr) + model_params["clip_range"] = linear_schedule(cr) + logger.info("Clip range linear schedule enabled, initial value: %s", cr) if "DQN" in self.model_type: - gradient_steps = model_params.get("gradient_steps") - if gradient_steps is None: - train_freq = model_params.get("train_freq") - if isinstance(train_freq, (tuple, list)) and train_freq: - train_freq = ( - train_freq[0] if isinstance(train_freq[0], int) else None - ) - else: - train_freq = train_freq if isinstance(train_freq, int) else None - subsample_steps = model_params.get("subsample_steps") - if ( - isinstance(train_freq, int) - and train_freq > 0 - and isinstance(subsample_steps, int) - and subsample_steps > 0 - ): - model_params["gradient_steps"] = min( - train_freq, max(train_freq // subsample_steps, 1) - ) - else: - model_params["gradient_steps"] = -1 + if model_params.get("gradient_steps") is None: + model_params["gradient_steps"] = compute_gradient_steps( + model_params.get("train_freq"), model_params.get("subsample_steps") + ) + if "subsample_steps" in model_params: + model_params.pop("subsample_steps", None) if not model_params.get("policy_kwargs"): model_params["policy_kwargs"] = {} - net_arch = model_params.get("policy_kwargs", {}).get("net_arch", [128, 128]) + default_net_arch: list[int] = [128, 128] + net_arch: Union[list[int], Dict[str, list[int]]] = model_params.get( + "policy_kwargs", {} + ).get("net_arch", default_net_arch) + if "PPO" in self.model_type: if isinstance(net_arch, str): - net_arch = get_net_arch(self.model_type, net_arch) - if isinstance(net_arch, dict): - model_params["policy_kwargs"]["net_arch"] = net_arch + resolved_net_arch = get_net_arch(self.model_type, net_arch) + if isinstance(resolved_net_arch, dict): + model_params["policy_kwargs"]["net_arch"] = resolved_net_arch else: model_params["policy_kwargs"]["net_arch"] = { - "pi": net_arch, - "vf": net_arch, + "pi": resolved_net_arch, + "vf": resolved_net_arch, } elif isinstance(net_arch, list): model_params["policy_kwargs"]["net_arch"] = { @@ -295,10 +290,16 @@ class ReforceXY(BaseReinforcementLearningModel): "vf": net_arch, } elif isinstance(net_arch, dict): - if not ("pi" in net_arch and "vf" in net_arch): + pi: Optional[list[int]] = net_arch.get("pi") + vf: Optional[list[int]] = net_arch.get("vf") + if not isinstance(pi, list) or not isinstance(vf, list): model_params["policy_kwargs"]["net_arch"] = { - "pi": net_arch.get("pi", net_arch.get("vf", [128, 128])), - "vf": net_arch.get("vf", net_arch.get("pi", [128, 128])), + "pi": pi + if isinstance(pi, list) + else (vf if isinstance(vf, list) else default_net_arch), + "vf": vf + if isinstance(vf, list) + else (pi if isinstance(pi, list) else default_net_arch), } else: model_params["policy_kwargs"]["net_arch"] = net_arch @@ -313,7 +314,6 @@ class ReforceXY(BaseReinforcementLearningModel): model_params["policy_kwargs"]["activation_fn"] = get_activation_fn( model_params.get("policy_kwargs", {}).get("activation_fn", "relu") ) - model_params["policy_kwargs"]["optimizer_class"] = get_optimizer_class( model_params.get("policy_kwargs", {}).get("optimizer_class", "adam") ) @@ -322,7 +322,7 @@ class ReforceXY(BaseReinforcementLearningModel): return copy.deepcopy(self._model_params_cache) def get_callbacks( - self, eval_freq: int, data_path: str, trial: Trial = None + self, eval_freq: int, data_path: str, trial: Optional[Trial] = None ) -> list[BaseCallback]: """ Get the model specific callbacks @@ -330,7 +330,7 @@ class ReforceXY(BaseReinforcementLearningModel): callbacks: list[BaseCallback] = [] no_improvement_callback = None rollout_plot_callback = None - verbose = self.get_model_params().get("verbose", 0) + verbose = int(self.get_model_params().get("verbose", 0)) if self.plot_new_best: rollout_plot_callback = RolloutPlotCallback(verbose=verbose) @@ -1481,21 +1481,24 @@ class InfoMetricsCallback(TensorboardCallback): ) def _on_step(self) -> bool: - infos_list: list[dict[str, Any]] | None = self.locals.get("infos") + infos_list: list[Dict[str, Any]] | None = self.locals.get("infos") aggregated_info: Dict[str, Any] = {} - if infos_list: - numeric_acc: dict[str, list[float]] = defaultdict(list) - non_numeric_acc: dict[str, set[Any]] = defaultdict(set) + + if isinstance(infos_list, list) and infos_list: + numeric_acc: Dict[str, list[float]] = defaultdict(list) + non_numeric_acc: Dict[str, set[Any]] = defaultdict(set) + for info_dict in infos_list: if not isinstance(info_dict, dict): continue for k, v in info_dict.items(): - if k in ("episode", "terminal_observation", "TimeLimit.truncated"): + if k in {"episode", "terminal_observation", "TimeLimit.truncated"}: continue if isinstance(v, (int, float)) and not isinstance(v, bool): numeric_acc[k].append(float(v)) else: non_numeric_acc[k].add(v) + for k, values in numeric_acc.items(): if not values: continue @@ -1506,39 +1509,34 @@ class InfoMetricsCallback(TensorboardCallback): aggregated_info[f"{k}_std"] = stdev(values) except Exception: pass + for key in ("reward", "pnl"): - values = numeric_acc.get(key, []) + values = numeric_acc.get(key) if values: - try: - aggregated_info[f"{key}_min"] = float(min(values)) - aggregated_info[f"{key}_max"] = float(max(values)) - except Exception: - pass + aggregated_info[f"{key}_min"] = float(min(values)) + aggregated_info[f"{key}_max"] = float(max(values)) + for k, values in non_numeric_acc.items(): - if len(values) == 1: - aggregated_info[k] = next(iter(values)) - else: - aggregated_info[k] = "mixed" + aggregated_info[k] = next(iter(values)) if len(values) == 1 else "mixed" if self.training_env is None: return True try: tensorboard_metrics_list = self.training_env.get_attr("tensorboard_metrics") + tensorboard_metrics = ( + tensorboard_metrics_list[0] if tensorboard_metrics_list else {} + ) except Exception: - tensorboard_metrics_list = [] - tensorboard_metrics = ( - tensorboard_metrics_list[0] if tensorboard_metrics_list else {} - ) + tensorboard_metrics = {} for metric, value in aggregated_info.items(): self.logger.record(f"info/{metric}", value) for category, metrics in tensorboard_metrics.items(): - if not isinstance(metrics, dict): - continue - for metric, value in metrics.items(): - self.logger.record(f"{category}/{metric}", value) + if isinstance(metrics, dict): + for metric, value in metrics.items(): + self.logger.record(f"{category}/{metric}", value) return True @@ -1646,7 +1644,7 @@ def make_env( return _init -def deepmerge(dst: dict[str, Any], src: dict[str, Any]) -> dict[str, Any]: +def deepmerge(dst: Dict[str, Any], src: Dict[str, Any]) -> Dict[str, Any]: """Recursively merge two dicts without mutating inputs""" dst_copy = copy.deepcopy(dst) for k, v in src.items(): @@ -1668,6 +1666,20 @@ def linear_schedule(initial_value: float) -> Callable[[float], float]: return func +def compute_gradient_steps(train_freq: Any, subsample_steps: Any) -> int: + tf: Optional[int] = None + if isinstance(train_freq, (tuple, list)) and train_freq: + tf = train_freq[0] if isinstance(train_freq[0], int) else None + elif isinstance(train_freq, int): + tf = train_freq + + ss: Optional[int] = subsample_steps if isinstance(subsample_steps, int) else None + + if isinstance(tf, int) and tf > 0 and isinstance(ss, int) and ss > 0: + return min(tf, max(tf // ss, 1)) + return -1 + + @lru_cache(maxsize=32) def hours_to_seconds(hours: float) -> float: """ @@ -1689,7 +1701,7 @@ def steps_to_days(steps: int, timeframe: str) -> float: def get_net_arch( model_type: str, net_arch_type: str -) -> Dict[str, list[int] | Dict[str, list[int]]]: +) -> Union[list[int], Dict[str, list[int]]]: """ Get network architecture """ @@ -1732,76 +1744,89 @@ def get_optimizer_class(optimizer_class_name: str) -> type[th.optim.Optimizer]: def convert_optuna_params_to_model_params( model_type: str, optuna_params: Dict[str, Any] ) -> Dict[str, Any]: - model_params = {} - policy_kwargs = {} + model_params: Dict[str, Any] = {} + policy_kwargs: Dict[str, Any] = {} lr = optuna_params.get("learning_rate") - if optuna_params.get("lr_schedule", "") == "linear": + if lr is None: + raise ValueError(f"missing 'learning_rate' in optuna params for {model_type}") + lr: float | Callable[[float], float] = float(lr) + if optuna_params.get("lr_schedule") == "linear": lr = linear_schedule(lr) if "PPO" in model_type: cr = optuna_params.get("clip_range") + if cr is None: + raise ValueError(f"missing 'clip_range' in optuna params for {model_type}") + cr: float | Callable[[float], float] = float(cr) if optuna_params.get("cr_schedule") == "linear": cr = linear_schedule(cr) + model_params.update( { - "n_steps": optuna_params.get("n_steps"), - "batch_size": optuna_params.get("batch_size"), - "gamma": optuna_params.get("gamma"), + "n_steps": int(optuna_params.get("n_steps")), + "batch_size": int(optuna_params.get("batch_size")), + "gamma": float(optuna_params.get("gamma")), "learning_rate": lr, - "ent_coef": optuna_params.get("ent_coef"), + "ent_coef": float(optuna_params.get("ent_coef")), "clip_range": cr, - "n_epochs": optuna_params.get("n_epochs"), - "gae_lambda": optuna_params.get("gae_lambda"), - "max_grad_norm": optuna_params.get("max_grad_norm"), - "vf_coef": optuna_params.get("vf_coef"), + "n_epochs": int(optuna_params.get("n_epochs")), + "gae_lambda": float(optuna_params.get("gae_lambda")), + "max_grad_norm": float(optuna_params.get("max_grad_norm")), + "vf_coef": float(optuna_params.get("vf_coef")), } ) if optuna_params.get("target_kl") is not None: - model_params["target_kl"] = optuna_params.get("target_kl") + model_params["target_kl"] = float(optuna_params.get("target_kl")) elif "DQN" in model_type: train_freq = optuna_params.get("train_freq") subsample_steps = optuna_params.get("subsample_steps") + gradient_steps = compute_gradient_steps(train_freq, subsample_steps) + model_params.update( { - "gamma": optuna_params.get("gamma"), - "batch_size": optuna_params.get("batch_size"), + "gamma": float(optuna_params.get("gamma")), + "batch_size": int(optuna_params.get("batch_size")), "learning_rate": lr, - "buffer_size": optuna_params.get("buffer_size"), + "buffer_size": int(optuna_params.get("buffer_size")), "train_freq": train_freq, - "gradient_steps": min( - train_freq, - max( - train_freq // subsample_steps, - 1, - ), + "gradient_steps": gradient_steps, + "exploration_fraction": float( + optuna_params.get("exploration_fraction") + ), + "exploration_initial_eps": float( + optuna_params.get("exploration_initial_eps") + ), + "exploration_final_eps": float( + optuna_params.get("exploration_final_eps") ), - "exploration_fraction": optuna_params.get("exploration_fraction"), - "exploration_initial_eps": optuna_params.get("exploration_initial_eps"), - "exploration_final_eps": optuna_params.get("exploration_final_eps"), - "target_update_interval": optuna_params.get("target_update_interval"), - "learning_starts": optuna_params.get("learning_starts"), + "target_update_interval": int( + optuna_params.get("target_update_interval") + ), + "learning_starts": int(optuna_params.get("learning_starts")), } ) - if "QRDQN" in model_type and optuna_params.get("n_quantiles"): - policy_kwargs["n_quantiles"] = optuna_params["n_quantiles"] + if "QRDQN" in model_type and optuna_params.get("n_quantiles") is not None: + policy_kwargs["n_quantiles"] = int(optuna_params["n_quantiles"]) else: raise ValueError(f"Model {model_type} not supported") if optuna_params.get("net_arch"): - policy_kwargs["net_arch"] = get_net_arch(model_type, optuna_params["net_arch"]) + policy_kwargs["net_arch"] = get_net_arch( + model_type, str(optuna_params["net_arch"]) + ) if optuna_params.get("activation_fn"): policy_kwargs["activation_fn"] = get_activation_fn( - optuna_params["activation_fn"] + str(optuna_params["activation_fn"]) ) if optuna_params.get("optimizer_class"): policy_kwargs["optimizer_class"] = get_optimizer_class( - optuna_params["optimizer_class"] + str(optuna_params["optimizer_class"]) ) if optuna_params.get("ortho_init") is not None: - policy_kwargs["ortho_init"] = optuna_params["ortho_init"] - model_params["policy_kwargs"] = policy_kwargs + policy_kwargs["ortho_init"] = bool(optuna_params["ortho_init"]) + model_params["policy_kwargs"] = policy_kwargs return model_params