From 0df47c96092f80054971c3d664c3857b7a099776 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Tue, 3 Jun 2025 19:21:23 +0200 Subject: [PATCH] refactor(qav3): trivial cleanup in pivots labeling optimization MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- .../user_data/freqaimodels/QuickAdapterRegressorV3.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py index 8ad20e9..45386a2 100644 --- a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py +++ b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py @@ -492,6 +492,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel): knn_kwargs["p"] = label_p_order ideal_point = np.ones(n_objectives) + ideal_point_2d = ideal_point.reshape(1, -1) if metric in { "braycurtis", @@ -528,7 +529,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel): cdist_kwargs["p"] = label_p_order return sp.spatial.distance.cdist( normalized_matrix, - ideal_point.reshape(1, -1), # reshape ideal_point to 2D + ideal_point_2d, metric=metric, **cdist_kwargs, ).flatten() @@ -568,7 +569,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel): if n_samples == 1: return sp.spatial.distance.cdist( normalized_matrix, - ideal_point.reshape(1, -1), + ideal_point_2d, metric=label_kmeans_metric, **cdist_kwargs, ).flatten() @@ -580,7 +581,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel): cluster_centers = kmeans.cluster_centers_ cluster_distances_to_ideal = sp.spatial.distance.cdist( cluster_centers, - ideal_point.reshape(1, -1), + ideal_point_2d, metric=label_kmeans_metric, **cdist_kwargs, ).flatten() -- 2.43.0