]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor(reforcexy): factor out DQN train_freq handling
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Mon, 22 Sep 2025 11:35:12 +0000 (13:35 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Mon, 22 Sep 2025 11:35:12 +0000 (13:35 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index 886e858bfeb425ca71f264c028a9311cba7976ad..68a14f3adba9e18fe3bc91dc8e2a6b7ce03e9629 100644 (file)
@@ -8,7 +8,7 @@ from collections import defaultdict
 from collections.abc import Mapping
 from enum import IntEnum
 from pathlib import Path
-from typing import Any, Callable, Dict, Literal, Optional, Tuple, Type, Union
+from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
 
 import matplotlib
 import matplotlib.pyplot as plt
@@ -114,7 +114,7 @@ class ReforceXY(BaseReinforcementLearningModel):
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
-        self.pairs: list[str] = self.config.get("exchange", {}).get("pair_whitelist")
+        self.pairs: List[str] = self.config.get("exchange", {}).get("pair_whitelist")
         if not self.pairs:
             raise ValueError(
                 "FreqAI model requires StaticPairList method defined in pairlists configuration and pair_whitelist defined in exchange section configuration"
@@ -333,10 +333,10 @@ class ReforceXY(BaseReinforcementLearningModel):
         if not model_params.get("policy_kwargs"):
             model_params["policy_kwargs"] = {}
 
-        default_net_arch: list[int] = [128, 128]
+        default_net_arch: List[int] = [128, 128]
         net_arch: Union[
-            list[int],
-            Dict[str, list[int]],
+            List[int],
+            Dict[str, List[int]],
             Literal["small", "medium", "large", "extra_large"],
         ] = model_params.get("policy_kwargs", {}).get("net_arch", default_net_arch)
 
@@ -356,8 +356,8 @@ class ReforceXY(BaseReinforcementLearningModel):
                     "vf": net_arch,
                 }
             elif isinstance(net_arch, dict):
-                pi: Optional[list[int]] = net_arch.get("pi")
-                vf: Optional[list[int]] = net_arch.get("vf")
+                pi: Optional[List[int]] = net_arch.get("pi")
+                vf: Optional[List[int]] = net_arch.get("vf")
                 if not isinstance(pi, list) or not isinstance(vf, list):
                     model_params["policy_kwargs"]["net_arch"] = {
                         "pi": pi
@@ -426,11 +426,11 @@ class ReforceXY(BaseReinforcementLearningModel):
         eval_freq: int,
         data_path: str,
         trial: Optional[Trial] = None,
-    ) -> list[BaseCallback]:
+    ) -> List[BaseCallback]:
         """
         Get the model specific callbacks
         """
-        callbacks: list[BaseCallback] = []
+        callbacks: List[BaseCallback] = []
         no_improvement_callback = None
         rollout_plot_callback = None
         verbose = int(self.get_model_params().get("verbose", 0))
@@ -540,11 +540,12 @@ class ReforceXY(BaseReinforcementLearningModel):
 
         if "PPO" in self.model_type:
             min_timesteps = 2 * model_params.get("n_steps", 0) * self.n_envs
-            if total_timesteps < min_timesteps:
+            if total_timesteps <= min_timesteps:
                 logger.warning(
-                    "total_timesteps=%s is less than 2*n_steps*n_envs=%s. This may lead to suboptimal training results",
+                    "total_timesteps=%s is less than or equal to 2*n_steps*n_envs=%s. This may lead to suboptimal training results for model %s",
                     total_timesteps,
                     min_timesteps,
+                    self.model_type,
                 )
 
         if self.activate_tensorboard:
@@ -624,7 +625,7 @@ class ReforceXY(BaseReinforcementLearningModel):
                 return Positions.Neutral
             return position
 
-        frame_buffer: list[NDArray[np.float32]] = []
+        frame_buffer: List[NDArray[np.float32]] = []
 
         def _predict(window) -> int:
             observation: DataFrame = dataframe.iloc[window.index]
@@ -645,7 +646,7 @@ class ReforceXY(BaseReinforcementLearningModel):
 
             np_observation = observation.to_numpy(dtype=np.float32)
 
-            fb: list[NDArray[np.float32]] = frame_buffer
+            fb: List[NDArray[np.float32]] = frame_buffer
             frame_stacking = self.frame_stacking
             if frame_stacking and frame_stacking > 1:
                 fb.append(np_observation.copy())
@@ -672,7 +673,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             )
             return int(action)
 
-        predicted_actions: list[int] = []
+        predicted_actions: List[int] = []
         for window_end in range(self.CONV_WIDTH, len(dataframe) + 1):
             window = dataframe.iloc[window_end - self.CONV_WIDTH : window_end]
             action = _predict(window)
@@ -1830,6 +1831,25 @@ class InfoMetricsCallback(TensorboardCallback):
                 logger.error("logger.record retry on stdout failed at %r: %r", key, e)
                 pass
 
+    @staticmethod
+    def build_train_freq(
+        train_freq: Optional[Union[int, Tuple[int, ...], List[int]]],
+    ) -> Optional[int]:
+        train_freq_val: Optional[int] = None
+        try:
+            if isinstance(train_freq, int):
+                train_freq_val = train_freq
+            elif isinstance(train_freq, (tuple, list)) and train_freq:
+                if isinstance(train_freq[0], int):
+                    train_freq_val = train_freq[0]
+            elif hasattr(train_freq, "freq"):
+                freq = getattr(train_freq, "freq")
+                if isinstance(freq, int):
+                    train_freq_val = freq
+        except Exception:
+            train_freq_val = None
+        return train_freq_val
+
     def _on_training_start(self) -> None:
         lr = getattr(self.model, "learning_rate", None)
         lr_schedule, lr_iv, lr_fv = get_schedule_type(lr)
@@ -1886,22 +1906,11 @@ class InfoMetricsCallback(TensorboardCallback):
                     "exploration_rate": float(self.model.exploration_rate),
                 }
             )
-            train_freq = getattr(self.model, "train_freq", None)
-            train_freq_val: int | None = None
-            try:
-                if isinstance(train_freq, int):
-                    train_freq_val = train_freq
-                elif isinstance(train_freq, (tuple, list)) and train_freq:
-                    if isinstance(train_freq[0], int):
-                        train_freq_val = train_freq[0]
-                elif hasattr(train_freq, "freq"):
-                    freq = getattr(train_freq, "freq")
-                    if isinstance(freq, int):
-                        train_freq_val = freq
-            except Exception:
-                train_freq_val = None
-            if train_freq_val is not None:
-                hparam_dict.update({"train_freq": train_freq_val})
+            train_freq = InfoMetricsCallback.build_train_freq(
+                getattr(self.model, "train_freq", None)
+            )
+            if train_freq is not None:
+                hparam_dict.update({"train_freq": train_freq})
             if "QRDQN" in self.model.__class__.__name__:
                 hparam_dict.update({"n_quantiles": int(self.model.n_quantiles)})
         metric_dict: dict[str, float | int] = {
@@ -1960,11 +1969,11 @@ class InfoMetricsCallback(TensorboardCallback):
             except Exception:
                 return False
 
-        infos_list: list[Dict[str, Any]] | None = self.locals.get("infos")
+        infos_list: List[Dict[str, Any]] | None = self.locals.get("infos")
         aggregated_info: Dict[str, Any] = {}
 
         if isinstance(infos_list, list) and infos_list:
-            numeric_acc: Dict[str, list[float]] = defaultdict(list)
+            numeric_acc: Dict[str, List[float]] = defaultdict(list)
             non_numeric_counts: Dict[str, Dict[Any, int]] = defaultdict(
                 lambda: defaultdict(int)
             )
@@ -2412,7 +2421,7 @@ def get_schedule(
 
 def get_net_arch(
     model_type: str, net_arch_type: Literal["small", "medium", "large", "extra_large"]
-) -> Union[list[int], Dict[str, list[int]]]:
+) -> Union[List[int], Dict[str, List[int]]]:
     """
     Get network architecture
     """