From: Jérôme Benoit Date: Tue, 23 Sep 2025 20:50:43 +0000 (+0200) Subject: refactor(qav3): factor out pairwise distance sums computation X-Git-Url: https://git.piment-noir.org/?a=commitdiff_plain;h=dead0a59307573394ab018e30d85a2b682c66fa9;p=freqai-strategies.git refactor(qav3): factor out pairwise distance sums computation Signed-off-by: Jérôme Benoit --- diff --git a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py index 2468a7f..8d2c468 100644 --- a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py +++ b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py @@ -704,6 +704,50 @@ class QuickAdapterRegressorV3(BaseRegressionModel): ) return np.median(values) + @staticmethod + def _pairwise_distance_sums( + matrix: NDArray[np.floating], + metric: str, + *, + weights: Optional[NDArray[np.floating]] = None, + p: Optional[float] = None, + ) -> NDArray[np.floating]: + """Return, for each sample, the sum of its distances to all other samples. + + Typical usage: representative (e.g., medoid) selection by taking argmin of + the returned vector. Function is generic and not tied to the medoid concept. + + Parameters: + - matrix: 2D array (n_samples, n_features) assumed already normalized. + - metric: distance metric name accepted by scipy.spatial.distance.cdist. + - weights: optional weight vector (broadcast via cdist 'w' parameter) for metrics supporting it. + - p: optional Minkowski order when metric == 'minkowski'. + + Notes: + - Caller must validate metric compatibility (e.g. exclude mahalanobis / seuclidean / jensenshannon here). + - Behavior for n_samples in {0,1} is handled to preserve previous semantics. + """ + if matrix.ndim != 2: + raise ValueError("matrix must be 2-dimensional") + if matrix.shape[1] == 0: + raise ValueError("matrix must have at least one feature") + if matrix.shape[0] == 1: + return np.array([0.0]) + if matrix.shape[0] < 2: + return np.full(matrix.shape[0], np.inf) + cdist_kwargs: dict[str, Any] = {} + if weights is not None: + cdist_kwargs["w"] = weights + if metric == "minkowski" and p is not None: + cdist_kwargs["p"] = p + pairwise_distances = sp.spatial.distance.cdist( + matrix, + matrix, + metric=metric, + **cdist_kwargs, + ) + return np.sum(pairwise_distances, axis=1) + def get_multi_objective_study_best_trial( self, namespace: str, study: optuna.study.Study ) -> Optional[optuna.trial.FrozenTrial]: @@ -888,18 +932,17 @@ class QuickAdapterRegressorV3(BaseRegressionModel): raise ValueError( f"Unsupported label_medoid_metric: {label_medoid_metric}. Supported are euclidean/minkowski/cityblock/chebyshev/..." ) - cdist_kwargs: dict[str, Any] = {"w": np_weights} - if label_medoid_metric == "minkowski": - cdist_kwargs["p"] = ( - label_p_order if label_p_order is not None else 2.0 - ) - pairwise_distances = sp.spatial.distance.cdist( - normalized_matrix, + return self._pairwise_distance_sums( normalized_matrix, - metric=label_medoid_metric, - **cdist_kwargs, + label_medoid_metric, + weights=np_weights, + p=( + label_p_order + if label_medoid_metric == "minkowski" + and label_p_order is not None + else None + ), ) - return np.sum(pairwise_distances, axis=1) elif metric in {"kmeans", "kmeans2"}: if n_samples == 1: return np.array([0.0]) @@ -952,14 +995,17 @@ class QuickAdapterRegressorV3(BaseRegressionModel): if best_cluster_indices is not None and best_cluster_indices.size > 0: if label_kmeans_selection == "medoid": best_cluster_matrix = normalized_matrix[best_cluster_indices] - pairwise_distances = sp.spatial.distance.cdist( - best_cluster_matrix, - best_cluster_matrix, - metric=label_kmeans_metric, - **cdist_kwargs, - ) - trial_distances[best_cluster_indices] = np.sum( - pairwise_distances, axis=1 + trial_distances[best_cluster_indices] = ( + self._pairwise_distance_sums( + best_cluster_matrix, + label_kmeans_metric, + p=( + label_p_order + if label_kmeans_metric == "minkowski" + and label_p_order is not None + else None + ), + ) ) elif label_kmeans_selection == "min": trial_distances[best_cluster_indices] = (