From: Jérôme Benoit Date: Mon, 22 Sep 2025 11:35:12 +0000 (+0200) Subject: refactor(reforcexy): factor out DQN train_freq handling X-Git-Url: https://git.piment-noir.org/?a=commitdiff_plain;h=80c20532199fa47259ce03a042aa2c19590cf0db;p=freqai-strategies.git refactor(reforcexy): factor out DQN train_freq handling Signed-off-by: Jérôme Benoit --- diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index 886e858..68a14f3 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -8,7 +8,7 @@ from collections import defaultdict from collections.abc import Mapping from enum import IntEnum from pathlib import Path -from typing import Any, Callable, Dict, Literal, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union import matplotlib import matplotlib.pyplot as plt @@ -114,7 +114,7 @@ class ReforceXY(BaseReinforcementLearningModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.pairs: list[str] = self.config.get("exchange", {}).get("pair_whitelist") + self.pairs: List[str] = self.config.get("exchange", {}).get("pair_whitelist") if not self.pairs: raise ValueError( "FreqAI model requires StaticPairList method defined in pairlists configuration and pair_whitelist defined in exchange section configuration" @@ -333,10 +333,10 @@ class ReforceXY(BaseReinforcementLearningModel): if not model_params.get("policy_kwargs"): model_params["policy_kwargs"] = {} - default_net_arch: list[int] = [128, 128] + default_net_arch: List[int] = [128, 128] net_arch: Union[ - list[int], - Dict[str, list[int]], + List[int], + Dict[str, List[int]], Literal["small", "medium", "large", "extra_large"], ] = model_params.get("policy_kwargs", {}).get("net_arch", default_net_arch) @@ -356,8 +356,8 @@ class ReforceXY(BaseReinforcementLearningModel): "vf": net_arch, } elif isinstance(net_arch, dict): - pi: Optional[list[int]] = net_arch.get("pi") - vf: Optional[list[int]] = net_arch.get("vf") + pi: Optional[List[int]] = net_arch.get("pi") + vf: Optional[List[int]] = net_arch.get("vf") if not isinstance(pi, list) or not isinstance(vf, list): model_params["policy_kwargs"]["net_arch"] = { "pi": pi @@ -426,11 +426,11 @@ class ReforceXY(BaseReinforcementLearningModel): eval_freq: int, data_path: str, trial: Optional[Trial] = None, - ) -> list[BaseCallback]: + ) -> List[BaseCallback]: """ Get the model specific callbacks """ - callbacks: list[BaseCallback] = [] + callbacks: List[BaseCallback] = [] no_improvement_callback = None rollout_plot_callback = None verbose = int(self.get_model_params().get("verbose", 0)) @@ -540,11 +540,12 @@ class ReforceXY(BaseReinforcementLearningModel): if "PPO" in self.model_type: min_timesteps = 2 * model_params.get("n_steps", 0) * self.n_envs - if total_timesteps < min_timesteps: + if total_timesteps <= min_timesteps: logger.warning( - "total_timesteps=%s is less than 2*n_steps*n_envs=%s. This may lead to suboptimal training results", + "total_timesteps=%s is less than or equal to 2*n_steps*n_envs=%s. This may lead to suboptimal training results for model %s", total_timesteps, min_timesteps, + self.model_type, ) if self.activate_tensorboard: @@ -624,7 +625,7 @@ class ReforceXY(BaseReinforcementLearningModel): return Positions.Neutral return position - frame_buffer: list[NDArray[np.float32]] = [] + frame_buffer: List[NDArray[np.float32]] = [] def _predict(window) -> int: observation: DataFrame = dataframe.iloc[window.index] @@ -645,7 +646,7 @@ class ReforceXY(BaseReinforcementLearningModel): np_observation = observation.to_numpy(dtype=np.float32) - fb: list[NDArray[np.float32]] = frame_buffer + fb: List[NDArray[np.float32]] = frame_buffer frame_stacking = self.frame_stacking if frame_stacking and frame_stacking > 1: fb.append(np_observation.copy()) @@ -672,7 +673,7 @@ class ReforceXY(BaseReinforcementLearningModel): ) return int(action) - predicted_actions: list[int] = [] + predicted_actions: List[int] = [] for window_end in range(self.CONV_WIDTH, len(dataframe) + 1): window = dataframe.iloc[window_end - self.CONV_WIDTH : window_end] action = _predict(window) @@ -1830,6 +1831,25 @@ class InfoMetricsCallback(TensorboardCallback): logger.error("logger.record retry on stdout failed at %r: %r", key, e) pass + @staticmethod + def build_train_freq( + train_freq: Optional[Union[int, Tuple[int, ...], List[int]]], + ) -> Optional[int]: + train_freq_val: Optional[int] = None + try: + if isinstance(train_freq, int): + train_freq_val = train_freq + elif isinstance(train_freq, (tuple, list)) and train_freq: + if isinstance(train_freq[0], int): + train_freq_val = train_freq[0] + elif hasattr(train_freq, "freq"): + freq = getattr(train_freq, "freq") + if isinstance(freq, int): + train_freq_val = freq + except Exception: + train_freq_val = None + return train_freq_val + def _on_training_start(self) -> None: lr = getattr(self.model, "learning_rate", None) lr_schedule, lr_iv, lr_fv = get_schedule_type(lr) @@ -1886,22 +1906,11 @@ class InfoMetricsCallback(TensorboardCallback): "exploration_rate": float(self.model.exploration_rate), } ) - train_freq = getattr(self.model, "train_freq", None) - train_freq_val: int | None = None - try: - if isinstance(train_freq, int): - train_freq_val = train_freq - elif isinstance(train_freq, (tuple, list)) and train_freq: - if isinstance(train_freq[0], int): - train_freq_val = train_freq[0] - elif hasattr(train_freq, "freq"): - freq = getattr(train_freq, "freq") - if isinstance(freq, int): - train_freq_val = freq - except Exception: - train_freq_val = None - if train_freq_val is not None: - hparam_dict.update({"train_freq": train_freq_val}) + train_freq = InfoMetricsCallback.build_train_freq( + getattr(self.model, "train_freq", None) + ) + if train_freq is not None: + hparam_dict.update({"train_freq": train_freq}) if "QRDQN" in self.model.__class__.__name__: hparam_dict.update({"n_quantiles": int(self.model.n_quantiles)}) metric_dict: dict[str, float | int] = { @@ -1960,11 +1969,11 @@ class InfoMetricsCallback(TensorboardCallback): except Exception: return False - infos_list: list[Dict[str, Any]] | None = self.locals.get("infos") + infos_list: List[Dict[str, Any]] | None = self.locals.get("infos") aggregated_info: Dict[str, Any] = {} if isinstance(infos_list, list) and infos_list: - numeric_acc: Dict[str, list[float]] = defaultdict(list) + numeric_acc: Dict[str, List[float]] = defaultdict(list) non_numeric_counts: Dict[str, Dict[Any, int]] = defaultdict( lambda: defaultdict(int) ) @@ -2412,7 +2421,7 @@ def get_schedule( def get_net_arch( model_type: str, net_arch_type: Literal["small", "medium", "large", "extra_large"] -) -> Union[list[int], Dict[str, list[int]]]: +) -> Union[List[int], Dict[str, List[int]]]: """ Get network architecture """