self.progressbar_callback = ProgressBarCallback()
callbacks.append(self.progressbar_callback)
- self.eval_callback = MaskableEvalCallback(
- self.eval_env,
- eval_freq=eval_freq,
- deterministic=True,
- render=False,
- best_model_save_path=data_path,
- use_masking=self.is_maskable,
- callback_on_new_best=rollout_plot_callback,
- callback_after_eval=no_improvement_callback,
- verbose=verbose,
- )
- callbacks.append(self.eval_callback)
-
if not trial:
- return callbacks
-
- self.optuna_callback = MaskableTrialEvalCallback(
- self.eval_env,
- trial,
- eval_freq=eval_freq,
- deterministic=True,
- render=False,
- best_model_save_path=data_path,
- use_masking=self.is_maskable,
- verbose=verbose,
- )
- callbacks.append(self.optuna_callback)
+ self.eval_callback = MaskableEvalCallback(
+ self.eval_env,
+ eval_freq=eval_freq,
+ deterministic=True,
+ render=False,
+ best_model_save_path=data_path,
+ use_masking=self.is_maskable,
+ callback_on_new_best=rollout_plot_callback,
+ callback_after_eval=no_improvement_callback,
+ verbose=verbose,
+ )
+ callbacks.append(self.eval_callback)
+ else:
+ self.optuna_callback = MaskableTrialEvalCallback(
+ self.eval_env,
+ trial,
+ eval_freq=eval_freq,
+ deterministic=True,
+ render=False,
+ best_model_save_path=data_path,
+ use_masking=self.is_maskable,
+ verbose=verbose,
+ )
+ callbacks.append(self.optuna_callback)
return callbacks
def fit(