else self.get_storage()
)
eval_freq = max(1, len(train_df) // self.n_envs)
+ max_resource = max(1, (total_timesteps + eval_freq - 1) // eval_freq)
+ min_resource = min(3, max_resource)
study: Study = create_study(
study_name=study_name,
sampler=TPESampler(
seed=self.rl_config_optuna.get("seed", 42),
),
pruner=HyperbandPruner(
- min_resource=3,
- max_resource=(total_timesteps + eval_freq - 1) // eval_freq,
+ min_resource=min_resource,
+ max_resource=max_resource,
reduction_factor=3,
),
direction=StudyDirection.MAXIMIZE,
Defines a single trial for hyperparameter optimization using Optuna
"""
if "PPO" in self.model_type:
- params = sample_params_ppo(trial)
+ params = sample_params_ppo(trial, self.n_envs)
if params.get("n_steps", 0) > total_timesteps:
- raise TrialPruned("n_steps exceeds total_timesteps")
+ raise TrialPruned("n_steps is greater than total_timesteps")
elif "QRDQN" in self.model_type:
params = sample_params_qrdqn(trial)
elif "DQN" in self.model_type:
try:
model.learn(total_timesteps=total_timesteps, callback=callbacks)
except AssertionError:
- logger.warning("Optuna encountered NaN")
+ logger.warning("Optuna encountered NaN (AssertionError)")
nan_encountered = True
+ except ValueError as e:
+ if "NaN" in str(e):
+ logger.warning("Optuna encountered NaN (ValueError)")
+ nan_encountered = True
+ else:
+ raise
finally:
if self.progressbar_callback:
self.progressbar_callback.on_training_end()
model.env.close()
if nan_encountered:
- return np.nan
+ raise TrialPruned("NaN encountered during training")
if self.optuna_callback.is_pruned:
raise TrialPruned()
return model_params
-def sample_params_ppo(trial: Trial) -> Dict[str, Any]:
+def sample_params_ppo(trial: Trial, n_envs: int) -> Dict[str, Any]:
"""
Sampler for PPO hyperparams
"""
n_steps = trial.suggest_categorical("n_steps", [512, 1024, 2048, 4096])
batch_size = trial.suggest_categorical("batch_size", [64, 128, 256, 512, 1024])
+ if batch_size > n_steps:
+ raise TrialPruned("batch_size must be less than or equal to n_steps")
+ if (n_steps * n_envs) % batch_size != 0:
+ raise TrialPruned("(n_steps * n_envs) not divisible by batch_size")
return convert_optuna_params_to_model_params(
"PPO",
{
"buffer_size", [int(1e4), int(5e4), int(1e5), int(2e5)]
)
learning_starts = trial.suggest_categorical(
- "learning_starts", [500, 1000, 2000, 3000, 4000, 5000]
+ "learning_starts", [500, 1000, 2000, 3000, 4000, 5000, 8000, 10000]
)
+ if learning_starts >= buffer_size:
+ raise TrialPruned("learning_starts must be less than buffer_size")
return {
"train_freq": train_freq,
- "gradient_steps": max(
- train_freq // trial.suggest_categorical("subsample_steps", [2, 4, 8]), 1
+ "gradient_steps": min(
+ train_freq,
+ max(
+ train_freq // trial.suggest_categorical("subsample_steps", [2, 4, 8]),
+ 1,
+ ),
),
"gamma": trial.suggest_categorical(
"gamma", [0.93, 0.95, 0.97, 0.98, 0.99, 0.995, 0.997, 0.999, 0.9999]