]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor(qav3): refine typing
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Wed, 18 Jun 2025 15:10:19 +0000 (17:10 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Wed, 18 Jun 2025 15:10:19 +0000 (17:10 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py
ReforceXY/user_data/strategies/RLAgentStrategy.py
quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py
quickadapter/user_data/strategies/QuickAdapterV3.py
quickadapter/user_data/strategies/Utils.py

index d6933042984311fe844931a1ad960be04b8ea3c8..81e77ca1881517a1080610103f52aaaef362b008 100644 (file)
@@ -215,11 +215,11 @@ class ReforceXY(BaseReinforcementLearningModel):
             raise ValueError("Frame stacking requires predefined observation shape")
         self.eval_env = VecMonitor(eval_env)
 
-    def get_model_params(self) -> Dict:
+    def get_model_params(self) -> Dict[str, Any]:
         """
         Get model parameters
         """
-        model_params: Dict = copy.deepcopy(self.model_training_parameters)
+        model_params: Dict[str, Any] = copy.deepcopy(self.model_training_parameters)
 
         if self.lr_schedule:
             _lr = model_params.get("learning_rate", 0.0003)
@@ -311,7 +311,9 @@ class ReforceXY(BaseReinforcementLearningModel):
         callbacks.append(self.optuna_callback)
         return callbacks
 
-    def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs):
+    def fit(
+        self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs
+    ) -> Any:
         """
         User customizable fit method
         :param data_dictionary: dict = common data dictionary containing all train/test
@@ -525,7 +527,7 @@ class ReforceXY(BaseReinforcementLearningModel):
 
     def study(
         self, train_df: DataFrame, total_timesteps: int, dk: FreqaiDataKitchen
-    ) -> Optional[Dict]:
+    ) -> Optional[Dict[str, Any]]:
         """
         Runs hyperparameter optimization using Optuna and
         returns the best hyperparameters found merged with the user defined parameters
@@ -736,7 +738,7 @@ class ReforceXY(BaseReinforcementLearningModel):
 
         return self.optuna_callback.best_mean_reward
 
-    def close_envs(self):
+    def close_envs(self) -> None:
         """
         Closes the training and evaluation environments if they are open
         """
@@ -775,9 +777,9 @@ class ReforceXY(BaseReinforcementLearningModel):
             df: DataFrame,
             prices: DataFrame,
             window_size: int,
-            reward_kwargs: dict,
+            reward_kwargs: Dict[str, Any],
             starting_point=True,
-        ):
+        ) -> None:
             """
             Resets the environment when the agent fails
             """
@@ -794,7 +796,7 @@ class ReforceXY(BaseReinforcementLearningModel):
                 low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float32
             )
 
-        def reset(self, seed=None, **kwargs):
+        def reset(self, seed=None, **kwargs) -> Tuple[np.ndarray, Dict[str, Any]]:
             """
             Reset is called at the beginning of every episode
             """
@@ -947,7 +949,7 @@ class ReforceXY(BaseReinforcementLearningModel):
 
             return 0.0
 
-        def _get_observation(self):
+        def _get_observation(self) -> np.ndarray:
             """
             This may or may not be independent of action types, user can inherit
             this in their custom "MyRLEnv"
@@ -999,11 +1001,11 @@ class ReforceXY(BaseReinforcementLearningModel):
                 Actions.Short_enter.value: Positions.Short,
             }[action]
 
-        def _enter_trade(self, action):
+        def _enter_trade(self, action: int) -> None:
             self._position = self._get_new_position(action)
             self._last_trade_tick = self._current_tick
 
-        def _exit_trade(self):
+        def _exit_trade(self) -> None:
             self._update_total_profit()
             self._last_closed_position = self._position
             self._position = Positions.Neutral
@@ -1036,7 +1038,9 @@ class ReforceXY(BaseReinforcementLearningModel):
                 self._exit_trade()
                 self.append_trade_history(f"{self._last_closed_position.name}_exit")
 
-        def step(self, action: int):
+        def step(
+            self, action: int
+        ) -> Tuple[np.ndarray, float, bool, bool, Dict[str, Any]]:
             """
             Take a step in the environment based on the provided action
             """
@@ -1070,7 +1074,7 @@ class ReforceXY(BaseReinforcementLearningModel):
                 info,
             )
 
-        def append_trade_history(self, trade_type: str):
+        def append_trade_history(self, trade_type: str) -> None:
             self.trade_history.append(
                 {
                     "tick": self._current_tick,
@@ -1228,7 +1232,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             )
             return history
 
-        def get_env_plot(self):
+        def get_env_plot(self) -> plt.Figure:
             """
             Plot trades and environment data
             """
@@ -1405,7 +1409,7 @@ class RolloutPlotCallback(BaseCallback):
     Tensorboard plot callback
     """
 
-    def record_env(self):
+    def record_env(self) -> bool:
         figures = self.training_env.env_method("get_env_plot")
         for i, fig in enumerate(figures):
             figure = Figure(fig, close=True)
index f3075568a49e35d764209ac9739cbb07ed07bbce..15581fd722768c7bffbc09e6b884074412afec3c 100644 (file)
@@ -1,5 +1,6 @@
 import logging
 from functools import cached_property, reduce
+# from typing import Any
 
 # import talib.abstract as ta
 from pandas import DataFrame
@@ -35,7 +36,7 @@ class RLAgentStrategy(IStrategy):
     startup_candle_count: int = 300
 
     # @cached_property
-    # def protections(self):
+    # def protections(self) -> list[dict[str, Any]]:
     #     fit_live_predictions_candles = self.freqai_info.get(
     #         "fit_live_predictions_candles", 100
     #     )
@@ -58,19 +59,19 @@ class RLAgentStrategy(IStrategy):
     #     ]
 
     @cached_property
-    def can_short(self):
+    def can_short(self) -> bool:
         return self.is_short_allowed()
 
     # def feature_engineering_expand_all(
     #     self, dataframe: DataFrame, period: int, metadata: dict, **kwargs
-    # ):
+    # ) -> DataFrame:
     #     dataframe["%-rsi-period"] = ta.RSI(dataframe, timeperiod=period)
 
     #     return dataframe
 
     def feature_engineering_expand_basic(
         self, dataframe: DataFrame, metadata: dict, **kwargs
-    ):
+    ) -> DataFrame:
         dataframe["%-close_pct_change"] = dataframe.get("close").pct_change()
         dataframe["%-raw_volume"] = dataframe.get("volume")
 
@@ -78,7 +79,7 @@ class RLAgentStrategy(IStrategy):
 
     def feature_engineering_standard(
         self, dataframe: DataFrame, metadata: dict, **kwargs
-    ):
+    ) -> DataFrame:
         dates = dataframe.get("date")
         dataframe["%-day_of_week"] = (dates.dt.dayofweek + 1) / 7
         dataframe["%-hour_of_day"] = (dates.dt.hour + 1) / 25
@@ -90,7 +91,9 @@ class RLAgentStrategy(IStrategy):
 
         return dataframe
 
-    def set_freqai_targets(self, dataframe: DataFrame, metadata: dict, **kwargs):
+    def set_freqai_targets(
+        self, dataframe: DataFrame, metadata: dict, **kwargs
+    ) -> DataFrame:
         dataframe[ACTION_COLUMN] = 0
 
         return dataframe
index 2e7935f78b70fc81fde2619cb4dde3ed22a84976..3469f946c31b459bd1dc171493833d47c3e3d8f4 100644 (file)
@@ -147,7 +147,9 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
             raise ValueError(f"Invalid namespace: {namespace}")
         return params
 
-    def set_optuna_params(self, pair: str, namespace: str, params: dict) -> None:
+    def set_optuna_params(
+        self, pair: str, namespace: str, params: dict[str, Any]
+    ) -> None:
         if namespace == "hp":
             self._optuna_hp_params[pair] = params
         elif namespace == "train":
@@ -237,7 +239,9 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
             self._optuna_label_candle_pool.extend(optuna_label_available_candles)
             random.shuffle(self._optuna_label_candle_pool)
 
-    def fit(self, data_dictionary: dict, dk: FreqaiDataKitchen, **kwargs) -> Any:
+    def fit(
+        self, data_dictionary: dict[str, Any], dk: FreqaiDataKitchen, **kwargs
+    ) -> Any:
         """
         User sets up the training and test data to fit their desired model here
         :param data_dictionary: the dictionary constructed by DataHandler to hold
@@ -473,7 +477,9 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
 
     def eval_set_and_weights(
         self, X_test: pd.DataFrame, y_test: pd.DataFrame, test_weights: np.ndarray
-    ) -> tuple[list[tuple] | None, list[np.ndarray] | None]:
+    ) -> tuple[
+        Optional[list[tuple[pd.DataFrame, pd.DataFrame]]], Optional[list[np.ndarray]]
+    ]:
         if self.data_split_parameters.get("test_size", TEST_SIZE) == 0:
             eval_set = None
             eval_weights = None
@@ -955,7 +961,9 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
             )
             raise
 
-    def optuna_load_best_params(self, pair: str, namespace: str) -> Optional[dict]:
+    def optuna_load_best_params(
+        self, pair: str, namespace: str
+    ) -> Optional[dict[str, Any]]:
         best_params_path = Path(
             self.full_path / f"optuna-{namespace}-best-params-{pair.split('/')[0]}.json"
         )
@@ -1029,11 +1037,11 @@ def fit_regressor(
     X: pd.DataFrame,
     y: pd.DataFrame,
     train_weights: np.ndarray,
-    eval_set: Optional[list[tuple]],
+    eval_set: Optional[list[tuple[pd.DataFrame, pd.DataFrame]]],
     eval_weights: Optional[list[np.ndarray]],
-    model_training_parameters: dict,
+    model_training_parameters: dict[str, Any],
     init_model: Any = None,
-    callbacks: list[Callable] = None,
+    callbacks: Optional[list[Callable]] = None,
 ) -> Any:
     if regressor == "xgboost":
         from xgboost import XGBRegressor
@@ -1083,7 +1091,7 @@ def train_objective(
     test_size: float,
     fit_live_predictions_candles: int,
     candles_step: int,
-    model_training_parameters: dict,
+    model_training_parameters: dict[str, Any],
 ) -> float:
     def calculate_min_extrema(
         length: int, fit_live_predictions_candles: int, min_extrema: int = 2
@@ -1184,7 +1192,7 @@ def train_objective(
 
 def get_optuna_study_model_parameters(
     trial: optuna.trial.Trial, regressor: str
-) -> dict:
+) -> dict[str, Any]:
     study_model_parameters = {
         "learning_rate": trial.suggest_float("learning_rate", 1e-3, 0.3, log=True),
         "min_child_weight": trial.suggest_float(
@@ -1224,7 +1232,7 @@ def hp_objective(
     X_test: pd.DataFrame,
     y_test: pd.DataFrame,
     test_weights: np.ndarray,
-    model_training_parameters: dict,
+    model_training_parameters: dict[str, Any],
 ) -> float:
     study_model_parameters = get_optuna_study_model_parameters(trial, regressor)
     model_training_parameters = {**model_training_parameters, **study_model_parameters}
@@ -1343,7 +1351,7 @@ def zigzag(
         last_pivot_pos = pos
         reset_candidate_pivot()
 
-    slope_ok_cache: dict[tuple[int, int, int, float]] = {}
+    slope_ok_cache: dict[tuple[int, int, int, float], bool] = {}
 
     def get_slope_ok(
         pos: int,
index c80345188a780fe3c2ad191dce572ade277dc5ed..16794dcc60418dbe8e41a4ed061a32c2461659e5 100644 (file)
@@ -6,7 +6,7 @@ import math
 from pathlib import Path
 import talib.abstract as ta
 from pandas import DataFrame, Series, isna
-from typing import Callable, Optional
+from typing import Any, Callable, Optional
 from freqtrade.exchange import timeframe_to_minutes, timeframe_to_prev_date
 from freqtrade.strategy.interface import IStrategy
 from freqtrade.strategy import stoploss_from_absolute
@@ -120,7 +120,7 @@ class QuickAdapterV3(IStrategy):
         }
 
     @cached_property
-    def protections(self) -> list[dict]:
+    def protections(self) -> list[dict[str, Any]]:
         fit_live_predictions_candles = self.freqai_info.get(
             "fit_live_predictions_candles", 100
         )
@@ -195,7 +195,7 @@ class QuickAdapterV3(IStrategy):
 
     def feature_engineering_expand_all(
         self, dataframe: DataFrame, period: int, metadata: dict, **kwargs
-    ):
+    ) -> DataFrame:
         highs = dataframe.get("high")
         lows = dataframe.get("low")
         closes = dataframe.get("close")
@@ -235,7 +235,7 @@ class QuickAdapterV3(IStrategy):
 
     def feature_engineering_expand_basic(
         self, dataframe: DataFrame, metadata: dict, **kwargs
-    ):
+    ) -> DataFrame:
         highs = dataframe.get("high")
         lows = dataframe.get("low")
         opens = dataframe.get("open")
@@ -350,7 +350,7 @@ class QuickAdapterV3(IStrategy):
         dataframe["%-raw_high"] = highs
         return dataframe
 
-    def feature_engineering_standard(self, dataframe: DataFrame, **kwargs):
+    def feature_engineering_standard(self, dataframe: DataFrame, **kwargs) -> DataFrame:
         dates = dataframe.get("date")
 
         dataframe["%-day_of_week"] = (dates.dt.dayofweek + 1) / 7
@@ -365,7 +365,7 @@ class QuickAdapterV3(IStrategy):
             return label_period_candles
         return self.freqai_info["feature_parameters"].get("label_period_candles", 50)
 
-    def set_label_period_candles(self, pair: str, label_period_candles: int):
+    def set_label_period_candles(self, pair: str, label_period_candles: int) -> None:
         if isinstance(label_period_candles, int):
             self._label_params[pair]["label_period_candles"] = label_period_candles
 
@@ -377,7 +377,7 @@ class QuickAdapterV3(IStrategy):
             self.freqai_info["feature_parameters"].get("label_natr_ratio", 6.0)
         )
 
-    def set_label_natr_ratio(self, pair: str, label_natr_ratio: float):
+    def set_label_natr_ratio(self, pair: str, label_natr_ratio: float) -> None:
         if isinstance(label_natr_ratio, float) and np.isfinite(label_natr_ratio):
             self._label_params[pair]["label_natr_ratio"] = label_natr_ratio
 
@@ -406,7 +406,9 @@ class QuickAdapterV3(IStrategy):
         except (KeyError, ValueError) as e:
             raise ValueError(f"Invalid pattern '{pattern}': {e}")
 
-    def set_freqai_targets(self, dataframe: DataFrame, metadata: dict, **kwargs):
+    def set_freqai_targets(
+        self, dataframe: DataFrame, metadata: dict, **kwargs
+    ) -> DataFrame:
         pair = str(metadata.get("pair"))
         label_period_candles = self.get_label_period_candles(pair)
         label_natr_ratio = self.get_label_natr_ratio(pair)
index 0afca65ce3f74ebd935d001fac95b21ac0e77f80..8e31e37eacd3b232f40c5c19fb6641c99176f293 100644 (file)
@@ -5,13 +5,13 @@ import numpy as np
 import pandas as pd
 import scipy as sp
 import talib.abstract as ta
-from typing import Callable, Literal, Union
+from typing import Callable, Literal, TypeVar
 from technical import qtpylib
 
+T = TypeVar("T", pd.Series, float)
 
-def get_distance(
-    p1: Union[pd.Series, float], p2: Union[pd.Series, float]
-) -> Union[pd.Series, float]:
+
+def get_distance(p1: T, p2: T) -> T:
     return abs(p1 - p2)
 
 
@@ -141,7 +141,9 @@ def price_retracement_percent(dataframe: pd.DataFrame, period: int) -> pd.Series
 
 
 # VWAP bands
-def vwapb(dataframe: pd.DataFrame, window: int = 20, std_factor: float = 1.0) -> tuple:
+def vwapb(
+    dataframe: pd.DataFrame, window: int = 20, std_factor: float = 1.0
+) -> tuple[pd.Series, pd.Series, pd.Series]:
     vwap = qtpylib.rolling_vwap(dataframe, window=window)
     rolling_std = vwap.rolling(window=window, min_periods=window).std()
     vwap_low = vwap - (rolling_std * std_factor)
@@ -471,7 +473,7 @@ def zigzag(
         last_pivot_pos = pos
         reset_candidate_pivot()
 
-    slope_ok_cache: dict[tuple[int, int, int, float]] = {}
+    slope_ok_cache: dict[tuple[int, int, int, float], bool] = {}
 
     def get_slope_ok(
         pos: int,