from functools import lru_cache
from pathlib import Path
from statistics import stdev
-from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
+from typing import Any, Callable, Dict, Literal, Optional, Tuple, Type, Union
import matplotlib
import matplotlib.pyplot as plt
model_params["policy_kwargs"] = {}
default_net_arch: list[int] = [128, 128]
- net_arch: Union[list[int], Dict[str, list[int]]] = model_params.get(
- "policy_kwargs", {}
- ).get("net_arch", default_net_arch)
+ net_arch: Union[
+ list[int],
+ Dict[str, list[int]],
+ Literal["small", "medium", "large", "extra_large"],
+ ] = model_params.get("policy_kwargs", {}).get("net_arch", default_net_arch)
if "PPO" in self.model_type:
if isinstance(net_arch, str):
def get_net_arch(
- model_type: str, net_arch_type: str
+ model_type: str, net_arch_type: Literal["small", "medium", "large", "extra_large"]
) -> Union[list[int], Dict[str, list[int]]]:
"""
Get network architecture
}.get(net_arch_type, [128, 128])
-def get_activation_fn(activation_fn_name: str) -> type[th.nn.Module]:
+def get_activation_fn(
+ activation_fn_name: Literal["tanh", "relu", "elu", "leaky_relu"],
+) -> type[th.nn.Module]:
"""
Get activation function
"""
}.get(activation_fn_name, th.nn.ReLU)
-def get_optimizer_class(optimizer_class_name: str) -> type[th.optim.Optimizer]:
+def get_optimizer_class(
+ optimizer_class_name: Literal["adam"],
+) -> type[th.optim.Optimizer]:
"""
Get optimizer class
"""