From 59b94bedd8a37bd7972dfe9875017b39b1326299 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Mon, 10 Feb 2025 21:21:40 +0100 Subject: [PATCH] refactor(qav3): fit label mean and std MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- .../LightGBMRegressorQuickAdapterV35.py | 13 +++++++++---- .../freqaimodels/XGBoostRegressorQuickAdapterV35.py | 13 +++++++++---- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py b/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py index c1af399..cdf4a34 100644 --- a/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py +++ b/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py @@ -199,10 +199,15 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel): ] dk.data["labels_mean"], dk.data["labels_std"] = {}, {} - for ft in dk.label_list: - # f = spy.stats.norm.fit(pred_df_full[ft]) - dk.data["labels_std"][ft] = 0 # f[1] - dk.data["labels_mean"][ft] = 0 # f[0] + for label in dk.label_list + dk.unique_class_list: + if pred_df_full[label].dtype == object: + continue + if not warmed_up: + f = [0, 0] + else: + f = spy.stats.norm.fit(pred_df_full[label]) + dk.data["labels_mean"][label] = f[0] + dk.data["labels_std"][label] = f[1] # fit the DI_threshold if not warmed_up: diff --git a/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py b/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py index f2c28f8..9b8d629 100644 --- a/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py +++ b/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py @@ -199,10 +199,15 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel): ] dk.data["labels_mean"], dk.data["labels_std"] = {}, {} - for ft in dk.label_list: - # f = spy.stats.norm.fit(pred_df_full[ft]) - dk.data["labels_std"][ft] = 0 # f[1] - dk.data["labels_mean"][ft] = 0 # f[0] + for label in dk.label_list + dk.unique_class_list: + if pred_df_full[label].dtype == object: + continue + if not warmed_up: + f = [0, 0] + else: + f = spy.stats.norm.fit(pred_df_full[label]) + dk.data["labels_mean"][label] = f[0] + dk.data["labels_std"][label] = f[1] # fit the DI_threshold if not warmed_up: -- 2.43.0