def get_optimizer_class(
- optimizer_class_name: Literal["adam"],
+ optimizer_class_name: Literal["adam", "adamw"],
) -> type[th.optim.Optimizer]:
"""
Get optimizer class
"""
return {
"adam": th.optim.Adam,
+ "adamw": th.optim.AdamW,
}.get(optimizer_class_name, th.optim.Adam)
"activation_fn": trial.suggest_categorical(
"activation_fn", ["tanh", "relu", "elu", "leaky_relu"]
),
- "optimizer_class": trial.suggest_categorical("optimizer_class", ["adam"]),
+ "optimizer_class": trial.suggest_categorical("optimizer_class", ["adamw"]),
},
)