from functools import lru_cache
from pathlib import Path
from statistics import stdev
-from typing import Any, Callable, Dict, Optional, Tuple, Type
+from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
import matplotlib
import matplotlib.pyplot as plt
model_params: Dict[str, Any] = copy.deepcopy(self.model_training_parameters)
if self.lr_schedule:
- lr = float(model_params.get("learning_rate", 0.0003))
- model_params["learning_rate"] = linear_schedule(lr)
- logger.info("Learning rate linear schedule enabled, initial value: %s", lr)
+ lr = model_params.get("learning_rate", 0.0003)
+ if isinstance(lr, (int, float)):
+ lr = float(lr)
+ model_params["learning_rate"] = linear_schedule(lr)
+ logger.info(
+ "Learning rate linear schedule enabled, initial value: %s", lr
+ )
if "PPO" in self.model_type and self.cr_schedule:
- cr = float(model_params.get("clip_range", 0.2))
- model_params["clip_range"] = linear_schedule(cr)
- logger.info("Clip range linear schedule enabled, initial value: %s", cr)
+ cr = model_params.get("clip_range", 0.2)
+ if isinstance(cr, (int, float)):
+ cr = float(cr)
+ model_params["clip_range"] = linear_schedule(cr)
+ logger.info("Clip range linear schedule enabled, initial value: %s", cr)
if "DQN" in self.model_type:
- gradient_steps = model_params.get("gradient_steps")
- if gradient_steps is None:
- train_freq = model_params.get("train_freq")
- if isinstance(train_freq, (tuple, list)) and train_freq:
- train_freq = (
- train_freq[0] if isinstance(train_freq[0], int) else None
- )
- else:
- train_freq = train_freq if isinstance(train_freq, int) else None
- subsample_steps = model_params.get("subsample_steps")
- if (
- isinstance(train_freq, int)
- and train_freq > 0
- and isinstance(subsample_steps, int)
- and subsample_steps > 0
- ):
- model_params["gradient_steps"] = min(
- train_freq, max(train_freq // subsample_steps, 1)
- )
- else:
- model_params["gradient_steps"] = -1
+ if model_params.get("gradient_steps") is None:
+ model_params["gradient_steps"] = compute_gradient_steps(
+ model_params.get("train_freq"), model_params.get("subsample_steps")
+ )
+ if "subsample_steps" in model_params:
+ model_params.pop("subsample_steps", None)
if not model_params.get("policy_kwargs"):
model_params["policy_kwargs"] = {}
- net_arch = model_params.get("policy_kwargs", {}).get("net_arch", [128, 128])
+ default_net_arch: list[int] = [128, 128]
+ net_arch: Union[list[int], Dict[str, list[int]]] = model_params.get(
+ "policy_kwargs", {}
+ ).get("net_arch", default_net_arch)
+
if "PPO" in self.model_type:
if isinstance(net_arch, str):
- net_arch = get_net_arch(self.model_type, net_arch)
- if isinstance(net_arch, dict):
- model_params["policy_kwargs"]["net_arch"] = net_arch
+ resolved_net_arch = get_net_arch(self.model_type, net_arch)
+ if isinstance(resolved_net_arch, dict):
+ model_params["policy_kwargs"]["net_arch"] = resolved_net_arch
else:
model_params["policy_kwargs"]["net_arch"] = {
- "pi": net_arch,
- "vf": net_arch,
+ "pi": resolved_net_arch,
+ "vf": resolved_net_arch,
}
elif isinstance(net_arch, list):
model_params["policy_kwargs"]["net_arch"] = {
"vf": net_arch,
}
elif isinstance(net_arch, dict):
- if not ("pi" in net_arch and "vf" in net_arch):
+ pi: Optional[list[int]] = net_arch.get("pi")
+ vf: Optional[list[int]] = net_arch.get("vf")
+ if not isinstance(pi, list) or not isinstance(vf, list):
model_params["policy_kwargs"]["net_arch"] = {
- "pi": net_arch.get("pi", net_arch.get("vf", [128, 128])),
- "vf": net_arch.get("vf", net_arch.get("pi", [128, 128])),
+ "pi": pi
+ if isinstance(pi, list)
+ else (vf if isinstance(vf, list) else default_net_arch),
+ "vf": vf
+ if isinstance(vf, list)
+ else (pi if isinstance(pi, list) else default_net_arch),
}
else:
model_params["policy_kwargs"]["net_arch"] = net_arch
model_params["policy_kwargs"]["activation_fn"] = get_activation_fn(
model_params.get("policy_kwargs", {}).get("activation_fn", "relu")
)
-
model_params["policy_kwargs"]["optimizer_class"] = get_optimizer_class(
model_params.get("policy_kwargs", {}).get("optimizer_class", "adam")
)
return copy.deepcopy(self._model_params_cache)
def get_callbacks(
- self, eval_freq: int, data_path: str, trial: Trial = None
+ self, eval_freq: int, data_path: str, trial: Optional[Trial] = None
) -> list[BaseCallback]:
"""
Get the model specific callbacks
callbacks: list[BaseCallback] = []
no_improvement_callback = None
rollout_plot_callback = None
- verbose = self.get_model_params().get("verbose", 0)
+ verbose = int(self.get_model_params().get("verbose", 0))
if self.plot_new_best:
rollout_plot_callback = RolloutPlotCallback(verbose=verbose)
)
def _on_step(self) -> bool:
- infos_list: list[dict[str, Any]] | None = self.locals.get("infos")
+ infos_list: list[Dict[str, Any]] | None = self.locals.get("infos")
aggregated_info: Dict[str, Any] = {}
- if infos_list:
- numeric_acc: dict[str, list[float]] = defaultdict(list)
- non_numeric_acc: dict[str, set[Any]] = defaultdict(set)
+
+ if isinstance(infos_list, list) and infos_list:
+ numeric_acc: Dict[str, list[float]] = defaultdict(list)
+ non_numeric_acc: Dict[str, set[Any]] = defaultdict(set)
+
for info_dict in infos_list:
if not isinstance(info_dict, dict):
continue
for k, v in info_dict.items():
- if k in ("episode", "terminal_observation", "TimeLimit.truncated"):
+ if k in {"episode", "terminal_observation", "TimeLimit.truncated"}:
continue
if isinstance(v, (int, float)) and not isinstance(v, bool):
numeric_acc[k].append(float(v))
else:
non_numeric_acc[k].add(v)
+
for k, values in numeric_acc.items():
if not values:
continue
aggregated_info[f"{k}_std"] = stdev(values)
except Exception:
pass
+
for key in ("reward", "pnl"):
- values = numeric_acc.get(key, [])
+ values = numeric_acc.get(key)
if values:
- try:
- aggregated_info[f"{key}_min"] = float(min(values))
- aggregated_info[f"{key}_max"] = float(max(values))
- except Exception:
- pass
+ aggregated_info[f"{key}_min"] = float(min(values))
+ aggregated_info[f"{key}_max"] = float(max(values))
+
for k, values in non_numeric_acc.items():
- if len(values) == 1:
- aggregated_info[k] = next(iter(values))
- else:
- aggregated_info[k] = "mixed"
+ aggregated_info[k] = next(iter(values)) if len(values) == 1 else "mixed"
if self.training_env is None:
return True
try:
tensorboard_metrics_list = self.training_env.get_attr("tensorboard_metrics")
+ tensorboard_metrics = (
+ tensorboard_metrics_list[0] if tensorboard_metrics_list else {}
+ )
except Exception:
- tensorboard_metrics_list = []
- tensorboard_metrics = (
- tensorboard_metrics_list[0] if tensorboard_metrics_list else {}
- )
+ tensorboard_metrics = {}
for metric, value in aggregated_info.items():
self.logger.record(f"info/{metric}", value)
for category, metrics in tensorboard_metrics.items():
- if not isinstance(metrics, dict):
- continue
- for metric, value in metrics.items():
- self.logger.record(f"{category}/{metric}", value)
+ if isinstance(metrics, dict):
+ for metric, value in metrics.items():
+ self.logger.record(f"{category}/{metric}", value)
return True
return _init
-def deepmerge(dst: dict[str, Any], src: dict[str, Any]) -> dict[str, Any]:
+def deepmerge(dst: Dict[str, Any], src: Dict[str, Any]) -> Dict[str, Any]:
"""Recursively merge two dicts without mutating inputs"""
dst_copy = copy.deepcopy(dst)
for k, v in src.items():
return func
+def compute_gradient_steps(train_freq: Any, subsample_steps: Any) -> int:
+ tf: Optional[int] = None
+ if isinstance(train_freq, (tuple, list)) and train_freq:
+ tf = train_freq[0] if isinstance(train_freq[0], int) else None
+ elif isinstance(train_freq, int):
+ tf = train_freq
+
+ ss: Optional[int] = subsample_steps if isinstance(subsample_steps, int) else None
+
+ if isinstance(tf, int) and tf > 0 and isinstance(ss, int) and ss > 0:
+ return min(tf, max(tf // ss, 1))
+ return -1
+
+
@lru_cache(maxsize=32)
def hours_to_seconds(hours: float) -> float:
"""
def get_net_arch(
model_type: str, net_arch_type: str
-) -> Dict[str, list[int] | Dict[str, list[int]]]:
+) -> Union[list[int], Dict[str, list[int]]]:
"""
Get network architecture
"""
def convert_optuna_params_to_model_params(
model_type: str, optuna_params: Dict[str, Any]
) -> Dict[str, Any]:
- model_params = {}
- policy_kwargs = {}
+ model_params: Dict[str, Any] = {}
+ policy_kwargs: Dict[str, Any] = {}
lr = optuna_params.get("learning_rate")
- if optuna_params.get("lr_schedule", "") == "linear":
+ if lr is None:
+ raise ValueError(f"missing 'learning_rate' in optuna params for {model_type}")
+ lr: float | Callable[[float], float] = float(lr)
+ if optuna_params.get("lr_schedule") == "linear":
lr = linear_schedule(lr)
if "PPO" in model_type:
cr = optuna_params.get("clip_range")
+ if cr is None:
+ raise ValueError(f"missing 'clip_range' in optuna params for {model_type}")
+ cr: float | Callable[[float], float] = float(cr)
if optuna_params.get("cr_schedule") == "linear":
cr = linear_schedule(cr)
+
model_params.update(
{
- "n_steps": optuna_params.get("n_steps"),
- "batch_size": optuna_params.get("batch_size"),
- "gamma": optuna_params.get("gamma"),
+ "n_steps": int(optuna_params.get("n_steps")),
+ "batch_size": int(optuna_params.get("batch_size")),
+ "gamma": float(optuna_params.get("gamma")),
"learning_rate": lr,
- "ent_coef": optuna_params.get("ent_coef"),
+ "ent_coef": float(optuna_params.get("ent_coef")),
"clip_range": cr,
- "n_epochs": optuna_params.get("n_epochs"),
- "gae_lambda": optuna_params.get("gae_lambda"),
- "max_grad_norm": optuna_params.get("max_grad_norm"),
- "vf_coef": optuna_params.get("vf_coef"),
+ "n_epochs": int(optuna_params.get("n_epochs")),
+ "gae_lambda": float(optuna_params.get("gae_lambda")),
+ "max_grad_norm": float(optuna_params.get("max_grad_norm")),
+ "vf_coef": float(optuna_params.get("vf_coef")),
}
)
if optuna_params.get("target_kl") is not None:
- model_params["target_kl"] = optuna_params.get("target_kl")
+ model_params["target_kl"] = float(optuna_params.get("target_kl"))
elif "DQN" in model_type:
train_freq = optuna_params.get("train_freq")
subsample_steps = optuna_params.get("subsample_steps")
+ gradient_steps = compute_gradient_steps(train_freq, subsample_steps)
+
model_params.update(
{
- "gamma": optuna_params.get("gamma"),
- "batch_size": optuna_params.get("batch_size"),
+ "gamma": float(optuna_params.get("gamma")),
+ "batch_size": int(optuna_params.get("batch_size")),
"learning_rate": lr,
- "buffer_size": optuna_params.get("buffer_size"),
+ "buffer_size": int(optuna_params.get("buffer_size")),
"train_freq": train_freq,
- "gradient_steps": min(
- train_freq,
- max(
- train_freq // subsample_steps,
- 1,
- ),
+ "gradient_steps": gradient_steps,
+ "exploration_fraction": float(
+ optuna_params.get("exploration_fraction")
+ ),
+ "exploration_initial_eps": float(
+ optuna_params.get("exploration_initial_eps")
+ ),
+ "exploration_final_eps": float(
+ optuna_params.get("exploration_final_eps")
),
- "exploration_fraction": optuna_params.get("exploration_fraction"),
- "exploration_initial_eps": optuna_params.get("exploration_initial_eps"),
- "exploration_final_eps": optuna_params.get("exploration_final_eps"),
- "target_update_interval": optuna_params.get("target_update_interval"),
- "learning_starts": optuna_params.get("learning_starts"),
+ "target_update_interval": int(
+ optuna_params.get("target_update_interval")
+ ),
+ "learning_starts": int(optuna_params.get("learning_starts")),
}
)
- if "QRDQN" in model_type and optuna_params.get("n_quantiles"):
- policy_kwargs["n_quantiles"] = optuna_params["n_quantiles"]
+ if "QRDQN" in model_type and optuna_params.get("n_quantiles") is not None:
+ policy_kwargs["n_quantiles"] = int(optuna_params["n_quantiles"])
else:
raise ValueError(f"Model {model_type} not supported")
if optuna_params.get("net_arch"):
- policy_kwargs["net_arch"] = get_net_arch(model_type, optuna_params["net_arch"])
+ policy_kwargs["net_arch"] = get_net_arch(
+ model_type, str(optuna_params["net_arch"])
+ )
if optuna_params.get("activation_fn"):
policy_kwargs["activation_fn"] = get_activation_fn(
- optuna_params["activation_fn"]
+ str(optuna_params["activation_fn"])
)
if optuna_params.get("optimizer_class"):
policy_kwargs["optimizer_class"] = get_optimizer_class(
- optuna_params["optimizer_class"]
+ str(optuna_params["optimizer_class"])
)
if optuna_params.get("ortho_init") is not None:
- policy_kwargs["ortho_init"] = optuna_params["ortho_init"]
- model_params["policy_kwargs"] = policy_kwargs
+ policy_kwargs["ortho_init"] = bool(optuna_params["ortho_init"])
+ model_params["policy_kwargs"] = policy_kwargs
return model_params