sampler = self._optuna_config.get(
"sampler",
)
- if sampler == QuickAdapterRegressorV3._OPTUNA_SAMPLERS.tpe:
- return optuna.samplers.TPESampler(
- n_startup_trials=self._optuna_config.get(
- "n_startup_trials",
- QuickAdapterRegressorV3.OPTUNA_N_STARTUP_TRIALS_DEFAULT,
- ),
- multivariate=True,
- group=True,
- constant_liar=self._optuna_config.get(
- "n_jobs", QuickAdapterRegressorV3.OPTUNA_N_JOBS_DEFAULT
+ match sampler:
+ case None:
+ raise ValueError(
+ f"Invalid optuna sampler value {sampler!r}: "
+ f"supported values are {', '.join(QuickAdapterRegressorV3._OPTUNA_SAMPLERS)}"
)
- > 1,
- seed=self._optuna_config.get(
- "seed", QuickAdapterRegressorV3.OPTUNA_SEED_DEFAULT
- ),
- )
- elif sampler == QuickAdapterRegressorV3._OPTUNA_SAMPLERS.auto:
- return optunahub.load_module("samplers/auto_sampler").AutoSampler(
- seed=self._optuna_config.get(
- "seed", QuickAdapterRegressorV3.OPTUNA_SEED_DEFAULT
+ case QuickAdapterRegressorV3._OPTUNA_SAMPLERS.tpe:
+ return optuna.samplers.TPESampler(
+ n_startup_trials=self._optuna_config.get(
+ "n_startup_trials",
+ QuickAdapterRegressorV3.OPTUNA_N_STARTUP_TRIALS_DEFAULT,
+ ),
+ multivariate=True,
+ group=True,
+ constant_liar=self._optuna_config.get(
+ "n_jobs", QuickAdapterRegressorV3.OPTUNA_N_JOBS_DEFAULT
+ )
+ > 1,
+ seed=self._optuna_config.get(
+ "seed", QuickAdapterRegressorV3.OPTUNA_SEED_DEFAULT
+ ),
)
- )
- elif sampler == QuickAdapterRegressorV3._OPTUNA_SAMPLERS.nsgaii:
- return optuna.samplers.NSGAIISampler(
- seed=self._optuna_config.get(
- "seed", QuickAdapterRegressorV3.OPTUNA_SEED_DEFAULT
- ),
- )
- elif sampler == QuickAdapterRegressorV3._OPTUNA_SAMPLERS.nsgaiii:
- return optuna.samplers.NSGAIIISampler(
- seed=self._optuna_config.get(
- "seed", QuickAdapterRegressorV3.OPTUNA_SEED_DEFAULT
- ),
- )
- else:
- raise ValueError(
- f"Invalid optuna sampler value {sampler!r}: "
- f"supported values are {', '.join(QuickAdapterRegressorV3._OPTUNA_SAMPLERS)}"
- )
+ case QuickAdapterRegressorV3._OPTUNA_SAMPLERS.auto:
+ return optunahub.load_module("samplers/auto_sampler").AutoSampler(
+ seed=self._optuna_config.get(
+ "seed", QuickAdapterRegressorV3.OPTUNA_SEED_DEFAULT
+ )
+ )
+ case QuickAdapterRegressorV3._OPTUNA_SAMPLERS.nsgaii:
+ return optuna.samplers.NSGAIISampler(
+ seed=self._optuna_config.get(
+ "seed", QuickAdapterRegressorV3.OPTUNA_SEED_DEFAULT
+ ),
+ )
+ case QuickAdapterRegressorV3._OPTUNA_SAMPLERS.nsgaiii:
+ return optuna.samplers.NSGAIIISampler(
+ seed=self._optuna_config.get(
+ "seed", QuickAdapterRegressorV3.OPTUNA_SEED_DEFAULT
+ ),
+ )
+ case _:
+ assert_never(sampler)
@lru_cache(maxsize=8)
def optuna_samplers_by_namespace(