]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
perf(qav3): improve MO optimization support
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Sun, 28 Sep 2025 18:31:38 +0000 (20:31 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Sun, 28 Sep 2025 18:31:38 +0000 (20:31 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py
quickadapter/user_data/strategies/QuickAdapterV3.py

index 4abea1dc0723f08ef596302cf5f75122ee8fb6f0..c77abe966d5a81503179efbad9a8a6cf2ee81911 100644 (file)
@@ -61,7 +61,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
     https://github.com/sponsors/robcaulk
     """
 
-    version = "3.7.114"
+    version = "3.7.115"
 
     @cached_property
     def _optuna_config(self) -> dict[str, Any]:
@@ -395,6 +395,8 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
     ) -> None:
         if namespace != "label":
             raise ValueError(f"Invalid namespace: {namespace}")
+        if not callable(callback):
+            raise ValueError("callback must be callable")
         self._optuna_label_candles[pair] += 1
         if pair not in self._optuna_label_incremented_pairs:
             self._optuna_label_incremented_pairs.append(pair)
@@ -845,6 +847,26 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
             ideal_point = np.ones(n_objectives)
             ideal_point_2d = ideal_point.reshape(1, -1)
 
+            def _get_n_clusters(
+                matrix: NDArray[np.floating],
+                *,
+                min_n_clusters: int = 2,
+                max_n_clusters: int = 10,
+            ) -> int:
+                n_samples = matrix.shape[0]
+                if n_samples <= 1:
+                    return 1
+                n_uniques = np.unique(matrix, axis=0).shape[0]
+                upper_bound = max(1, min(max_n_clusters, n_uniques, n_samples))
+                lower_bound = max(2, min(min_n_clusters, upper_bound))
+                if upper_bound < 2:
+                    return 1
+                try:
+                    n_clusters = int(round(np.log2(max(n_samples, 2))))
+                except Exception:
+                    n_clusters = min_n_clusters
+                return max(lower_bound, min(n_clusters, upper_bound))
+
             if metric in {
                 # "braycurtis",
                 # "canberra",
@@ -950,7 +972,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
                     return np.array([0.0])
                 if n_samples < 2:
                     return np.full(n_samples, np.inf)
-                n_clusters = min(max(2, int(np.sqrt(n_samples / 2))), 10, n_samples)
+                n_clusters = _get_n_clusters(normalized_matrix)
                 if metric == "kmeans":
                     kmeans = sklearn.cluster.KMeans(
                         n_clusters=n_clusters, random_state=42, n_init=10
@@ -1028,7 +1050,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
                     return np.array([0.0])
                 if n_samples < 2:
                     return np.full(n_samples, np.inf)
-                n_clusters = min(max(2, int(np.sqrt(n_samples / 2))), 10, n_samples)
+                n_clusters = _get_n_clusters(normalized_matrix)
                 label_kmedoids_metric = self.ft_params.get(
                     "label_kmedoids_metric", "euclidean"
                 )
@@ -1274,7 +1296,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
             metric_log_msg = ""
         else:
             try:
-                best_trial = self.get_multi_objective_study_best_trial("label", study)
+                best_trial = self.get_multi_objective_study_best_trial(namespace, study)
             except Exception as e:
                 logger.error(
                     f"Optuna {pair} {namespace} {objective_type} objective hyperopt failed ({time_spent:.2f} secs): {repr(e)}",
@@ -1363,6 +1385,19 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
         if continuous:
             QuickAdapterRegressorV3.optuna_study_delete(study_name, storage)
 
+        is_study_single_objective = direction is not None and directions is None
+        if (
+            not is_study_single_objective
+            and isinstance(directions, list)
+            and len(directions) < 2
+        ):
+            raise ValueError(
+                "Multi-objective study must have at least 2 directions specified"
+            )
+        if is_study_single_objective:
+            pruner = optuna.pruners.HyperbandPruner(min_resource=3)
+        else:
+            pruner = optuna.pruners.NopPruner()
         try:
             return optuna.create_study(
                 study_name=study_name,
@@ -1372,7 +1407,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
                     group=True,
                     seed=self._optuna_config.get("seed"),
                 ),
-                pruner=optuna.pruners.HyperbandPruner(min_resource=3),
+                pruner=pruner,
                 direction=direction,
                 directions=directions,
                 storage=storage,
@@ -1504,7 +1539,7 @@ def train_objective(
         logger.warning(
             f"Insufficient test data: {test_length} < {min_test_period_candles}"
         )
-        test_ok = False
+        return np.inf
     max_test_period_candles: int = test_length
     test_period_candles: int = trial.suggest_int(
         "test_period_candles",
@@ -1543,7 +1578,7 @@ def train_objective(
         logger.warning(
             f"Insufficient train data: {train_length} < {min_train_period_candles}"
         )
-        train_ok = False
+        return np.inf
     max_train_period_candles: int = train_length
     train_period_candles: int = trial.suggest_int(
         "train_period_candles",
index 6c43055c7a37d229b4150cc13e72d4293ae63c26..5b050aaab393d454a979f7a7122301f9e160c40a 100644 (file)
@@ -788,6 +788,8 @@ class QuickAdapterV3(IStrategy):
         current_time: datetime.datetime,
         callback: Callable[[], None],
     ) -> None:
+        if not callable(callback):
+            raise ValueError("callback must be callable")
         timestamp = int(current_time.timestamp())
         candle_start_secs = timestamp - (timestamp % self._candle_duration_secs)
         if candle_start_secs != self.last_candle_start_secs.get(pair):