From ae4fb46bc4557a3009972523b37cb1a04452e450 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Mon, 26 May 2025 19:21:58 +0200 Subject: [PATCH] feat(qav3): add a few metrics for trials selection MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- .../freqaimodels/QuickAdapterRegressorV3.py | 36 ++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py index 6b2a394..c7f77ee 100644 --- a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py +++ b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py @@ -406,7 +406,16 @@ class QuickAdapterRegressorV3(BaseRegressionModel): n_objectives = len(study.directions) label_metric = self.ft_params.get("label_metric", "euclidean") - metrics = {"euclidean", "chebyshev", "manhattan", "minkowski"} + metrics = { + "euclidean", + "chebyshev", + "manhattan", + "minkowski", + "canberra", + "braycurtis", + "hellinger", + "geometric_mean", + } if label_metric not in metrics: raise ValueError( f"Unsupported label metric: {label_metric}. Supported metrics are {', '.join(metrics)}" @@ -451,6 +460,31 @@ class QuickAdapterRegressorV3(BaseRegressionModel): ), 1.0 / p_order, ) + elif metric == "canberra": + return np.sum( + np.abs(normalized_matrix - ideal_point) + / (np.abs(normalized_matrix) + np.abs(ideal_point)), + axis=1, + ) + elif metric == "braycurtis": + return np.divide( + np.sum(np.abs(normalized_matrix - ideal_point), axis=1), + np.sum(normalized_matrix + ideal_point, axis=1), + out=np.zeros(normalized_matrix.shape[0], dtype=float), + where=(np.sum(normalized_matrix + ideal_point, axis=1) != 0), + ) + elif metric == "hellinger": + return np.sqrt(np.sum((np.sqrt(normalized_matrix) - 1.0) ** 2, axis=1)) + elif metric == "geometric_mean": + return ( + np.array([]) + if normalized_matrix.shape[1] == 0 + else 1.0 + - ( + np.prod(normalized_matrix, axis=1) + ** (1.0 / normalized_matrix.shape[1]) + ) + ) else: raise ValueError(f"Unsupported distance metric: {metric}") -- 2.43.0