From 56ef81edef4412f79df2ee129e7c1ee7a703b3d2 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Thu, 29 May 2025 11:17:48 +0200 Subject: [PATCH] refactor(qav3): add tunables for knn based trials selection MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- .../freqaimodels/QuickAdapterRegressorV3.py | 39 +++++++++++++------ 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py index 8980184..40d5094 100644 --- a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py +++ b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py @@ -437,8 +437,10 @@ class QuickAdapterRegressorV3(BaseRegressionModel): "harmonic_mean", "power_mean", "weighted_sum", - "d1", - "d2", + "knn-d1", + "knn-d2-mean", + "knn-d2-median", + "knn-d2-max", } if label_metric not in metrics: raise ValueError( @@ -470,6 +472,10 @@ class QuickAdapterRegressorV3(BaseRegressionModel): ) if np_weights.size != normalized_matrix.shape[1]: raise ValueError("label_weights length must match number of objectives") + label_knn_metric = self.ft_params.get("label_knn_metric", "euclidean") + knn_kwargs = {} + if label_knn_metric == "minkowski" and isinstance(p_order, float): + knn_kwargs["p"] = p_order ideal_point = np.ones(normalized_matrix.shape[1]) if metric in { @@ -533,23 +539,34 @@ class QuickAdapterRegressorV3(BaseRegressionModel): ) - sp.stats.pmean(normalized_matrix, p=p, weights=np_weights, axis=1) elif metric == "weighted_sum": return np.sum(np_weights * (ideal_point - normalized_matrix), axis=1) - elif metric == "d1": + elif metric == "knn-d1": if normalized_matrix.shape[0] < 2: return np.full(normalized_matrix.shape[0], np.inf) - nbrs = sklearn.neighbors.NearestNeighbors(n_neighbors=2).fit( - normalized_matrix - ) + nbrs = sklearn.neighbors.NearestNeighbors( + n_neighbors=2, metric=label_knn_metric, **knn_kwargs + ).fit(normalized_matrix) distances, _ = nbrs.kneighbors(normalized_matrix) return distances[:, 1] - elif metric == "d2": + elif metric in {"knn-d2-mean", "knn-d2-median", "knn-d2-max"}: if normalized_matrix.shape[0] < 2: return np.full(normalized_matrix.shape[0], np.inf) - k = min(4, normalized_matrix.shape[0] - 1) + 1 - nbrs = sklearn.neighbors.NearestNeighbors(n_neighbors=k).fit( - normalized_matrix + n_neighbors = ( + min( + int(self.ft_params.get("label_knn_d2_min_n_neighbors", "4")), + normalized_matrix.shape[0] - 1, + ) + + 1 ) + nbrs = sklearn.neighbors.NearestNeighbors( + n_neighbors=n_neighbors, metric=label_knn_metric, **knn_kwargs + ).fit(normalized_matrix) distances, _ = nbrs.kneighbors(normalized_matrix) - return np.mean(distances[:, 1:], axis=1) + if metric == "knn-d2-mean": + return np.mean(distances[:, 1:], axis=1) + elif metric == "knn-d2-median": + return np.median(distances[:, 1:], axis=1) + elif metric == "knn-d2-max": + return np.max(distances[:, 1:], axis=1) else: raise ValueError(f"Unsupported distance metric: {metric}") -- 2.43.0