From 843a616ed7fb4a28cdff944bade0aaa416996024 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Wed, 18 Jun 2025 20:00:15 +0200 Subject: [PATCH] refactor: refine typing MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/user_data/freqaimodels/ReforceXY.py | 6 ++++-- .../user_data/strategies/RLAgentStrategy.py | 20 +++++++++++-------- .../freqaimodels/QuickAdapterRegressorV3.py | 8 ++++---- .../user_data/strategies/QuickAdapterV3.py | 18 ++++++++++------- quickadapter/user_data/strategies/Utils.py | 2 +- 5 files changed, 32 insertions(+), 22 deletions(-) diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index 81e77ca..c2e3734 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -120,7 +120,9 @@ class ReforceXY(BaseReinforcementLearningModel): self.check_envs: bool = self.rl_config.get("check_envs", True) self.progressbar_callback: Optional[ProgressBarCallback] = None # Optuna hyperopt - self.rl_config_optuna: dict = self.freqai_info.get("rl_config_optuna", {}) + self.rl_config_optuna: Dict[str, Any] = self.freqai_info.get( + "rl_config_optuna", {} + ) self.hyperopt: bool = ( self.freqai_info.get("enabled", False) and self.rl_config_optuna.get("enabled", False) @@ -459,7 +461,7 @@ class ReforceXY(BaseReinforcementLearningModel): def _predict(window): observation: DataFrame = dataframe.iloc[window.index] - action_masks_param: dict = {} + action_masks_param: Dict[str, Any] = {} if self.live and self.rl_config.get("add_state_info", False): position, pnl, trade_duration = self.get_state_info(dk.pair) diff --git a/ReforceXY/user_data/strategies/RLAgentStrategy.py b/ReforceXY/user_data/strategies/RLAgentStrategy.py index 15581fd..6903c66 100644 --- a/ReforceXY/user_data/strategies/RLAgentStrategy.py +++ b/ReforceXY/user_data/strategies/RLAgentStrategy.py @@ -1,6 +1,6 @@ import logging from functools import cached_property, reduce -# from typing import Any +from typing import Any # import talib.abstract as ta from pandas import DataFrame @@ -63,14 +63,14 @@ class RLAgentStrategy(IStrategy): return self.is_short_allowed() # def feature_engineering_expand_all( - # self, dataframe: DataFrame, period: int, metadata: dict, **kwargs + # self, dataframe: DataFrame, period: int, metadata: dict[str, Any], **kwargs # ) -> DataFrame: # dataframe["%-rsi-period"] = ta.RSI(dataframe, timeperiod=period) # return dataframe def feature_engineering_expand_basic( - self, dataframe: DataFrame, metadata: dict, **kwargs + self, dataframe: DataFrame, metadata: dict[str, Any], **kwargs ) -> DataFrame: dataframe["%-close_pct_change"] = dataframe.get("close").pct_change() dataframe["%-raw_volume"] = dataframe.get("volume") @@ -78,7 +78,7 @@ class RLAgentStrategy(IStrategy): return dataframe def feature_engineering_standard( - self, dataframe: DataFrame, metadata: dict, **kwargs + self, dataframe: DataFrame, metadata: dict[str, Any], **kwargs ) -> DataFrame: dates = dataframe.get("date") dataframe["%-day_of_week"] = (dates.dt.dayofweek + 1) / 7 @@ -92,18 +92,22 @@ class RLAgentStrategy(IStrategy): return dataframe def set_freqai_targets( - self, dataframe: DataFrame, metadata: dict, **kwargs + self, dataframe: DataFrame, metadata: dict[str, Any], **kwargs ) -> DataFrame: dataframe[ACTION_COLUMN] = 0 return dataframe - def populate_indicators(self, dataframe: DataFrame, metadata: dict) -> DataFrame: + def populate_indicators( + self, dataframe: DataFrame, metadata: dict[str, Any] + ) -> DataFrame: dataframe = self.freqai.start(dataframe, metadata, self) return dataframe - def populate_entry_trend(self, df: DataFrame, metadata: dict) -> DataFrame: + def populate_entry_trend( + self, df: DataFrame, metadata: dict[str, Any] + ) -> DataFrame: enter_long_conditions = [df.get("do_predict") == 1, df.get(ACTION_COLUMN) == 1] df.loc[ @@ -120,7 +124,7 @@ class RLAgentStrategy(IStrategy): return df - def populate_exit_trend(self, df: DataFrame, metadata: dict) -> DataFrame: + def populate_exit_trend(self, df: DataFrame, metadata: dict[str, Any]) -> DataFrame: exit_long_conditions = [df.get("do_predict") == 1, df.get(ACTION_COLUMN) == 2] df.loc[reduce(lambda x, y: x & y, exit_long_conditions), "exit_long"] = 1 diff --git a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py index 3469f94..8d9f367 100644 --- a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py +++ b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py @@ -53,7 +53,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel): version = "3.7.94" @cached_property - def _optuna_config(self) -> dict: + def _optuna_config(self) -> dict[str, Any]: optuna_default_config = { "enabled": False, "n_jobs": min( @@ -96,9 +96,9 @@ class QuickAdapterRegressorV3(BaseRegressionModel): self._optuna_hp_value: dict[str, float] = {} self._optuna_train_value: dict[str, float] = {} self._optuna_label_values: dict[str, list] = {} - self._optuna_hp_params: dict[str, dict] = {} - self._optuna_train_params: dict[str, dict] = {} - self._optuna_label_params: dict[str, dict] = {} + self._optuna_hp_params: dict[str, dict[str, Any]] = {} + self._optuna_train_params: dict[str, dict[str, Any]] = {} + self._optuna_label_params: dict[str, dict[str, Any]] = {} self.init_optuna_label_candle_pool() self._optuna_label_candle: dict[str, int] = {} self._optuna_label_candles: dict[str, int] = {} diff --git a/quickadapter/user_data/strategies/QuickAdapterV3.py b/quickadapter/user_data/strategies/QuickAdapterV3.py index 16794dc..a2fe46d 100644 --- a/quickadapter/user_data/strategies/QuickAdapterV3.py +++ b/quickadapter/user_data/strategies/QuickAdapterV3.py @@ -167,7 +167,7 @@ class QuickAdapterV3(IStrategy): / "models" / self.freqai_info.get("identifier") ) - self._label_params: dict[str, dict] = {} + self._label_params: dict[str, dict[str, Any]] = {} for pair in self.pairs: self._label_params[pair] = ( self.optuna_load_best_params(pair, "label") @@ -194,7 +194,7 @@ class QuickAdapterV3(IStrategy): ) def feature_engineering_expand_all( - self, dataframe: DataFrame, period: int, metadata: dict, **kwargs + self, dataframe: DataFrame, period: int, metadata: dict[str, Any], **kwargs ) -> DataFrame: highs = dataframe.get("high") lows = dataframe.get("low") @@ -234,7 +234,7 @@ class QuickAdapterV3(IStrategy): return dataframe def feature_engineering_expand_basic( - self, dataframe: DataFrame, metadata: dict, **kwargs + self, dataframe: DataFrame, metadata: dict[str, Any], **kwargs ) -> DataFrame: highs = dataframe.get("high") lows = dataframe.get("low") @@ -407,7 +407,7 @@ class QuickAdapterV3(IStrategy): raise ValueError(f"Invalid pattern '{pattern}': {e}") def set_freqai_targets( - self, dataframe: DataFrame, metadata: dict, **kwargs + self, dataframe: DataFrame, metadata: dict[str, Any], **kwargs ) -> DataFrame: pair = str(metadata.get("pair")) label_period_candles = self.get_label_period_candles(pair) @@ -445,7 +445,9 @@ class QuickAdapterV3(IStrategy): logger.info(f"{n_extrema=}") return dataframe - def populate_indicators(self, dataframe: DataFrame, metadata: dict) -> DataFrame: + def populate_indicators( + self, dataframe: DataFrame, metadata: dict[str, Any] + ) -> DataFrame: dataframe = self.freqai.start(dataframe, metadata, self) dataframe["DI_catch"] = np.where( @@ -470,7 +472,9 @@ class QuickAdapterV3(IStrategy): return dataframe - def populate_entry_trend(self, df: DataFrame, metadata: dict) -> DataFrame: + def populate_entry_trend( + self, df: DataFrame, metadata: dict[str, Any] + ) -> DataFrame: enter_long_conditions = [ df.get("do_predict") == 1, df.get("DI_catch") == 1, @@ -495,7 +499,7 @@ class QuickAdapterV3(IStrategy): return df - def populate_exit_trend(self, df: DataFrame, metadata: dict) -> DataFrame: + def populate_exit_trend(self, df: DataFrame, metadata: dict[str, Any]) -> DataFrame: return df def get_trade_entry_date(self, trade: Trade) -> datetime.datetime: diff --git a/quickadapter/user_data/strategies/Utils.py b/quickadapter/user_data/strategies/Utils.py index 8e31e37..40b5611 100644 --- a/quickadapter/user_data/strategies/Utils.py +++ b/quickadapter/user_data/strategies/Utils.py @@ -160,7 +160,7 @@ def calculate_zero_lag(series: pd.Series, period: int) -> pd.Series: def get_ma_fn(mamode: str) -> Callable[[pd.Series, int], np.ndarray]: - mamodes: dict = { + mamodes: dict[str, Callable[[pd.Series, int], np.ndarray]] = { "sma": ta.SMA, "ema": ta.EMA, "wma": ta.WMA, -- 2.43.0