else:
model_params["policy_kwargs"]["net_arch"] = net_arch
- model_params["policy_kwargs"]["activation_fn"] = model_params[
- "policy_kwargs"
- ].get("activation_fn", "relu")
+ model_params["policy_kwargs"]["activation_fn"] = get_activation_fn(
+ model_params["policy_kwargs"].get("activation_fn", "relu")
+ )
- model_params["policy_kwargs"]["optimizer_class"] = model_params[
- "policy_kwargs"
- ].get("optimizer_class", "adam")
+ model_params["policy_kwargs"]["optimizer_class"] = get_optimizer_class(
+ model_params["policy_kwargs"].get("optimizer_class", "adam")
+ )
return model_params