]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor(reforcexy): cleanup train frequency handling
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Wed, 24 Sep 2025 21:05:18 +0000 (23:05 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Wed, 24 Sep 2025 21:05:18 +0000 (23:05 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index d93192052c5ade6a4b279e3ff69ba341c4817e89..cbf6f1a594e1e320d6d54bcbcb5e381dbbf53de8 100644 (file)
@@ -48,6 +48,7 @@ from stable_baselines3.common.callbacks import (
 )
 from stable_baselines3.common.env_checker import check_env
 from stable_baselines3.common.logger import Figure, HParam
+from stable_baselines3.common.type_aliases import TrainFreq
 from stable_baselines3.common.utils import ConstantSchedule, set_random_seed
 from stable_baselines3.common.vec_env import (
     DummyVecEnv,
@@ -1925,21 +1926,18 @@ class InfoMetricsCallback(TensorboardCallback):
 
     @staticmethod
     def _build_train_freq(
-        train_freq: Optional[Union[int, Tuple[int, ...], List[int]]],
+        train_freq: Optional[Union[int, Tuple[int], List[int]]],
     ) -> Optional[int]:
         train_freq_val: Optional[int] = None
-        try:
-            if isinstance(train_freq, int):
-                train_freq_val = train_freq
-            elif isinstance(train_freq, (tuple, list)) and train_freq:
-                if isinstance(train_freq[0], int):
-                    train_freq_val = train_freq[0]
-            elif hasattr(train_freq, "freq"):
-                freq = getattr(train_freq, "freq")
-                if isinstance(freq, int):
-                    train_freq_val = freq
-        except Exception:
-            train_freq_val = None
+        if isinstance(train_freq, TrainFreq):
+            if isinstance(train_freq.freq, int):
+                train_freq_val = train_freq.freq
+        elif isinstance(train_freq, (tuple, list)) and train_freq:
+            if isinstance(train_freq[0], int):
+                train_freq_val = train_freq[0]
+        elif isinstance(train_freq, int):
+            train_freq_val = train_freq
+
         return train_freq_val
 
     def _on_training_start(self) -> None:
@@ -2449,6 +2447,8 @@ def _compute_gradient_steps(tf: int, ss: int) -> int:
 
 def compute_gradient_steps(train_freq: Any, subsample_steps: Any) -> int:
     tf: Optional[int] = None
+    if isinstance(train_freq, TrainFreq):
+        tf = train_freq.freq if isinstance(train_freq.freq, int) else None
     if isinstance(train_freq, (tuple, list)) and train_freq:
         tf = train_freq[0] if isinstance(train_freq[0], int) else None
     elif isinstance(train_freq, int):