From bb13044bc0555c545e86ce2a0fd59d9ee0c0baf2 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Wed, 18 Jun 2025 17:10:19 +0200 Subject: [PATCH] refactor(qav3): refine typing MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/user_data/freqaimodels/ReforceXY.py | 34 +++++++++++-------- .../user_data/strategies/RLAgentStrategy.py | 15 ++++---- .../freqaimodels/QuickAdapterRegressorV3.py | 30 ++++++++++------ .../user_data/strategies/QuickAdapterV3.py | 18 +++++----- quickadapter/user_data/strategies/Utils.py | 14 ++++---- 5 files changed, 65 insertions(+), 46 deletions(-) diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index d693304..81e77ca 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -215,11 +215,11 @@ class ReforceXY(BaseReinforcementLearningModel): raise ValueError("Frame stacking requires predefined observation shape") self.eval_env = VecMonitor(eval_env) - def get_model_params(self) -> Dict: + def get_model_params(self) -> Dict[str, Any]: """ Get model parameters """ - model_params: Dict = copy.deepcopy(self.model_training_parameters) + model_params: Dict[str, Any] = copy.deepcopy(self.model_training_parameters) if self.lr_schedule: _lr = model_params.get("learning_rate", 0.0003) @@ -311,7 +311,9 @@ class ReforceXY(BaseReinforcementLearningModel): callbacks.append(self.optuna_callback) return callbacks - def fit(self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs): + def fit( + self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs + ) -> Any: """ User customizable fit method :param data_dictionary: dict = common data dictionary containing all train/test @@ -525,7 +527,7 @@ class ReforceXY(BaseReinforcementLearningModel): def study( self, train_df: DataFrame, total_timesteps: int, dk: FreqaiDataKitchen - ) -> Optional[Dict]: + ) -> Optional[Dict[str, Any]]: """ Runs hyperparameter optimization using Optuna and returns the best hyperparameters found merged with the user defined parameters @@ -736,7 +738,7 @@ class ReforceXY(BaseReinforcementLearningModel): return self.optuna_callback.best_mean_reward - def close_envs(self): + def close_envs(self) -> None: """ Closes the training and evaluation environments if they are open """ @@ -775,9 +777,9 @@ class ReforceXY(BaseReinforcementLearningModel): df: DataFrame, prices: DataFrame, window_size: int, - reward_kwargs: dict, + reward_kwargs: Dict[str, Any], starting_point=True, - ): + ) -> None: """ Resets the environment when the agent fails """ @@ -794,7 +796,7 @@ class ReforceXY(BaseReinforcementLearningModel): low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float32 ) - def reset(self, seed=None, **kwargs): + def reset(self, seed=None, **kwargs) -> Tuple[np.ndarray, Dict[str, Any]]: """ Reset is called at the beginning of every episode """ @@ -947,7 +949,7 @@ class ReforceXY(BaseReinforcementLearningModel): return 0.0 - def _get_observation(self): + def _get_observation(self) -> np.ndarray: """ This may or may not be independent of action types, user can inherit this in their custom "MyRLEnv" @@ -999,11 +1001,11 @@ class ReforceXY(BaseReinforcementLearningModel): Actions.Short_enter.value: Positions.Short, }[action] - def _enter_trade(self, action): + def _enter_trade(self, action: int) -> None: self._position = self._get_new_position(action) self._last_trade_tick = self._current_tick - def _exit_trade(self): + def _exit_trade(self) -> None: self._update_total_profit() self._last_closed_position = self._position self._position = Positions.Neutral @@ -1036,7 +1038,9 @@ class ReforceXY(BaseReinforcementLearningModel): self._exit_trade() self.append_trade_history(f"{self._last_closed_position.name}_exit") - def step(self, action: int): + def step( + self, action: int + ) -> Tuple[np.ndarray, float, bool, bool, Dict[str, Any]]: """ Take a step in the environment based on the provided action """ @@ -1070,7 +1074,7 @@ class ReforceXY(BaseReinforcementLearningModel): info, ) - def append_trade_history(self, trade_type: str): + def append_trade_history(self, trade_type: str) -> None: self.trade_history.append( { "tick": self._current_tick, @@ -1228,7 +1232,7 @@ class ReforceXY(BaseReinforcementLearningModel): ) return history - def get_env_plot(self): + def get_env_plot(self) -> plt.Figure: """ Plot trades and environment data """ @@ -1405,7 +1409,7 @@ class RolloutPlotCallback(BaseCallback): Tensorboard plot callback """ - def record_env(self): + def record_env(self) -> bool: figures = self.training_env.env_method("get_env_plot") for i, fig in enumerate(figures): figure = Figure(fig, close=True) diff --git a/ReforceXY/user_data/strategies/RLAgentStrategy.py b/ReforceXY/user_data/strategies/RLAgentStrategy.py index f307556..15581fd 100644 --- a/ReforceXY/user_data/strategies/RLAgentStrategy.py +++ b/ReforceXY/user_data/strategies/RLAgentStrategy.py @@ -1,5 +1,6 @@ import logging from functools import cached_property, reduce +# from typing import Any # import talib.abstract as ta from pandas import DataFrame @@ -35,7 +36,7 @@ class RLAgentStrategy(IStrategy): startup_candle_count: int = 300 # @cached_property - # def protections(self): + # def protections(self) -> list[dict[str, Any]]: # fit_live_predictions_candles = self.freqai_info.get( # "fit_live_predictions_candles", 100 # ) @@ -58,19 +59,19 @@ class RLAgentStrategy(IStrategy): # ] @cached_property - def can_short(self): + def can_short(self) -> bool: return self.is_short_allowed() # def feature_engineering_expand_all( # self, dataframe: DataFrame, period: int, metadata: dict, **kwargs - # ): + # ) -> DataFrame: # dataframe["%-rsi-period"] = ta.RSI(dataframe, timeperiod=period) # return dataframe def feature_engineering_expand_basic( self, dataframe: DataFrame, metadata: dict, **kwargs - ): + ) -> DataFrame: dataframe["%-close_pct_change"] = dataframe.get("close").pct_change() dataframe["%-raw_volume"] = dataframe.get("volume") @@ -78,7 +79,7 @@ class RLAgentStrategy(IStrategy): def feature_engineering_standard( self, dataframe: DataFrame, metadata: dict, **kwargs - ): + ) -> DataFrame: dates = dataframe.get("date") dataframe["%-day_of_week"] = (dates.dt.dayofweek + 1) / 7 dataframe["%-hour_of_day"] = (dates.dt.hour + 1) / 25 @@ -90,7 +91,9 @@ class RLAgentStrategy(IStrategy): return dataframe - def set_freqai_targets(self, dataframe: DataFrame, metadata: dict, **kwargs): + def set_freqai_targets( + self, dataframe: DataFrame, metadata: dict, **kwargs + ) -> DataFrame: dataframe[ACTION_COLUMN] = 0 return dataframe diff --git a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py index 2e7935f..3469f94 100644 --- a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py +++ b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py @@ -147,7 +147,9 @@ class QuickAdapterRegressorV3(BaseRegressionModel): raise ValueError(f"Invalid namespace: {namespace}") return params - def set_optuna_params(self, pair: str, namespace: str, params: dict) -> None: + def set_optuna_params( + self, pair: str, namespace: str, params: dict[str, Any] + ) -> None: if namespace == "hp": self._optuna_hp_params[pair] = params elif namespace == "train": @@ -237,7 +239,9 @@ class QuickAdapterRegressorV3(BaseRegressionModel): self._optuna_label_candle_pool.extend(optuna_label_available_candles) random.shuffle(self._optuna_label_candle_pool) - def fit(self, data_dictionary: dict, dk: FreqaiDataKitchen, **kwargs) -> Any: + def fit( + self, data_dictionary: dict[str, Any], dk: FreqaiDataKitchen, **kwargs + ) -> Any: """ User sets up the training and test data to fit their desired model here :param data_dictionary: the dictionary constructed by DataHandler to hold @@ -473,7 +477,9 @@ class QuickAdapterRegressorV3(BaseRegressionModel): def eval_set_and_weights( self, X_test: pd.DataFrame, y_test: pd.DataFrame, test_weights: np.ndarray - ) -> tuple[list[tuple] | None, list[np.ndarray] | None]: + ) -> tuple[ + Optional[list[tuple[pd.DataFrame, pd.DataFrame]]], Optional[list[np.ndarray]] + ]: if self.data_split_parameters.get("test_size", TEST_SIZE) == 0: eval_set = None eval_weights = None @@ -955,7 +961,9 @@ class QuickAdapterRegressorV3(BaseRegressionModel): ) raise - def optuna_load_best_params(self, pair: str, namespace: str) -> Optional[dict]: + def optuna_load_best_params( + self, pair: str, namespace: str + ) -> Optional[dict[str, Any]]: best_params_path = Path( self.full_path / f"optuna-{namespace}-best-params-{pair.split('/')[0]}.json" ) @@ -1029,11 +1037,11 @@ def fit_regressor( X: pd.DataFrame, y: pd.DataFrame, train_weights: np.ndarray, - eval_set: Optional[list[tuple]], + eval_set: Optional[list[tuple[pd.DataFrame, pd.DataFrame]]], eval_weights: Optional[list[np.ndarray]], - model_training_parameters: dict, + model_training_parameters: dict[str, Any], init_model: Any = None, - callbacks: list[Callable] = None, + callbacks: Optional[list[Callable]] = None, ) -> Any: if regressor == "xgboost": from xgboost import XGBRegressor @@ -1083,7 +1091,7 @@ def train_objective( test_size: float, fit_live_predictions_candles: int, candles_step: int, - model_training_parameters: dict, + model_training_parameters: dict[str, Any], ) -> float: def calculate_min_extrema( length: int, fit_live_predictions_candles: int, min_extrema: int = 2 @@ -1184,7 +1192,7 @@ def train_objective( def get_optuna_study_model_parameters( trial: optuna.trial.Trial, regressor: str -) -> dict: +) -> dict[str, Any]: study_model_parameters = { "learning_rate": trial.suggest_float("learning_rate", 1e-3, 0.3, log=True), "min_child_weight": trial.suggest_float( @@ -1224,7 +1232,7 @@ def hp_objective( X_test: pd.DataFrame, y_test: pd.DataFrame, test_weights: np.ndarray, - model_training_parameters: dict, + model_training_parameters: dict[str, Any], ) -> float: study_model_parameters = get_optuna_study_model_parameters(trial, regressor) model_training_parameters = {**model_training_parameters, **study_model_parameters} @@ -1343,7 +1351,7 @@ def zigzag( last_pivot_pos = pos reset_candidate_pivot() - slope_ok_cache: dict[tuple[int, int, int, float]] = {} + slope_ok_cache: dict[tuple[int, int, int, float], bool] = {} def get_slope_ok( pos: int, diff --git a/quickadapter/user_data/strategies/QuickAdapterV3.py b/quickadapter/user_data/strategies/QuickAdapterV3.py index c803451..16794dc 100644 --- a/quickadapter/user_data/strategies/QuickAdapterV3.py +++ b/quickadapter/user_data/strategies/QuickAdapterV3.py @@ -6,7 +6,7 @@ import math from pathlib import Path import talib.abstract as ta from pandas import DataFrame, Series, isna -from typing import Callable, Optional +from typing import Any, Callable, Optional from freqtrade.exchange import timeframe_to_minutes, timeframe_to_prev_date from freqtrade.strategy.interface import IStrategy from freqtrade.strategy import stoploss_from_absolute @@ -120,7 +120,7 @@ class QuickAdapterV3(IStrategy): } @cached_property - def protections(self) -> list[dict]: + def protections(self) -> list[dict[str, Any]]: fit_live_predictions_candles = self.freqai_info.get( "fit_live_predictions_candles", 100 ) @@ -195,7 +195,7 @@ class QuickAdapterV3(IStrategy): def feature_engineering_expand_all( self, dataframe: DataFrame, period: int, metadata: dict, **kwargs - ): + ) -> DataFrame: highs = dataframe.get("high") lows = dataframe.get("low") closes = dataframe.get("close") @@ -235,7 +235,7 @@ class QuickAdapterV3(IStrategy): def feature_engineering_expand_basic( self, dataframe: DataFrame, metadata: dict, **kwargs - ): + ) -> DataFrame: highs = dataframe.get("high") lows = dataframe.get("low") opens = dataframe.get("open") @@ -350,7 +350,7 @@ class QuickAdapterV3(IStrategy): dataframe["%-raw_high"] = highs return dataframe - def feature_engineering_standard(self, dataframe: DataFrame, **kwargs): + def feature_engineering_standard(self, dataframe: DataFrame, **kwargs) -> DataFrame: dates = dataframe.get("date") dataframe["%-day_of_week"] = (dates.dt.dayofweek + 1) / 7 @@ -365,7 +365,7 @@ class QuickAdapterV3(IStrategy): return label_period_candles return self.freqai_info["feature_parameters"].get("label_period_candles", 50) - def set_label_period_candles(self, pair: str, label_period_candles: int): + def set_label_period_candles(self, pair: str, label_period_candles: int) -> None: if isinstance(label_period_candles, int): self._label_params[pair]["label_period_candles"] = label_period_candles @@ -377,7 +377,7 @@ class QuickAdapterV3(IStrategy): self.freqai_info["feature_parameters"].get("label_natr_ratio", 6.0) ) - def set_label_natr_ratio(self, pair: str, label_natr_ratio: float): + def set_label_natr_ratio(self, pair: str, label_natr_ratio: float) -> None: if isinstance(label_natr_ratio, float) and np.isfinite(label_natr_ratio): self._label_params[pair]["label_natr_ratio"] = label_natr_ratio @@ -406,7 +406,9 @@ class QuickAdapterV3(IStrategy): except (KeyError, ValueError) as e: raise ValueError(f"Invalid pattern '{pattern}': {e}") - def set_freqai_targets(self, dataframe: DataFrame, metadata: dict, **kwargs): + def set_freqai_targets( + self, dataframe: DataFrame, metadata: dict, **kwargs + ) -> DataFrame: pair = str(metadata.get("pair")) label_period_candles = self.get_label_period_candles(pair) label_natr_ratio = self.get_label_natr_ratio(pair) diff --git a/quickadapter/user_data/strategies/Utils.py b/quickadapter/user_data/strategies/Utils.py index 0afca65..8e31e37 100644 --- a/quickadapter/user_data/strategies/Utils.py +++ b/quickadapter/user_data/strategies/Utils.py @@ -5,13 +5,13 @@ import numpy as np import pandas as pd import scipy as sp import talib.abstract as ta -from typing import Callable, Literal, Union +from typing import Callable, Literal, TypeVar from technical import qtpylib +T = TypeVar("T", pd.Series, float) -def get_distance( - p1: Union[pd.Series, float], p2: Union[pd.Series, float] -) -> Union[pd.Series, float]: + +def get_distance(p1: T, p2: T) -> T: return abs(p1 - p2) @@ -141,7 +141,9 @@ def price_retracement_percent(dataframe: pd.DataFrame, period: int) -> pd.Series # VWAP bands -def vwapb(dataframe: pd.DataFrame, window: int = 20, std_factor: float = 1.0) -> tuple: +def vwapb( + dataframe: pd.DataFrame, window: int = 20, std_factor: float = 1.0 +) -> tuple[pd.Series, pd.Series, pd.Series]: vwap = qtpylib.rolling_vwap(dataframe, window=window) rolling_std = vwap.rolling(window=window, min_periods=window).std() vwap_low = vwap - (rolling_std * std_factor) @@ -471,7 +473,7 @@ def zigzag( last_pivot_pos = pos reset_candidate_pivot() - slope_ok_cache: dict[tuple[int, int, int, float]] = {} + slope_ok_cache: dict[tuple[int, int, int, float], bool] = {} def get_slope_ok( pos: int, -- 2.43.0