From: Jérôme Benoit Date: Wed, 10 Sep 2025 17:36:20 +0000 (+0200) Subject: refactor(reforcexy): refine typing X-Git-Url: https://git.piment-noir.org/?a=commitdiff_plain;h=d737df5af4a9c6eface6032e5b74fb44f71a9af0;p=freqai-strategies.git refactor(reforcexy): refine typing Signed-off-by: Jérôme Benoit --- diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index 04a2d10..db1ad40 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -10,7 +10,7 @@ from enum import IntEnum from functools import lru_cache from pathlib import Path from statistics import stdev -from typing import Any, Callable, Dict, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, Literal, Optional, Tuple, Type, Union import matplotlib import matplotlib.pyplot as plt @@ -270,9 +270,11 @@ class ReforceXY(BaseReinforcementLearningModel): model_params["policy_kwargs"] = {} default_net_arch: list[int] = [128, 128] - net_arch: Union[list[int], Dict[str, list[int]]] = model_params.get( - "policy_kwargs", {} - ).get("net_arch", default_net_arch) + net_arch: Union[ + list[int], + Dict[str, list[int]], + Literal["small", "medium", "large", "extra_large"], + ] = model_params.get("policy_kwargs", {}).get("net_arch", default_net_arch) if "PPO" in self.model_type: if isinstance(net_arch, str): @@ -1700,7 +1702,7 @@ def steps_to_days(steps: int, timeframe: str) -> float: def get_net_arch( - model_type: str, net_arch_type: str + model_type: str, net_arch_type: Literal["small", "medium", "large", "extra_large"] ) -> Union[list[int], Dict[str, list[int]]]: """ Get network architecture @@ -1720,7 +1722,9 @@ def get_net_arch( }.get(net_arch_type, [128, 128]) -def get_activation_fn(activation_fn_name: str) -> type[th.nn.Module]: +def get_activation_fn( + activation_fn_name: Literal["tanh", "relu", "elu", "leaky_relu"], +) -> type[th.nn.Module]: """ Get activation function """ @@ -1732,7 +1736,9 @@ def get_activation_fn(activation_fn_name: str) -> type[th.nn.Module]: }.get(activation_fn_name, th.nn.ReLU) -def get_optimizer_class(optimizer_class_name: str) -> type[th.optim.Optimizer]: +def get_optimizer_class( + optimizer_class_name: Literal["adam"], +) -> type[th.optim.Optimizer]: """ Get optimizer class """