From 5e2ea5dd3347a7a6766943590a971bd05aff37fd Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Mon, 22 Sep 2025 00:31:36 +0200 Subject: [PATCH] refactor(reforcexy): cleanup training schedule handling code MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/user_data/freqaimodels/ReforceXY.py | 131 ++++++++++-------- 1 file changed, 76 insertions(+), 55 deletions(-) diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index e241019..5609f60 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -47,7 +47,7 @@ from stable_baselines3.common.callbacks import ( ) from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.logger import Figure, HParam -from stable_baselines3.common.utils import set_random_seed +from stable_baselines3.common.utils import ConstantSchedule, set_random_seed from stable_baselines3.common.vec_env import ( DummyVecEnv, SubprocVecEnv, @@ -310,7 +310,7 @@ class ReforceXY(BaseReinforcementLearningModel): lr = model_params.get("learning_rate", 0.0003) if isinstance(lr, (int, float)): lr = float(lr) - model_params["learning_rate"] = linear_schedule(lr) + model_params["learning_rate"] = get_schedule("linear", lr) logger.info( "Learning rate linear schedule enabled, initial value: %s", lr ) @@ -319,7 +319,7 @@ class ReforceXY(BaseReinforcementLearningModel): cr = model_params.get("clip_range", 0.2) if isinstance(cr, (int, float)): cr = float(cr) - model_params["clip_range"] = linear_schedule(cr) + model_params["clip_range"] = get_schedule("linear", cr) logger.info("Clip range linear schedule enabled, initial value: %s", cr) if "DQN" in self.model_type: @@ -1369,8 +1369,8 @@ class MyRLEnv(Base5ActionRLEnv): self._update_total_profit() self._last_closed_position = self._position self._position = Positions.Neutral - self._last_closed_trade_tick = self._current_tick self._last_trade_tick = None + self._last_closed_trade_tick = self._current_tick def execute_trade(self, action: int) -> Optional[str]: """ @@ -1831,24 +1831,8 @@ class InfoMetricsCallback(TensorboardCallback): pass def _on_training_start(self) -> None: - lr_schedule = "unknown" - lr_iv = np.nan - lr_fv = np.nan lr = getattr(self.model, "learning_rate", None) - if callable(lr): - lr_schedule = "linear" - try: - lr_iv = lr(1.0) - except Exception: - lr_iv = np.nan - try: - lr_fv = lr(0.0) - except Exception: - lr_fv = np.nan - elif isinstance(lr, (int, float)): - lr_schedule = "constant" - lr_iv = float(lr) - lr_fv = float(lr) + lr_schedule, lr_iv, lr_fv = get_schedule_type(lr) n_stack = 1 env = getattr(self, "training_env", None) while env is not None: @@ -1870,24 +1854,8 @@ class InfoMetricsCallback(TensorboardCallback): "batch_size": int(self.model.batch_size), } if "PPO" in self.model.__class__.__name__: - cr_schedule = "unknown" - cr_iv = np.nan - cr_fv = np.nan cr = getattr(self.model, "clip_range", None) - if callable(cr): - cr_schedule = "linear" - try: - cr_iv = cr(1.0) - except Exception: - cr_iv = np.nan - try: - cr_fv = cr(0.0) - except Exception: - cr_fv = np.nan - elif isinstance(cr, (int, float)): - cr_schedule = "constant" - cr_iv = float(cr) - cr_fv = float(cr) + cr_schedule, cr_iv, cr_fv = get_schedule_type(cr) hparam_dict.update( { "cr_schedule": cr_schedule, @@ -2164,10 +2132,23 @@ class InfoMetricsCallback(TensorboardCallback): except Exception: progress_remaining = 1.0 + def _eval_schedule(schedule: Any) -> float | None: + schedule_type, _, _ = get_schedule_type(schedule) + try: + if schedule_type == "linear": + return float(schedule(progress_remaining)) + if schedule_type == "constant": + if callable(schedule): + return float(schedule(0.0)) + if isinstance(schedule, (int, float)): + return float(schedule) + return None + except Exception: + return None + try: lr = getattr(self.model, "learning_rate", None) - if callable(lr): - lr = lr(progress_remaining) + lr = _eval_schedule(lr) if _is_finite_number(lr): self._safe_logger_record( "train/learning_rate", float(lr), exclude=logger_exclude @@ -2178,8 +2159,7 @@ class InfoMetricsCallback(TensorboardCallback): if "PPO" in self.model.__class__.__name__: try: cr = getattr(self.model, "clip_range", None) - if callable(cr): - cr = cr(progress_remaining) + cr = _eval_schedule(cr) if _is_finite_number(cr): self._safe_logger_record( "train/clip_range", float(cr), exclude=logger_exclude @@ -2326,6 +2306,25 @@ class MaskableTrialEvalCallback(MaskableEvalCallback): return True +class SimpleLinearSchedule: + """ + Linear schedule (from initial value to zero), + simpler than sb3 LinearSchedule. + + :param initial_value: (float or str) The initial value for the schedule + """ + + def __init__(self, initial_value: Union[float, str]) -> None: + # Force conversion to float + self.initial_value = float(initial_value) + + def __call__(self, progress_remaining: float) -> float: + return progress_remaining * self.initial_value + + def __repr__(self) -> str: + return f"SimpleLinearSchedule(initial_value={self.initial_value})" + + def deepmerge(dst: Dict[str, Any], src: Dict[str, Any]) -> Dict[str, Any]: """Recursively merge two dicts without mutating inputs""" dst_copy = copy.deepcopy(dst) @@ -2341,13 +2340,6 @@ def deepmerge(dst: Dict[str, Any], src: Dict[str, Any]) -> Dict[str, Any]: return dst_copy -def linear_schedule(initial_value: float) -> Callable[[float], float]: - def func(progress_remaining: float) -> float: - return progress_remaining * initial_value - - return func - - def _compute_gradient_steps(tf: int, ss: int) -> int: if tf > 0 and ss > 0: return min(tf, max(tf // ss, 1)) @@ -2385,6 +2377,39 @@ def steps_to_days(steps: int, timeframe: str) -> float: return round(days, 1) +def get_schedule_type( + schedule: Any, +) -> Tuple[Literal["constant", "linear", "unknown"], float, float]: + if isinstance(schedule, (int, float)): + try: + schedule = float(schedule) + return "constant", schedule, schedule + except Exception: + return "constant", np.nan, np.nan + elif isinstance(schedule, ConstantSchedule): + try: + return "constant", schedule(1.0), schedule(0.0) + except Exception: + return "constant", np.nan, np.nan + elif isinstance(schedule, SimpleLinearSchedule): + try: + return "linear", schedule(1.0), schedule(0.0) + except Exception: + return "linear", np.nan, np.nan + + return "unknown", np.nan, np.nan + + +def get_schedule( + schedule_type: Literal["linear", "constant"], + initial_value: float, +) -> Callable[[float], float]: + if schedule_type == "linear": + return SimpleLinearSchedule(initial_value) + elif schedule_type == "constant": + return ConstantSchedule(initial_value) + + def get_net_arch( model_type: str, net_arch_type: Literal["small", "medium", "large", "extra_large"] ) -> Union[list[int], Dict[str, list[int]]]: @@ -2441,17 +2466,13 @@ def convert_optuna_params_to_model_params( lr = optuna_params.get("learning_rate") if lr is None: raise ValueError(f"missing 'learning_rate' in optuna params for {model_type}") - lr: float | Callable[[float], float] = float(lr) - if optuna_params.get("lr_schedule") == "linear": - lr = linear_schedule(lr) + lr = get_schedule(optuna_params.get("lr_schedule", "constant"), float(lr)) if "PPO" in model_type: cr = optuna_params.get("clip_range") if cr is None: raise ValueError(f"missing 'clip_range' in optuna params for {model_type}") - cr: float | Callable[[float], float] = float(cr) - if optuna_params.get("cr_schedule") == "linear": - cr = linear_schedule(cr) + cr = get_schedule(optuna_params.get("cr_schedule", "constant"), float(cr)) model_params.update( { -- 2.43.0