From a64b6c68b15c205996ea50b59dd622baf94b1475 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Tue, 3 Jun 2025 17:22:12 +0200 Subject: [PATCH] refactor(qav3): ensure label_weights are L1 normalized MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- .../user_data/freqaimodels/QuickAdapterRegressorV3.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py index 27a48e6..8ad20e9 100644 --- a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py +++ b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py @@ -480,6 +480,12 @@ class QuickAdapterRegressorV3(BaseRegressionModel): ) if np_weights.size != n_objectives: raise ValueError("label_weights length must match number of objectives") + if np.any(np_weights < 0): + raise ValueError("label_weights values must be non-negative") + label_weights_sum = np.sum(np_weights) + if np.isclose(label_weights_sum, 0): + raise ValueError("label_weights sum cannot be zero") + np_weights = np_weights / label_weights_sum knn_kwargs = {} label_knn_metric = self.ft_params.get("label_knn_metric", "euclidean") if label_knn_metric == "minkowski" and isinstance(label_p_order, float): -- 2.43.0