]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
feat(qav3): add standardized hellinger distance support to MO trial
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Wed, 28 May 2025 18:08:03 +0000 (20:08 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Wed, 28 May 2025 18:08:03 +0000 (20:08 +0200)
selection

Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py
quickadapter/user_data/strategies/QuickAdapterV3.py

index b94572c38b9f1dcaf5f1f676da7eb15ceff99578..2392359eb5a9a0b0468d80d5aaf3b09c0c25d911 100644 (file)
@@ -45,7 +45,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
     https://github.com/sponsors/robcaulk
     """
 
-    version = "3.7.72"
+    version = "3.7.73"
 
     @cached_property
     def _optuna_config(self) -> dict:
@@ -432,6 +432,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
             "sqeuclidean",
             "yule",
             "hellinger",
+            "shellinger",
             "geometric_mean",
             "harmonic_mean",
             "power_mean",
@@ -517,7 +518,17 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
                         * (np.sqrt(normalized_matrix) - np.sqrt(ideal_point)) ** 2,
                         axis=1,
                     )
-                )
+                ) / np.sqrt(2.0)
+            elif metric == "shellinger":
+                np_sqrt_normalized_matrix = np.sqrt(normalized_matrix)
+                np_weights = 1 / np.var(np_sqrt_normalized_matrix, axis=0, ddof=1)
+                return np.sqrt(
+                    np.sum(
+                        np_weights
+                        * (np_sqrt_normalized_matrix - np.sqrt(ideal_point)) ** 2,
+                        axis=1,
+                    )
+                ) / np.sqrt(2.0)
             elif metric in {"geometric_mean", "harmonic_mean", "power_mean"}:
                 p = {
                     "geometric_mean": 0.0,
index aa47db635f174e6ea78163f9c005028e81c87050..e7afb108be230ddb958150a78a5d1f03d1a783b9 100644 (file)
@@ -60,7 +60,7 @@ class QuickAdapterV3(IStrategy):
     INTERFACE_VERSION = 3
 
     def version(self) -> str:
-        return "3.3.75"
+        return "3.3.76"
 
     timeframe = "5m"