"harmonic_mean",
"power_mean",
"weighted_sum",
- "d1",
- "d2",
+ "knn-d1",
+ "knn-d2-mean",
+ "knn-d2-median",
+ "knn-d2-max",
}
if label_metric not in metrics:
raise ValueError(
)
if np_weights.size != normalized_matrix.shape[1]:
raise ValueError("label_weights length must match number of objectives")
+ label_knn_metric = self.ft_params.get("label_knn_metric", "euclidean")
+ knn_kwargs = {}
+ if label_knn_metric == "minkowski" and isinstance(p_order, float):
+ knn_kwargs["p"] = p_order
ideal_point = np.ones(normalized_matrix.shape[1])
if metric in {
) - sp.stats.pmean(normalized_matrix, p=p, weights=np_weights, axis=1)
elif metric == "weighted_sum":
return np.sum(np_weights * (ideal_point - normalized_matrix), axis=1)
- elif metric == "d1":
+ elif metric == "knn-d1":
if normalized_matrix.shape[0] < 2:
return np.full(normalized_matrix.shape[0], np.inf)
- nbrs = sklearn.neighbors.NearestNeighbors(n_neighbors=2).fit(
- normalized_matrix
- )
+ nbrs = sklearn.neighbors.NearestNeighbors(
+ n_neighbors=2, metric=label_knn_metric, **knn_kwargs
+ ).fit(normalized_matrix)
distances, _ = nbrs.kneighbors(normalized_matrix)
return distances[:, 1]
- elif metric == "d2":
+ elif metric in {"knn-d2-mean", "knn-d2-median", "knn-d2-max"}:
if normalized_matrix.shape[0] < 2:
return np.full(normalized_matrix.shape[0], np.inf)
- k = min(4, normalized_matrix.shape[0] - 1) + 1
- nbrs = sklearn.neighbors.NearestNeighbors(n_neighbors=k).fit(
- normalized_matrix
+ n_neighbors = (
+ min(
+ int(self.ft_params.get("label_knn_d2_min_n_neighbors", "4")),
+ normalized_matrix.shape[0] - 1,
+ )
+ + 1
)
+ nbrs = sklearn.neighbors.NearestNeighbors(
+ n_neighbors=n_neighbors, metric=label_knn_metric, **knn_kwargs
+ ).fit(normalized_matrix)
distances, _ = nbrs.kneighbors(normalized_matrix)
- return np.mean(distances[:, 1:], axis=1)
+ if metric == "knn-d2-mean":
+ return np.mean(distances[:, 1:], axis=1)
+ elif metric == "knn-d2-median":
+ return np.median(distances[:, 1:], axis=1)
+ elif metric == "knn-d2-max":
+ return np.max(distances[:, 1:], axis=1)
else:
raise ValueError(f"Unsupported distance metric: {metric}")