]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor(qav3): trivial cleanup in pivots labeling optimization
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Tue, 3 Jun 2025 17:21:23 +0000 (19:21 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Tue, 3 Jun 2025 17:21:23 +0000 (19:21 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py

index 8ad20e9605af19ad51b9320556e9eeb048b04dfd..45386a2fd553bb579bcf59f2c553ac76e16e4d38 100644 (file)
@@ -492,6 +492,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
                 knn_kwargs["p"] = label_p_order
 
             ideal_point = np.ones(n_objectives)
+            ideal_point_2d = ideal_point.reshape(1, -1)
 
             if metric in {
                 "braycurtis",
@@ -528,7 +529,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
                     cdist_kwargs["p"] = label_p_order
                 return sp.spatial.distance.cdist(
                     normalized_matrix,
-                    ideal_point.reshape(1, -1),  # reshape ideal_point to 2D
+                    ideal_point_2d,
                     metric=metric,
                     **cdist_kwargs,
                 ).flatten()
@@ -568,7 +569,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
                 if n_samples == 1:
                     return sp.spatial.distance.cdist(
                         normalized_matrix,
-                        ideal_point.reshape(1, -1),
+                        ideal_point_2d,
                         metric=label_kmeans_metric,
                         **cdist_kwargs,
                     ).flatten()
@@ -580,7 +581,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
                 cluster_centers = kmeans.cluster_centers_
                 cluster_distances_to_ideal = sp.spatial.distance.cdist(
                     cluster_centers,
-                    ideal_point.reshape(1, -1),
+                    ideal_point_2d,
                     metric=label_kmeans_metric,
                     **cdist_kwargs,
                 ).flatten()