From f3850c66e6d794f70b7b35c2691f3b8837185f41 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Thu, 10 Apr 2025 17:05:23 +0200 Subject: [PATCH] refactor(qav3): cleanup multi objective best trial heuristic MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- .../freqaimodels/QuickAdapterRegressorV3.py | 40 +++++++++++-------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py index 7015783..0141b16 100644 --- a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py +++ b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py @@ -389,31 +389,39 @@ class QuickAdapterRegressorV3(BaseRegressionModel): return None best_trials = study.best_trials if namespace == "label": - range_sizes = [trial.values[1] for trial in best_trials] - median_range_size = np.median(range_sizes) + peaks_sizes = [trial.values[1] for trial in best_trials] + median_peaks_size = np.median(peaks_sizes) equal_median_trials = [ trial for trial in best_trials - if np.isclose(trial.values[1], median_range_size) + if np.isclose(trial.values[1], median_peaks_size) ] if equal_median_trials: return max(equal_median_trials, key=lambda trial: trial.values[0]) - nearest_above_median = (np.inf, -np.inf, None) - nearest_below_median = (-np.inf, -np.inf, None) + nearest_above_median = ( + np.inf, + -np.inf, + None, + ) # (trial_peaks_size, trial_peaks_range, trial_index) + nearest_below_median = ( + -np.inf, + -np.inf, + None, + ) # (trial_peaks_size, trial_peaks_range, trial_index) for idx, trial in enumerate(best_trials): - range_size = trial.values[1] - if range_size >= median_range_size: - if range_size < nearest_above_median[0] or ( - range_size == nearest_above_median[0] - and trial.values[0] > nearest_above_median[1] + peaks_size = trial.values[1] + if peaks_size >= median_peaks_size: + if peaks_size < nearest_above_median[0] or ( + peaks_size == nearest_above_median[0] + and trial.values[1] > nearest_above_median[1] ): - nearest_above_median = (range_size, trial.values[0], idx) - if range_size <= median_range_size: - if range_size > nearest_below_median[0] or ( - range_size == nearest_below_median[0] - and trial.values[0] > nearest_below_median[1] + nearest_above_median = (peaks_size, trial.values[1], idx) + if peaks_size <= median_peaks_size: + if peaks_size > nearest_below_median[0] or ( + peaks_size == nearest_below_median[0] + and trial.values[1] > nearest_below_median[1] ): - nearest_below_median = (range_size, trial.values[0], idx) + nearest_below_median = (peaks_size, trial.values[1], idx) if nearest_above_median[2] is None or nearest_below_median[2] is None: return None above_median_trial = best_trials[nearest_above_median[2]] -- 2.43.0