]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor(reforcexy): cleanup training schedule handling code
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Sun, 21 Sep 2025 22:31:36 +0000 (00:31 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Sun, 21 Sep 2025 22:31:36 +0000 (00:31 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index e2410194a0f7a2e6e0364595181fb73b83aba795..5609f60ed64b1d2542f28651fed54bc08618c2ea 100644 (file)
@@ -47,7 +47,7 @@ from stable_baselines3.common.callbacks import (
 )
 from stable_baselines3.common.env_checker import check_env
 from stable_baselines3.common.logger import Figure, HParam
-from stable_baselines3.common.utils import set_random_seed
+from stable_baselines3.common.utils import ConstantSchedule, set_random_seed
 from stable_baselines3.common.vec_env import (
     DummyVecEnv,
     SubprocVecEnv,
@@ -310,7 +310,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             lr = model_params.get("learning_rate", 0.0003)
             if isinstance(lr, (int, float)):
                 lr = float(lr)
-                model_params["learning_rate"] = linear_schedule(lr)
+                model_params["learning_rate"] = get_schedule("linear", lr)
                 logger.info(
                     "Learning rate linear schedule enabled, initial value: %s", lr
                 )
@@ -319,7 +319,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             cr = model_params.get("clip_range", 0.2)
             if isinstance(cr, (int, float)):
                 cr = float(cr)
-                model_params["clip_range"] = linear_schedule(cr)
+                model_params["clip_range"] = get_schedule("linear", cr)
                 logger.info("Clip range linear schedule enabled, initial value: %s", cr)
 
         if "DQN" in self.model_type:
@@ -1369,8 +1369,8 @@ class MyRLEnv(Base5ActionRLEnv):
         self._update_total_profit()
         self._last_closed_position = self._position
         self._position = Positions.Neutral
-        self._last_closed_trade_tick = self._current_tick
         self._last_trade_tick = None
+        self._last_closed_trade_tick = self._current_tick
 
     def execute_trade(self, action: int) -> Optional[str]:
         """
@@ -1831,24 +1831,8 @@ class InfoMetricsCallback(TensorboardCallback):
                 pass
 
     def _on_training_start(self) -> None:
-        lr_schedule = "unknown"
-        lr_iv = np.nan
-        lr_fv = np.nan
         lr = getattr(self.model, "learning_rate", None)
-        if callable(lr):
-            lr_schedule = "linear"
-            try:
-                lr_iv = lr(1.0)
-            except Exception:
-                lr_iv = np.nan
-            try:
-                lr_fv = lr(0.0)
-            except Exception:
-                lr_fv = np.nan
-        elif isinstance(lr, (int, float)):
-            lr_schedule = "constant"
-            lr_iv = float(lr)
-            lr_fv = float(lr)
+        lr_schedule, lr_iv, lr_fv = get_schedule_type(lr)
         n_stack = 1
         env = getattr(self, "training_env", None)
         while env is not None:
@@ -1870,24 +1854,8 @@ class InfoMetricsCallback(TensorboardCallback):
             "batch_size": int(self.model.batch_size),
         }
         if "PPO" in self.model.__class__.__name__:
-            cr_schedule = "unknown"
-            cr_iv = np.nan
-            cr_fv = np.nan
             cr = getattr(self.model, "clip_range", None)
-            if callable(cr):
-                cr_schedule = "linear"
-                try:
-                    cr_iv = cr(1.0)
-                except Exception:
-                    cr_iv = np.nan
-                try:
-                    cr_fv = cr(0.0)
-                except Exception:
-                    cr_fv = np.nan
-            elif isinstance(cr, (int, float)):
-                cr_schedule = "constant"
-                cr_iv = float(cr)
-                cr_fv = float(cr)
+            cr_schedule, cr_iv, cr_fv = get_schedule_type(cr)
             hparam_dict.update(
                 {
                     "cr_schedule": cr_schedule,
@@ -2164,10 +2132,23 @@ class InfoMetricsCallback(TensorboardCallback):
         except Exception:
             progress_remaining = 1.0
 
+        def _eval_schedule(schedule: Any) -> float | None:
+            schedule_type, _, _ = get_schedule_type(schedule)
+            try:
+                if schedule_type == "linear":
+                    return float(schedule(progress_remaining))
+                if schedule_type == "constant":
+                    if callable(schedule):
+                        return float(schedule(0.0))
+                    if isinstance(schedule, (int, float)):
+                        return float(schedule)
+                return None
+            except Exception:
+                return None
+
         try:
             lr = getattr(self.model, "learning_rate", None)
-            if callable(lr):
-                lr = lr(progress_remaining)
+            lr = _eval_schedule(lr)
             if _is_finite_number(lr):
                 self._safe_logger_record(
                     "train/learning_rate", float(lr), exclude=logger_exclude
@@ -2178,8 +2159,7 @@ class InfoMetricsCallback(TensorboardCallback):
         if "PPO" in self.model.__class__.__name__:
             try:
                 cr = getattr(self.model, "clip_range", None)
-                if callable(cr):
-                    cr = cr(progress_remaining)
+                cr = _eval_schedule(cr)
                 if _is_finite_number(cr):
                     self._safe_logger_record(
                         "train/clip_range", float(cr), exclude=logger_exclude
@@ -2326,6 +2306,25 @@ class MaskableTrialEvalCallback(MaskableEvalCallback):
         return True
 
 
+class SimpleLinearSchedule:
+    """
+    Linear schedule (from initial value to zero),
+    simpler than sb3 LinearSchedule.
+
+    :param initial_value: (float or str) The initial value for the schedule
+    """
+
+    def __init__(self, initial_value: Union[float, str]) -> None:
+        # Force conversion to float
+        self.initial_value = float(initial_value)
+
+    def __call__(self, progress_remaining: float) -> float:
+        return progress_remaining * self.initial_value
+
+    def __repr__(self) -> str:
+        return f"SimpleLinearSchedule(initial_value={self.initial_value})"
+
+
 def deepmerge(dst: Dict[str, Any], src: Dict[str, Any]) -> Dict[str, Any]:
     """Recursively merge two dicts without mutating inputs"""
     dst_copy = copy.deepcopy(dst)
@@ -2341,13 +2340,6 @@ def deepmerge(dst: Dict[str, Any], src: Dict[str, Any]) -> Dict[str, Any]:
     return dst_copy
 
 
-def linear_schedule(initial_value: float) -> Callable[[float], float]:
-    def func(progress_remaining: float) -> float:
-        return progress_remaining * initial_value
-
-    return func
-
-
 def _compute_gradient_steps(tf: int, ss: int) -> int:
     if tf > 0 and ss > 0:
         return min(tf, max(tf // ss, 1))
@@ -2385,6 +2377,39 @@ def steps_to_days(steps: int, timeframe: str) -> float:
     return round(days, 1)
 
 
+def get_schedule_type(
+    schedule: Any,
+) -> Tuple[Literal["constant", "linear", "unknown"], float, float]:
+    if isinstance(schedule, (int, float)):
+        try:
+            schedule = float(schedule)
+            return "constant", schedule, schedule
+        except Exception:
+            return "constant", np.nan, np.nan
+    elif isinstance(schedule, ConstantSchedule):
+        try:
+            return "constant", schedule(1.0), schedule(0.0)
+        except Exception:
+            return "constant", np.nan, np.nan
+    elif isinstance(schedule, SimpleLinearSchedule):
+        try:
+            return "linear", schedule(1.0), schedule(0.0)
+        except Exception:
+            return "linear", np.nan, np.nan
+
+    return "unknown", np.nan, np.nan
+
+
+def get_schedule(
+    schedule_type: Literal["linear", "constant"],
+    initial_value: float,
+) -> Callable[[float], float]:
+    if schedule_type == "linear":
+        return SimpleLinearSchedule(initial_value)
+    elif schedule_type == "constant":
+        return ConstantSchedule(initial_value)
+
+
 def get_net_arch(
     model_type: str, net_arch_type: Literal["small", "medium", "large", "extra_large"]
 ) -> Union[list[int], Dict[str, list[int]]]:
@@ -2441,17 +2466,13 @@ def convert_optuna_params_to_model_params(
     lr = optuna_params.get("learning_rate")
     if lr is None:
         raise ValueError(f"missing 'learning_rate' in optuna params for {model_type}")
-    lr: float | Callable[[float], float] = float(lr)
-    if optuna_params.get("lr_schedule") == "linear":
-        lr = linear_schedule(lr)
+    lr = get_schedule(optuna_params.get("lr_schedule", "constant"), float(lr))
 
     if "PPO" in model_type:
         cr = optuna_params.get("clip_range")
         if cr is None:
             raise ValueError(f"missing 'clip_range' in optuna params for {model_type}")
-        cr: float | Callable[[float], float] = float(cr)
-        if optuna_params.get("cr_schedule") == "linear":
-            cr = linear_schedule(cr)
+        cr = get_schedule(optuna_params.get("cr_schedule", "constant"), float(cr))
 
         model_params.update(
             {