"n_startup_trials", 15
)
self.optuna_callback: Optional[MaskableTrialEvalCallback] = None
+ self._model_params_cache: Optional[Dict[str, Any]] = None
self.unset_unsupported()
def unset_unsupported(self) -> None:
"""
Get model parameters
"""
+ if self._model_params_cache is not None:
+ return copy.deepcopy(self._model_params_cache)
+
model_params: Dict[str, Any] = copy.deepcopy(self.model_training_parameters)
if self.lr_schedule:
net_arch = model_params.get("policy_kwargs", {}).get("net_arch", [128, 128])
if "PPO" in self.model_type:
- if not isinstance(net_arch, dict):
+ if isinstance(net_arch, str):
+ net_arch = get_net_arch(self.model_type, net_arch)
+ if isinstance(net_arch, dict):
+ model_params["policy_kwargs"]["net_arch"] = net_arch
+ else:
+ model_params["policy_kwargs"]["net_arch"] = {
+ "pi": net_arch,
+ "vf": net_arch,
+ }
+ elif isinstance(net_arch, list):
model_params["policy_kwargs"]["net_arch"] = {
"pi": net_arch,
"vf": net_arch,
}
+ elif isinstance(net_arch, dict):
+ if not ("pi" in net_arch and "vf" in net_arch):
+ model_params["policy_kwargs"]["net_arch"] = {
+ "pi": net_arch.get("pi", net_arch.get("vf", [128, 128])),
+ "vf": net_arch.get("vf", net_arch.get("pi", [128, 128])),
+ }
+ else:
+ model_params["policy_kwargs"]["net_arch"] = net_arch
else:
model_params["policy_kwargs"]["net_arch"] = net_arch
model_params["policy_kwargs"]["optimizer_class"] = get_optimizer_class(
model_params.get("policy_kwargs", {}).get("optimizer_class", "adam")
)
-
- return model_params
+ self._model_params_cache = copy.deepcopy(model_params)
+ return copy.deepcopy(self._model_params_cache)
def get_callbacks(
self, eval_freq: int, data_path: str, trial: Trial = None
exploration_initial_eps = trial.suggest_float(
"exploration_initial_eps", exploration_final_eps, 1.0
)
+ if exploration_initial_eps >= 0.9:
+ min_fraction = 0.2
+ elif (exploration_initial_eps - exploration_final_eps) > 0.5:
+ min_fraction = 0.15
+ else:
+ min_fraction = 0.05
exploration_fraction = trial.suggest_float(
- "exploration_fraction", 0.05, 0.9, step=0.02
+ "exploration_fraction", min_fraction, 0.9, step=0.02
)
buffer_size = trial.suggest_categorical(
"buffer_size", [int(1e4), int(5e4), int(1e5), int(2e5)]