From 482a123043f44ad8423d46d656ddf389302077a5 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Mon, 26 May 2025 22:51:22 +0200 Subject: [PATCH] refactor(qav3): more explicit code at MO trial selection MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- .../freqaimodels/QuickAdapterRegressorV3.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py index 4c3db58..fc3ff6e 100644 --- a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py +++ b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py @@ -465,7 +465,11 @@ class QuickAdapterRegressorV3(BaseRegressionModel): normalized_matrix - ideal_point, ord=order, axis=1 ) elif metric == "hellinger": - return np.sqrt(np.sum((np.sqrt(normalized_matrix) - 1.0) ** 2, axis=1)) + return np.sqrt( + np.sum( + (np.sqrt(normalized_matrix) - np.sqrt(ideal_point)) ** 2, axis=1 + ) + ) elif metric == "geometric_mean": return 1.0 - np.prod(normalized_matrix, axis=1) ** ( 1.0 / normalized_matrix.shape[1] @@ -480,7 +484,9 @@ class QuickAdapterRegressorV3(BaseRegressionModel): raise ValueError( "label_weights length must match number of objectives" ) - return np.sum(np.array(weights) * (1.0 - normalized_matrix), axis=1) + return np.sum( + np.array(weights) * (ideal_point - normalized_matrix), axis=1 + ) elif metric == "tchebichev": weights = self.ft_params.get( "label_weights", [1.0] * normalized_matrix.shape[1] @@ -489,7 +495,9 @@ class QuickAdapterRegressorV3(BaseRegressionModel): raise ValueError( "label_weights length must match number of objectives" ) - return np.max(np.array(weights) * (1.0 - normalized_matrix), axis=1) + return np.max( + np.array(weights) * (ideal_point - normalized_matrix), axis=1 + ) elif metric == "mahalanobis": if normalized_matrix.shape[0] < 2: return np.linalg.norm( -- 2.43.0