logger.info("Clip range linear schedule enabled, initial value: %s", cr)
if "DQN" in self.model_type:
- gradient_steps = int(model_params.get("gradient_steps"))
+ gradient_steps = model_params.get("gradient_steps")
if gradient_steps is None:
+ gradient_steps = int(gradient_steps)
train_freq = model_params.get("train_freq")
if isinstance(train_freq, (tuple, list)) and train_freq:
train_freq = (
train_freq, max(train_freq // subsample_steps, 1)
)
else:
- model_params["gradient_steps"] = 1
+ model_params["gradient_steps"] = -1
if not model_params.get("policy_kwargs"):
model_params["policy_kwargs"] = {}