From 217683c562efdb321f2615fafe424181c41cdc09 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Mon, 10 Feb 2025 13:02:31 +0100 Subject: [PATCH] perf(qav3): optimize optuna memory usage MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- .../LightGBMRegressorQuickAdapterV35.py | 14 +++++++++++++- .../XGBoostRegressorQuickAdapterV35.py | 14 +++++++++++++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py b/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py index d2251fa..5b4a9fb 100644 --- a/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py +++ b/quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py @@ -67,8 +67,19 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel): start = time.time() if self.__optuna_hyperopt: + storage_dir, study_name = str(dk.full_path).rsplit("/", 1) pruner = optuna.pruners.HyperbandPruner() - study = optuna.create_study(pruner=pruner, direction="minimize") + study = optuna.create_study( + study_name=study_name, + sampler=optuna.samplers.TPESampler( + multivariate=True, + group=True, + ), + pruner=pruner, + direction=optuna.study.StudyDirection.MINIMIZE, + storage=f"sqlite:///{storage_dir}/optuna-lgbm.sqlite", + load_if_exists=True, + ) study.optimize( lambda trial: objective( trial, @@ -86,6 +97,7 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel): n_trials=self.__optuna_config.get("n_trials", N_TRIALS), n_jobs=self.__optuna_config.get("n_jobs", 1), timeout=self.__optuna_config.get("timeout", 3600), + gc_after_trial=True, ) self.__optuna_hp = study.best_params diff --git a/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py b/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py index 9eb323d..0f00dff 100644 --- a/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py +++ b/quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py @@ -67,8 +67,19 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel): start = time.time() if self.__optuna_hyperopt: + storage_dir, study_name = str(dk.full_path).rsplit("/", 1) pruner = optuna.pruners.HyperbandPruner() - study = optuna.create_study(pruner=pruner, direction="minimize") + study = optuna.create_study( + study_name=study_name, + sampler=optuna.samplers.TPESampler( + multivariate=True, + group=True, + ), + pruner=pruner, + direction=optuna.study.StudyDirection.MINIMIZE, + storage=f"sqlite:///{storage_dir}/optuna-xgboost.sqlite", + load_if_exists=True, + ) study.optimize( lambda trial: objective( trial, @@ -86,6 +97,7 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel): n_trials=self.__optuna_config.get("n_trials", N_TRIALS), n_jobs=self.__optuna_config.get("n_jobs", 1), timeout=self.__optuna_config.get("timeout", 3600), + gc_after_trial=True, ) self.__optuna_hp = study.best_params -- 2.43.0