From b5b6a961e2fe2a8905efc28a9ac5418f05009831 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Sun, 23 Nov 2025 22:59:46 +0100 Subject: [PATCH] refactor: improve type hints MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/.devcontainer/requirements-dev.txt | 3 +++ ReforceXY/user_data/freqaimodels/ReforceXY.py | 12 +++++++--- .../.devcontainer/requirements-dev.txt | 2 ++ .../freqaimodels/QuickAdapterRegressorV3.py | 5 +++-- quickadapter/user_data/strategies/Utils.py | 22 +++++++++++++++++-- 5 files changed, 37 insertions(+), 7 deletions(-) diff --git a/ReforceXY/.devcontainer/requirements-dev.txt b/ReforceXY/.devcontainer/requirements-dev.txt index af3ee57..73e7736 100644 --- a/ReforceXY/.devcontainer/requirements-dev.txt +++ b/ReforceXY/.devcontainer/requirements-dev.txt @@ -1 +1,4 @@ +pandas-stubs +scipy-stubs +uv ruff diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index 78c31b6..3f22e9c 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -52,6 +52,7 @@ from optuna.storages import ( ) from optuna.storages.journal import JournalFileBackend from optuna.study import Study, StudyDirection +from optuna.study.study import ObjectiveFuncType from pandas import DataFrame, merge from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback from sb3_contrib.common.maskable.utils import is_masking_supported @@ -1159,8 +1160,11 @@ class ReforceXY(BaseReinforcementLearningModel): hyperopt_failed = False start_time = time.time() try: + objective: ObjectiveFuncType = lambda trial: self.objective( + trial, dk, total_timesteps + ) study.optimize( - lambda trial: self.objective(trial, dk, total_timesteps), + objective, n_trials=self.optuna_n_trials, timeout=( hours_to_seconds(self.optuna_timeout_hours) @@ -1285,7 +1289,7 @@ class ReforceXY(BaseReinforcementLearningModel): prices_train, prices_test = self.build_ohlc_price_dataframes( dk.data_dictionary, dk.pair, dk ) - seed = self.get_model_params().get("seed", 42) if seed is None else seed + seed: int = self.get_model_params().get("seed", 42) if seed is None else seed if trial is not None: seed += trial.number set_random_seed(seed) @@ -1549,7 +1553,9 @@ class MyRLEnv(Base5ActionRLEnv): self._total_entry_additive: float = 0.0 self._last_exit_additive: float = 0.0 self._total_exit_additive: float = 0.0 - model_reward_parameters = self.rl_config.get("model_reward_parameters", {}) + model_reward_parameters: Dict[str, Any] = self.rl_config.get( + "model_reward_parameters", {} + ) self.max_trade_duration_candles: int = int( model_reward_parameters.get( "max_trade_duration_candles", diff --git a/quickadapter/.devcontainer/requirements-dev.txt b/quickadapter/.devcontainer/requirements-dev.txt index af3ee57..4c0d70d 100644 --- a/quickadapter/.devcontainer/requirements-dev.txt +++ b/quickadapter/.devcontainer/requirements-dev.txt @@ -1 +1,3 @@ +pandas-stubs +scipy-stubs ruff diff --git a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py index 4b5481b..65f7c66 100644 --- a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py +++ b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py @@ -18,6 +18,7 @@ import sklearn from freqtrade.freqai.base_models.BaseRegressionModel import BaseRegressionModel from freqtrade.freqai.data_kitchen import FreqaiDataKitchen from numpy.typing import NDArray +from optuna.study.study import ObjectiveFuncType from sklearn_extra.cluster import KMedoids from Utils import ( @@ -560,7 +561,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel): self, pair: str, namespace: str, - callback: Callable[[], None], + callback: Callable[[], Optional[optuna.study.Study]], ) -> None: if namespace not in { QuickAdapterRegressorV3._OPTUNA_NAMESPACES[2] @@ -1637,7 +1638,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel): self, pair: str, namespace: str, - objective: Callable[[optuna.trial.Trial], float], + objective: ObjectiveFuncType, direction: Optional[optuna.study.StudyDirection] = None, directions: Optional[list[optuna.study.StudyDirection]] = None, ) -> Optional[optuna.study.Study]: diff --git a/quickadapter/user_data/strategies/Utils.py b/quickadapter/user_data/strategies/Utils.py index 3a6ede8..5c3bb12 100644 --- a/quickadapter/user_data/strategies/Utils.py +++ b/quickadapter/user_data/strategies/Utils.py @@ -1144,7 +1144,18 @@ REGRESSORS: Final[tuple[Regressor, ...]] = ("xgboost", "lightgbm") def get_optuna_callbacks( trial: optuna.trial.Trial, regressor: Regressor -) -> list[Callable[[optuna.trial.Trial, str], None]]: +) -> list[ + Union[ + optuna.integration.XGBoostPruningCallback, + optuna.integration.LightGBMPruningCallback, + ] +]: + callbacks: list[ + Union[ + optuna.integration.XGBoostPruningCallback, + optuna.integration.LightGBMPruningCallback, + ] + ] if regressor == REGRESSORS[0]: # "xgboost" callbacks = [ optuna.integration.XGBoostPruningCallback(trial, "validation_0-rmse") @@ -1167,7 +1178,14 @@ def fit_regressor( eval_weights: Optional[list[NDArray[np.floating]]], model_training_parameters: dict[str, Any], init_model: Any = None, - callbacks: Optional[list[Callable[[optuna.trial.Trial, str], None]]] = None, + callbacks: Optional[ + list[ + Union[ + optuna.integration.XGBoostPruningCallback, + optuna.integration.LightGBMPruningCallback, + ] + ] + ] = None, trial: Optional[optuna.trial.Trial] = None, ) -> Any: if regressor == REGRESSORS[0]: # "xgboost" -- 2.43.0