]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor: refine typing
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Wed, 18 Jun 2025 18:00:15 +0000 (20:00 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Wed, 18 Jun 2025 18:00:15 +0000 (20:00 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py
ReforceXY/user_data/strategies/RLAgentStrategy.py
quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py
quickadapter/user_data/strategies/QuickAdapterV3.py
quickadapter/user_data/strategies/Utils.py

index 81e77ca1881517a1080610103f52aaaef362b008..c2e37346852ea65e2cda6efe33866d0287ff017c 100644 (file)
@@ -120,7 +120,9 @@ class ReforceXY(BaseReinforcementLearningModel):
         self.check_envs: bool = self.rl_config.get("check_envs", True)
         self.progressbar_callback: Optional[ProgressBarCallback] = None
         # Optuna hyperopt
-        self.rl_config_optuna: dict = self.freqai_info.get("rl_config_optuna", {})
+        self.rl_config_optuna: Dict[str, Any] = self.freqai_info.get(
+            "rl_config_optuna", {}
+        )
         self.hyperopt: bool = (
             self.freqai_info.get("enabled", False)
             and self.rl_config_optuna.get("enabled", False)
@@ -459,7 +461,7 @@ class ReforceXY(BaseReinforcementLearningModel):
 
         def _predict(window):
             observation: DataFrame = dataframe.iloc[window.index]
-            action_masks_param: dict = {}
+            action_masks_param: Dict[str, Any] = {}
 
             if self.live and self.rl_config.get("add_state_info", False):
                 position, pnl, trade_duration = self.get_state_info(dk.pair)
index 15581fd722768c7bffbc09e6b884074412afec3c..6903c6613a54662a0ac1b332c6ddb609b85551b6 100644 (file)
@@ -1,6 +1,6 @@
 import logging
 from functools import cached_property, reduce
-from typing import Any
+from typing import Any
 
 # import talib.abstract as ta
 from pandas import DataFrame
@@ -63,14 +63,14 @@ class RLAgentStrategy(IStrategy):
         return self.is_short_allowed()
 
     # def feature_engineering_expand_all(
-    #     self, dataframe: DataFrame, period: int, metadata: dict, **kwargs
+    #     self, dataframe: DataFrame, period: int, metadata: dict[str, Any], **kwargs
     # ) -> DataFrame:
     #     dataframe["%-rsi-period"] = ta.RSI(dataframe, timeperiod=period)
 
     #     return dataframe
 
     def feature_engineering_expand_basic(
-        self, dataframe: DataFrame, metadata: dict, **kwargs
+        self, dataframe: DataFrame, metadata: dict[str, Any], **kwargs
     ) -> DataFrame:
         dataframe["%-close_pct_change"] = dataframe.get("close").pct_change()
         dataframe["%-raw_volume"] = dataframe.get("volume")
@@ -78,7 +78,7 @@ class RLAgentStrategy(IStrategy):
         return dataframe
 
     def feature_engineering_standard(
-        self, dataframe: DataFrame, metadata: dict, **kwargs
+        self, dataframe: DataFrame, metadata: dict[str, Any], **kwargs
     ) -> DataFrame:
         dates = dataframe.get("date")
         dataframe["%-day_of_week"] = (dates.dt.dayofweek + 1) / 7
@@ -92,18 +92,22 @@ class RLAgentStrategy(IStrategy):
         return dataframe
 
     def set_freqai_targets(
-        self, dataframe: DataFrame, metadata: dict, **kwargs
+        self, dataframe: DataFrame, metadata: dict[str, Any], **kwargs
     ) -> DataFrame:
         dataframe[ACTION_COLUMN] = 0
 
         return dataframe
 
-    def populate_indicators(self, dataframe: DataFrame, metadata: dict) -> DataFrame:
+    def populate_indicators(
+        self, dataframe: DataFrame, metadata: dict[str, Any]
+    ) -> DataFrame:
         dataframe = self.freqai.start(dataframe, metadata, self)
 
         return dataframe
 
-    def populate_entry_trend(self, df: DataFrame, metadata: dict) -> DataFrame:
+    def populate_entry_trend(
+        self, df: DataFrame, metadata: dict[str, Any]
+    ) -> DataFrame:
         enter_long_conditions = [df.get("do_predict") == 1, df.get(ACTION_COLUMN) == 1]
 
         df.loc[
@@ -120,7 +124,7 @@ class RLAgentStrategy(IStrategy):
 
         return df
 
-    def populate_exit_trend(self, df: DataFrame, metadata: dict) -> DataFrame:
+    def populate_exit_trend(self, df: DataFrame, metadata: dict[str, Any]) -> DataFrame:
         exit_long_conditions = [df.get("do_predict") == 1, df.get(ACTION_COLUMN) == 2]
         df.loc[reduce(lambda x, y: x & y, exit_long_conditions), "exit_long"] = 1
 
index 3469f946c31b459bd1dc171493833d47c3e3d8f4..8d9f3674b79768c7c02ac8be84d357106a90ead0 100644 (file)
@@ -53,7 +53,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
     version = "3.7.94"
 
     @cached_property
-    def _optuna_config(self) -> dict:
+    def _optuna_config(self) -> dict[str, Any]:
         optuna_default_config = {
             "enabled": False,
             "n_jobs": min(
@@ -96,9 +96,9 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
         self._optuna_hp_value: dict[str, float] = {}
         self._optuna_train_value: dict[str, float] = {}
         self._optuna_label_values: dict[str, list] = {}
-        self._optuna_hp_params: dict[str, dict] = {}
-        self._optuna_train_params: dict[str, dict] = {}
-        self._optuna_label_params: dict[str, dict] = {}
+        self._optuna_hp_params: dict[str, dict[str, Any]] = {}
+        self._optuna_train_params: dict[str, dict[str, Any]] = {}
+        self._optuna_label_params: dict[str, dict[str, Any]] = {}
         self.init_optuna_label_candle_pool()
         self._optuna_label_candle: dict[str, int] = {}
         self._optuna_label_candles: dict[str, int] = {}
index 16794dcc60418dbe8e41a4ed061a32c2461659e5..a2fe46d52f6f3d410cff7d5bdf962bab329ff88e 100644 (file)
@@ -167,7 +167,7 @@ class QuickAdapterV3(IStrategy):
             / "models"
             / self.freqai_info.get("identifier")
         )
-        self._label_params: dict[str, dict] = {}
+        self._label_params: dict[str, dict[str, Any]] = {}
         for pair in self.pairs:
             self._label_params[pair] = (
                 self.optuna_load_best_params(pair, "label")
@@ -194,7 +194,7 @@ class QuickAdapterV3(IStrategy):
         )
 
     def feature_engineering_expand_all(
-        self, dataframe: DataFrame, period: int, metadata: dict, **kwargs
+        self, dataframe: DataFrame, period: int, metadata: dict[str, Any], **kwargs
     ) -> DataFrame:
         highs = dataframe.get("high")
         lows = dataframe.get("low")
@@ -234,7 +234,7 @@ class QuickAdapterV3(IStrategy):
         return dataframe
 
     def feature_engineering_expand_basic(
-        self, dataframe: DataFrame, metadata: dict, **kwargs
+        self, dataframe: DataFrame, metadata: dict[str, Any], **kwargs
     ) -> DataFrame:
         highs = dataframe.get("high")
         lows = dataframe.get("low")
@@ -407,7 +407,7 @@ class QuickAdapterV3(IStrategy):
             raise ValueError(f"Invalid pattern '{pattern}': {e}")
 
     def set_freqai_targets(
-        self, dataframe: DataFrame, metadata: dict, **kwargs
+        self, dataframe: DataFrame, metadata: dict[str, Any], **kwargs
     ) -> DataFrame:
         pair = str(metadata.get("pair"))
         label_period_candles = self.get_label_period_candles(pair)
@@ -445,7 +445,9 @@ class QuickAdapterV3(IStrategy):
             logger.info(f"{n_extrema=}")
         return dataframe
 
-    def populate_indicators(self, dataframe: DataFrame, metadata: dict) -> DataFrame:
+    def populate_indicators(
+        self, dataframe: DataFrame, metadata: dict[str, Any]
+    ) -> DataFrame:
         dataframe = self.freqai.start(dataframe, metadata, self)
 
         dataframe["DI_catch"] = np.where(
@@ -470,7 +472,9 @@ class QuickAdapterV3(IStrategy):
 
         return dataframe
 
-    def populate_entry_trend(self, df: DataFrame, metadata: dict) -> DataFrame:
+    def populate_entry_trend(
+        self, df: DataFrame, metadata: dict[str, Any]
+    ) -> DataFrame:
         enter_long_conditions = [
             df.get("do_predict") == 1,
             df.get("DI_catch") == 1,
@@ -495,7 +499,7 @@ class QuickAdapterV3(IStrategy):
 
         return df
 
-    def populate_exit_trend(self, df: DataFrame, metadata: dict) -> DataFrame:
+    def populate_exit_trend(self, df: DataFrame, metadata: dict[str, Any]) -> DataFrame:
         return df
 
     def get_trade_entry_date(self, trade: Trade) -> datetime.datetime:
index 8e31e37eacd3b232f40c5c19fb6641c99176f293..40b5611ac78c1d6a397f0dcd1c4483716fe3237f 100644 (file)
@@ -160,7 +160,7 @@ def calculate_zero_lag(series: pd.Series, period: int) -> pd.Series:
 
 
 def get_ma_fn(mamode: str) -> Callable[[pd.Series, int], np.ndarray]:
-    mamodes: dict = {
+    mamodes: dict[str, Callable[[pd.Series, int], np.ndarray]] = {
         "sma": ta.SMA,
         "ema": ta.EMA,
         "wma": ta.WMA,