]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor(reforcexy): refine typing
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Wed, 10 Sep 2025 17:36:20 +0000 (19:36 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Wed, 10 Sep 2025 17:36:20 +0000 (19:36 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index 04a2d108fe32ae24fa593afde6efef59a5222436..db1ad40f0d906d6205afb0c747f2eaf6a12b9cf6 100644 (file)
@@ -10,7 +10,7 @@ from enum import IntEnum
 from functools import lru_cache
 from pathlib import Path
 from statistics import stdev
-from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
+from typing import Any, Callable, Dict, Literal, Optional, Tuple, Type, Union
 
 import matplotlib
 import matplotlib.pyplot as plt
@@ -270,9 +270,11 @@ class ReforceXY(BaseReinforcementLearningModel):
             model_params["policy_kwargs"] = {}
 
         default_net_arch: list[int] = [128, 128]
-        net_arch: Union[list[int], Dict[str, list[int]]] = model_params.get(
-            "policy_kwargs", {}
-        ).get("net_arch", default_net_arch)
+        net_arch: Union[
+            list[int],
+            Dict[str, list[int]],
+            Literal["small", "medium", "large", "extra_large"],
+        ] = model_params.get("policy_kwargs", {}).get("net_arch", default_net_arch)
 
         if "PPO" in self.model_type:
             if isinstance(net_arch, str):
@@ -1700,7 +1702,7 @@ def steps_to_days(steps: int, timeframe: str) -> float:
 
 
 def get_net_arch(
-    model_type: str, net_arch_type: str
+    model_type: str, net_arch_type: Literal["small", "medium", "large", "extra_large"]
 ) -> Union[list[int], Dict[str, list[int]]]:
     """
     Get network architecture
@@ -1720,7 +1722,9 @@ def get_net_arch(
     }.get(net_arch_type, [128, 128])
 
 
-def get_activation_fn(activation_fn_name: str) -> type[th.nn.Module]:
+def get_activation_fn(
+    activation_fn_name: Literal["tanh", "relu", "elu", "leaky_relu"],
+) -> type[th.nn.Module]:
     """
     Get activation function
     """
@@ -1732,7 +1736,9 @@ def get_activation_fn(activation_fn_name: str) -> type[th.nn.Module]:
     }.get(activation_fn_name, th.nn.ReLU)
 
 
-def get_optimizer_class(optimizer_class_name: str) -> type[th.optim.Optimizer]:
+def get_optimizer_class(
+    optimizer_class_name: Literal["adam"],
+) -> type[th.optim.Optimizer]:
     """
     Get optimizer class
     """