else:
raise NotImplementedError
+ if "DQN" in self.model_type:
+ batch_size = params.get("batch_size")
+ gradient_steps = params.get("gradient_steps")
+ buffer_size = params.get("buffer_size")
+ if (batch_size * gradient_steps) > buffer_size:
+ raise TrialPruned(
+ "batch_size * gradient_steps is greater than buffer_size"
+ )
+
# Ensure that the sampled parameters take precedence
params = deepmerge(self.get_model_params(), params)
n_steps = trial.suggest_categorical("n_steps", list(PPO_N_STEPS))
batch_size = trial.suggest_categorical("batch_size", [64, 128, 256, 512, 1024])
if batch_size > n_steps:
- raise TrialPruned("batch_size must be less than or equal to n_steps")
+ raise TrialPruned("batch_size is greater than n_steps")
if (n_steps * n_envs) % batch_size != 0:
- raise TrialPruned("(n_steps * n_envs) not divisible by batch_size")
+ raise TrialPruned("n_steps * n_envs is not divisible by batch_size")
return convert_optuna_params_to_model_params(
"PPO",
{
"learning_starts", [500, 1000, 2000, 3000, 4000, 5000, 8000, 10000]
)
if learning_starts >= buffer_size:
- raise TrialPruned("learning_starts must be less than buffer_size")
+ raise TrialPruned("learning_starts is greater than or equal to buffer_size")
return {
"train_freq": trial.suggest_categorical(
"train_freq", [2, 4, 8, 16, 128, 256, 512, 1024]