From 5e29309fc37134deb1579bc92b9d91119dc2f401 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Wed, 24 Sep 2025 23:05:18 +0200 Subject: [PATCH] refactor(reforcexy): cleanup train frequency handling MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/user_data/freqaimodels/ReforceXY.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index d931920..cbf6f1a 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -48,6 +48,7 @@ from stable_baselines3.common.callbacks import ( ) from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.logger import Figure, HParam +from stable_baselines3.common.type_aliases import TrainFreq from stable_baselines3.common.utils import ConstantSchedule, set_random_seed from stable_baselines3.common.vec_env import ( DummyVecEnv, @@ -1925,21 +1926,18 @@ class InfoMetricsCallback(TensorboardCallback): @staticmethod def _build_train_freq( - train_freq: Optional[Union[int, Tuple[int, ...], List[int]]], + train_freq: Optional[Union[int, Tuple[int], List[int]]], ) -> Optional[int]: train_freq_val: Optional[int] = None - try: - if isinstance(train_freq, int): - train_freq_val = train_freq - elif isinstance(train_freq, (tuple, list)) and train_freq: - if isinstance(train_freq[0], int): - train_freq_val = train_freq[0] - elif hasattr(train_freq, "freq"): - freq = getattr(train_freq, "freq") - if isinstance(freq, int): - train_freq_val = freq - except Exception: - train_freq_val = None + if isinstance(train_freq, TrainFreq): + if isinstance(train_freq.freq, int): + train_freq_val = train_freq.freq + elif isinstance(train_freq, (tuple, list)) and train_freq: + if isinstance(train_freq[0], int): + train_freq_val = train_freq[0] + elif isinstance(train_freq, int): + train_freq_val = train_freq + return train_freq_val def _on_training_start(self) -> None: @@ -2449,6 +2447,8 @@ def _compute_gradient_steps(tf: int, ss: int) -> int: def compute_gradient_steps(train_freq: Any, subsample_steps: Any) -> int: tf: Optional[int] = None + if isinstance(train_freq, TrainFreq): + tf = train_freq.freq if isinstance(train_freq.freq, int) else None if isinstance(train_freq, (tuple, list)) and train_freq: tf = train_freq[0] if isinstance(train_freq[0], int) else None elif isinstance(train_freq, int): -- 2.43.0