]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor(quickadapter): dispatch optuna_create_sampler via match + assert_never ...
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Mon, 22 Jun 2026 10:09:19 +0000 (12:09 +0200)
committerGitHub <noreply@github.com>
Mon, 22 Jun 2026 10:09:19 +0000 (12:09 +0200)
Follow-up to #101. Converts the `optuna_create_sampler` if/elif/else dispatch chain to a `match`/`case` statement using value patterns (`QuickAdapterRegressorV3._OPTUNA_SAMPLERS.<name>`) -- the per-field singleton `Literal[...]` typing introduced by #101 unlocks pyright/mypy exhaustiveness narrowing, so the final `case _: assert_never(sampler)` type-checks as `Never` and catches any future extension of `OptunaSampler` that forgets to add a corresponding match branch.

The `case None:` branch is structurally required (not stylistic): without it, after the four `Literal[...]` value patterns, pyright/mypy would narrow `sampler` to `None`, not `Never`, and `assert_never(sampler)` would fail to type-check. Its presence is what makes the "5th sampler addition -> type error at assert_never" claim effective. The error message inside `case None:` is preserved verbatim from the prior `else: raise ValueError(...)` for user-facing wire compatibility.

Behavior delta is confined to non-`Literal` string inputs (e.g. dynamically-injected `"garbage"`): prior code raised `ValueError` with the supported-values list; new code raises `AssertionError` with the standard `typing.assert_never` message. In practice this path is unreachable from every current in-repo call site -- the sole caller `optuna_create_study` already validates `sampler not in samplers` and raises `ValueError` with the supported-values list before dispatching, so misconfigured `_optuna_config["sampler"]` values surface the old error shape from the upstream gate. The new `AssertionError` would only manifest via a hypothetical future caller that bypasses `optuna_create_study`, which does not exist today.

Pattern parity: the same `assert_never` exhaustiveness idiom is already used in this file for `support_policy` dispatch; value-pattern syntax matches the 8 call sites migrated by #101.

Closes the deferred follow-up from PR #101 (issue #88).

quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py

index 261fbd3553a5d1606965e0afd35927528999bfaf..be1741316f3fe08a0153d8982f52da15330a98f5 100644 (file)
@@ -4243,45 +4243,48 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
             sampler = self._optuna_config.get(
                 "sampler",
             )
-        if sampler == QuickAdapterRegressorV3._OPTUNA_SAMPLERS.tpe:
-            return optuna.samplers.TPESampler(
-                n_startup_trials=self._optuna_config.get(
-                    "n_startup_trials",
-                    QuickAdapterRegressorV3.OPTUNA_N_STARTUP_TRIALS_DEFAULT,
-                ),
-                multivariate=True,
-                group=True,
-                constant_liar=self._optuna_config.get(
-                    "n_jobs", QuickAdapterRegressorV3.OPTUNA_N_JOBS_DEFAULT
+        match sampler:
+            case None:
+                raise ValueError(
+                    f"Invalid optuna sampler value {sampler!r}: "
+                    f"supported values are {', '.join(QuickAdapterRegressorV3._OPTUNA_SAMPLERS)}"
                 )
-                > 1,
-                seed=self._optuna_config.get(
-                    "seed", QuickAdapterRegressorV3.OPTUNA_SEED_DEFAULT
-                ),
-            )
-        elif sampler == QuickAdapterRegressorV3._OPTUNA_SAMPLERS.auto:
-            return optunahub.load_module("samplers/auto_sampler").AutoSampler(
-                seed=self._optuna_config.get(
-                    "seed", QuickAdapterRegressorV3.OPTUNA_SEED_DEFAULT
+            case QuickAdapterRegressorV3._OPTUNA_SAMPLERS.tpe:
+                return optuna.samplers.TPESampler(
+                    n_startup_trials=self._optuna_config.get(
+                        "n_startup_trials",
+                        QuickAdapterRegressorV3.OPTUNA_N_STARTUP_TRIALS_DEFAULT,
+                    ),
+                    multivariate=True,
+                    group=True,
+                    constant_liar=self._optuna_config.get(
+                        "n_jobs", QuickAdapterRegressorV3.OPTUNA_N_JOBS_DEFAULT
+                    )
+                    > 1,
+                    seed=self._optuna_config.get(
+                        "seed", QuickAdapterRegressorV3.OPTUNA_SEED_DEFAULT
+                    ),
                 )
-            )
-        elif sampler == QuickAdapterRegressorV3._OPTUNA_SAMPLERS.nsgaii:
-            return optuna.samplers.NSGAIISampler(
-                seed=self._optuna_config.get(
-                    "seed", QuickAdapterRegressorV3.OPTUNA_SEED_DEFAULT
-                ),
-            )
-        elif sampler == QuickAdapterRegressorV3._OPTUNA_SAMPLERS.nsgaiii:
-            return optuna.samplers.NSGAIIISampler(
-                seed=self._optuna_config.get(
-                    "seed", QuickAdapterRegressorV3.OPTUNA_SEED_DEFAULT
-                ),
-            )
-        else:
-            raise ValueError(
-                f"Invalid optuna sampler value {sampler!r}: "
-                f"supported values are {', '.join(QuickAdapterRegressorV3._OPTUNA_SAMPLERS)}"
-            )
+            case QuickAdapterRegressorV3._OPTUNA_SAMPLERS.auto:
+                return optunahub.load_module("samplers/auto_sampler").AutoSampler(
+                    seed=self._optuna_config.get(
+                        "seed", QuickAdapterRegressorV3.OPTUNA_SEED_DEFAULT
+                    )
+                )
+            case QuickAdapterRegressorV3._OPTUNA_SAMPLERS.nsgaii:
+                return optuna.samplers.NSGAIISampler(
+                    seed=self._optuna_config.get(
+                        "seed", QuickAdapterRegressorV3.OPTUNA_SEED_DEFAULT
+                    ),
+                )
+            case QuickAdapterRegressorV3._OPTUNA_SAMPLERS.nsgaiii:
+                return optuna.samplers.NSGAIIISampler(
+                    seed=self._optuna_config.get(
+                        "seed", QuickAdapterRegressorV3.OPTUNA_SEED_DEFAULT
+                    ),
+                )
+            case _:
+                assert_never(sampler)
 
     @lru_cache(maxsize=8)
     def optuna_samplers_by_namespace(