)
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.logger import Figure, HParam
+from stable_baselines3.common.type_aliases import TrainFreq
from stable_baselines3.common.utils import ConstantSchedule, set_random_seed
from stable_baselines3.common.vec_env import (
DummyVecEnv,
@staticmethod
def _build_train_freq(
- train_freq: Optional[Union[int, Tuple[int, ...], List[int]]],
+ train_freq: Optional[Union[int, Tuple[int], List[int]]],
) -> Optional[int]:
train_freq_val: Optional[int] = None
- try:
- if isinstance(train_freq, int):
- train_freq_val = train_freq
- elif isinstance(train_freq, (tuple, list)) and train_freq:
- if isinstance(train_freq[0], int):
- train_freq_val = train_freq[0]
- elif hasattr(train_freq, "freq"):
- freq = getattr(train_freq, "freq")
- if isinstance(freq, int):
- train_freq_val = freq
- except Exception:
- train_freq_val = None
+ if isinstance(train_freq, TrainFreq):
+ if isinstance(train_freq.freq, int):
+ train_freq_val = train_freq.freq
+ elif isinstance(train_freq, (tuple, list)) and train_freq:
+ if isinstance(train_freq[0], int):
+ train_freq_val = train_freq[0]
+ elif isinstance(train_freq, int):
+ train_freq_val = train_freq
+
return train_freq_val
def _on_training_start(self) -> None:
def compute_gradient_steps(train_freq: Any, subsample_steps: Any) -> int:
tf: Optional[int] = None
+ if isinstance(train_freq, TrainFreq):
+ tf = train_freq.freq if isinstance(train_freq.freq, int) else None
if isinstance(train_freq, (tuple, list)) and train_freq:
tf = train_freq[0] if isinstance(train_freq[0], int) else None
elif isinstance(train_freq, int):