ClassVar,
Final,
Literal,
+ NamedTuple,
Optional,
Union,
assert_never,
)
+class _OptunaSamplers(NamedTuple):
+ tpe: Literal["tpe"] = "tpe"
+ auto: Literal["auto"] = "auto"
+ nsgaii: Literal["nsgaii"] = "nsgaii"
+ nsgaiii: Literal["nsgaiii"] = "nsgaiii"
+
+
+class _OptunaHpoSamplers(NamedTuple):
+ tpe: Literal["tpe"] = "tpe"
+ auto: Literal["auto"] = "auto"
+
+
+class _OptunaLabelSamplers(NamedTuple):
+ auto: Literal["auto"] = "auto"
+ tpe: Literal["tpe"] = "tpe"
+ nsgaii: Literal["nsgaii"] = "nsgaii"
+ nsgaiii: Literal["nsgaiii"] = "nsgaiii"
+
+
class QuickAdapterRegressorV3(BaseRegressionModel):
"""
The following freqaimodel is released to sponsors of the non-profit FreqAI open-source project.
optuna.study.StudyDirection.MAXIMIZE,
) * _OPTUNA_LABEL_N_OBJECTIVES
_OPTUNA_STORAGE_BACKENDS: Final[tuple[str, ...]] = ("file", "sqlite")
- _OPTUNA_SAMPLERS: Final[tuple[OptunaSampler, ...]] = (
- "tpe",
- "auto",
- "nsgaii",
- "nsgaiii",
- )
- _OPTUNA_HPO_SAMPLERS: Final[tuple[OptunaSampler, ...]] = _OPTUNA_SAMPLERS[:2]
+ _OPTUNA_SAMPLERS: Final[_OptunaSamplers] = _OptunaSamplers()
+ _OPTUNA_HPO_SAMPLERS: Final[_OptunaHpoSamplers] = _OptunaHpoSamplers()
_OPTUNA_HPO_SAMPLERS_SET: Final[frozenset[OptunaSampler]] = frozenset(
_OPTUNA_HPO_SAMPLERS
)
- _OPTUNA_LABEL_SAMPLERS: Final[tuple[OptunaSampler, ...]] = (
- _OPTUNA_SAMPLERS[1], # "auto"
- _OPTUNA_SAMPLERS[0], # "tpe"
- _OPTUNA_SAMPLERS[2], # "nsgaii"
- _OPTUNA_SAMPLERS[3], # "nsgaiii"
- )
+ _OPTUNA_LABEL_SAMPLERS: Final[_OptunaLabelSamplers] = _OptunaLabelSamplers()
_OPTUNA_LABEL_SAMPLERS_SET: Final[frozenset[OptunaSampler]] = frozenset(
_OPTUNA_LABEL_SAMPLERS
)
.get("n_jobs", QuickAdapterRegressorV3.OPTUNA_N_JOBS_DEFAULT),
max(int(self.max_system_threads / 4), 1),
),
- "sampler": QuickAdapterRegressorV3._OPTUNA_HPO_SAMPLERS[0], # "tpe"
+ "sampler": QuickAdapterRegressorV3._OPTUNA_HPO_SAMPLERS.tpe,
"storage": QuickAdapterRegressorV3._OPTUNA_STORAGE_BACKENDS[0], # "file"
"continuous": True,
"warm_start": True,
"n_startup_trials": QuickAdapterRegressorV3.OPTUNA_N_STARTUP_TRIALS_DEFAULT,
"n_trials": QuickAdapterRegressorV3.OPTUNA_N_TRIALS_DEFAULT,
"timeout": QuickAdapterRegressorV3.OPTUNA_TIMEOUT_DEFAULT,
- "label_sampler": QuickAdapterRegressorV3._OPTUNA_LABEL_SAMPLERS[
- 0
- ], # "auto"
+ "label_sampler": QuickAdapterRegressorV3._OPTUNA_LABEL_SAMPLERS.auto,
"label_candles_step": QuickAdapterRegressorV3.OPTUNA_LABEL_CANDLES_STEP_DEFAULT,
"space_reduction": QuickAdapterRegressorV3.OPTUNA_SPACE_REDUCTION_DEFAULT,
"space_fraction": QuickAdapterRegressorV3.OPTUNA_SPACE_FRACTION_DEFAULT,
sampler = self._optuna_config.get(
"sampler",
)
- if sampler == QuickAdapterRegressorV3._OPTUNA_SAMPLERS[0]: # "tpe"
+ if sampler == QuickAdapterRegressorV3._OPTUNA_SAMPLERS.tpe:
return optuna.samplers.TPESampler(
n_startup_trials=self._optuna_config.get(
"n_startup_trials",
"seed", QuickAdapterRegressorV3.OPTUNA_SEED_DEFAULT
),
)
- elif sampler == QuickAdapterRegressorV3._OPTUNA_SAMPLERS[1]: # "auto"
+ elif sampler == QuickAdapterRegressorV3._OPTUNA_SAMPLERS.auto:
return optunahub.load_module("samplers/auto_sampler").AutoSampler(
seed=self._optuna_config.get(
"seed", QuickAdapterRegressorV3.OPTUNA_SEED_DEFAULT
)
)
- elif sampler == QuickAdapterRegressorV3._OPTUNA_SAMPLERS[2]: # "nsgaii"
+ elif sampler == QuickAdapterRegressorV3._OPTUNA_SAMPLERS.nsgaii:
return optuna.samplers.NSGAIISampler(
seed=self._optuna_config.get(
"seed", QuickAdapterRegressorV3.OPTUNA_SEED_DEFAULT
),
)
- elif sampler == QuickAdapterRegressorV3._OPTUNA_SAMPLERS[3]: # "nsgaiii"
+ elif sampler == QuickAdapterRegressorV3._OPTUNA_SAMPLERS.nsgaiii:
return optuna.samplers.NSGAIIISampler(
seed=self._optuna_config.get(
"seed", QuickAdapterRegressorV3.OPTUNA_SEED_DEFAULT
return (
QuickAdapterRegressorV3._OPTUNA_HPO_SAMPLERS_SET,
self._optuna_config.get(
- "sampler", QuickAdapterRegressorV3._OPTUNA_HPO_SAMPLERS[0]
+ "sampler", QuickAdapterRegressorV3._OPTUNA_HPO_SAMPLERS.tpe
),
)
elif namespace == _OPTUNA_NAMESPACES.label:
return (
QuickAdapterRegressorV3._OPTUNA_LABEL_SAMPLERS_SET,
self._optuna_config.get(
- "label_sampler", QuickAdapterRegressorV3._OPTUNA_LABEL_SAMPLERS[0]
+ "label_sampler", QuickAdapterRegressorV3._OPTUNA_LABEL_SAMPLERS.auto
),
)
else: