]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor(reforcexy): factor out safe logger helper
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Sun, 21 Sep 2025 14:37:54 +0000 (16:37 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Sun, 21 Sep 2025 14:37:54 +0000 (16:37 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py
quickadapter/user_data/strategies/Utils.py

index bcad2a8cb3adcfeafbcbdd4710675c3e3780eaff..197039de0a1e9128dfe0367d69fd89d5da75d8b5 100644 (file)
@@ -170,10 +170,10 @@ class ReforceXY(BaseReinforcementLearningModel):
     @staticmethod
     def get_action_masks(
         position: Positions, force_action: Optional[ForceActions] = None
-    ) -> NDArray[bool]:
+    ) -> NDArray[np.bool_]:
         position = ReforceXY._normalize_position(position)
 
-        action_masks = np.zeros(len(Actions), dtype=bool)
+        action_masks = np.zeros(len(Actions), dtype=np.bool_)
 
         if force_action is not None and position in (Positions.Long, Positions.Short):
             if position == Positions.Long:
@@ -422,8 +422,8 @@ class ReforceXY(BaseReinforcementLearningModel):
         if self.activate_tensorboard:
             info_callback = InfoMetricsCallback(
                 actions=Actions,
-                verbose=verbose,
                 throttle=self.rl_config.get("tensorboard_throttle", 1),
+                verbose=verbose,
             )
             callbacks.append(info_callback)
             if self.plot_new_best:
@@ -590,7 +590,7 @@ class ReforceXY(BaseReinforcementLearningModel):
                 return Positions.Neutral
             return position
 
-        frame_buffer: list[np.ndarray] = []
+        frame_buffer: list[NDArray[np.float32]] = []
 
         def _predict(window) -> int:
             observation: DataFrame = dataframe.iloc[window.index]
@@ -611,7 +611,7 @@ class ReforceXY(BaseReinforcementLearningModel):
 
             np_observation = observation.to_numpy(dtype=np.float32)
 
-            fb: list[np.ndarray] = frame_buffer
+            fb: list[NDArray[np.float32]] = frame_buffer
             frame_stacking = self.frame_stacking
             if frame_stacking and frame_stacking > 1:
                 fb.append(np_observation.copy())
@@ -1449,7 +1449,7 @@ class MyRLEnv(Base5ActionRLEnv):
             )
         )
 
-    def action_masks(self) -> NDArray[bool]:
+    def action_masks(self) -> NDArray[np.bool_]:
         return ReforceXY.get_action_masks(self._position, self._force_action)
 
     def get_feature_value(
@@ -1784,6 +1784,28 @@ class InfoMetricsCallback(TensorboardCallback):
         super().__init__(*args, **kwargs)
         self.throttle = 1 if throttle < 1 else throttle
 
+    def _safe_logger_record(
+        self, key: str, value: Any, exclude: Optional[Tuple[str, ...]] = None
+    ) -> None:
+        try:
+            self.logger.record(key, value, exclude=exclude)
+        except Exception as e:
+            logger.warning("logger.record failed at %r: %r", key, e)
+            if exclude is None:
+                exclude = ("tensorboard",)
+            else:
+                exclude_set = set(exclude)
+                exclude_set.add("tensorboard")
+                exclude_set.discard("stdout")
+                exclude = tuple(exclude_set)
+            try:
+                self.logger.record(key, value, exclude=exclude)
+            except Exception as e:
+                logger.error(
+                    "logger.record retry on stdout failed again at %r: %r", key, e
+                )
+                pass
+
     def _on_training_start(self) -> None:
         lr_schedule = "unknown"
         lr_iv = np.nan
@@ -1923,7 +1945,7 @@ class InfoMetricsCallback(TensorboardCallback):
                     "train/exploration_rate": 0.0,
                 }
             )
-        self.logger.record(
+        self._safe_logger_record(
             "hparams",
             HParam(hparam_dict, metric_dict),
             exclude=("stdout", "log", "json", "csv"),
@@ -2008,28 +2030,15 @@ class InfoMetricsCallback(TensorboardCallback):
                 else:
                     aggregated_info[f"{k}_mode"] = "mixed"
 
-            try:
-                self.logger.record("info/n_envs", int(len(infos_list)))
-            except Exception:
-                try:
-                    self.logger.record(
-                        "info/n_envs", int(len(infos_list)), exclude=("tensorboard",)
-                    )
-                except Exception:
-                    pass
+            logger_exclude = ("stdout", "log", "json", "csv")
+            self._safe_logger_record(
+                "info/n_envs", int(len(infos_list)), exclude=logger_exclude
+            )
 
             if filtered_values > 0:
-                try:
-                    self.logger.record("info/filtered_values", int(filtered_values))
-                except Exception:
-                    try:
-                        self.logger.record(
-                            "info/filtered_values",
-                            int(filtered_values),
-                            exclude=("tensorboard",),
-                        )
-                    except Exception:
-                        pass
+                self._safe_logger_record(
+                    "info/filtered_values", int(filtered_values), exclude=logger_exclude
+                )
 
         if self.training_env is None:
             return True
@@ -2065,15 +2074,7 @@ class InfoMetricsCallback(TensorboardCallback):
                             cat_dict[metric] = value
 
         for metric, value in aggregated_info.items():
-            try:
-                self.logger.record(f"info/{metric}", value)
-            except Exception:
-                try:
-                    self.logger.record(
-                        f"info/{metric}", value, exclude=("tensorboard",)
-                    )
-                except Exception:
-                    pass
+            self._safe_logger_record(f"info/{metric}", value, exclude=logger_exclude)
 
         if isinstance(infos_list, list) and infos_list:
             cat_keys = ("force_action", "action", "position")
@@ -2093,61 +2094,34 @@ class InfoMetricsCallback(TensorboardCallback):
             for k, counts in cat_counts.items():
                 cat_total = max(1, int(cat_totals.get(k, 0)))
                 for name, cnt in counts.items():
-                    try:
-                        self.logger.record(f"info/{k}/{name}_count", int(cnt))
-                        self.logger.record(
-                            f"info/{k}/{name}_ratio", float(cnt) / float(cat_total)
-                        )
-                    except Exception:
-                        try:
-                            self.logger.record(
-                                f"info/{k}/{name}_count",
-                                int(cnt),
-                                exclude=("tensorboard",),
-                            )
-                            self.logger.record(
-                                f"info/{k}/{name}_ratio",
-                                float(cnt) / float(cat_total),
-                                exclude=("tensorboard",),
-                            )
-                        except Exception:
-                            pass
+                    self._safe_logger_record(
+                        f"info/{k}/{name}_count", int(cnt), exclude=logger_exclude
+                    )
+                    self._safe_logger_record(
+                        f"info/{k}/{name}_ratio",
+                        float(cnt) / float(cat_total),
+                        exclude=logger_exclude,
+                    )
 
         for category, metrics in aggregated_tensorboard_metrics.items():
             if isinstance(metrics, dict):
                 for metric, value in metrics.items():
-                    try:
-                        self.logger.record(f"{category}/{metric}_sum", value)
-                    except Exception:
-                        try:
-                            self.logger.record(
-                                f"{category}/{metric}_sum",
-                                value,
-                                exclude=("tensorboard",),
-                            )
-                        except Exception:
-                            pass
-                    try:
-                        count = aggregated_tensorboard_metric_counts.get(
-                            category, {}
-                        ).get(metric)
-                        if isinstance(value, (int, float)) and count and count > 0:
-                            self.logger.record(
-                                f"{category}/{metric}_mean", float(value) / float(count)
-                            )
-                    except Exception:
-                        try:
-                            count = aggregated_tensorboard_metric_counts.get(
-                                category, {}
-                            ).get(metric)
-                            if isinstance(value, (int, float)) and count and count > 0:
-                                self.logger.record(
-                                    f"{category}/{metric}_mean",
-                                    float(value) / float(count),
-                                    exclude=("tensorboard",),
-                                )
-                        except Exception:
-                            pass
+                    self._safe_logger_record(
+                        f"{category}/{metric}_sum", value, exclude=logger_exclude
+                    )
+                    count = aggregated_tensorboard_metric_counts.get(category, {}).get(
+                        metric
+                    )
+                    if (
+                        _is_finite_number(value)
+                        and isinstance(count, int)
+                        and count > 0
+                    ):
+                        self._safe_logger_record(
+                            f"{category}/{metric}_mean",
+                            float(value) / float(count),
+                            exclude=logger_exclude,
+                        )
 
         try:
             total_timesteps = getattr(self.model, "_total_timesteps", None)
@@ -2157,23 +2131,12 @@ class InfoMetricsCallback(TensorboardCallback):
             else:
                 progress_done = 0.0
             progress_remaining = 1.0 - progress_done
-            try:
-                self.logger.record("train/progress_done", progress_done)
-                self.logger.record("train/progress_remaining", progress_remaining)
-            except Exception:
-                try:
-                    self.logger.record(
-                        "train/progress_done",
-                        progress_done,
-                        exclude=("tensorboard",),
-                    )
-                    self.logger.record(
-                        "train/progress_remaining",
-                        progress_remaining,
-                        exclude=("tensorboard",),
-                    )
-                except Exception:
-                    pass
+            self._safe_logger_record(
+                "train/progress_done", progress_done, exclude=logger_exclude
+            )
+            self._safe_logger_record(
+                "train/progress_remaining", progress_remaining, exclude=logger_exclude
+            )
         except Exception:
             progress_remaining = 1.0
 
@@ -2181,18 +2144,10 @@ class InfoMetricsCallback(TensorboardCallback):
             lr = getattr(self.model, "learning_rate", None)
             if callable(lr):
                 lr = lr(progress_remaining)
-            if isinstance(lr, (int, float)) and np.isfinite(lr):
-                try:
-                    self.logger.record("train/learning_rate", float(lr))
-                except Exception:
-                    try:
-                        self.logger.record(
-                            "train/learning_rate",
-                            float(lr),
-                            exclude=("tensorboard",),
-                        )
-                    except Exception:
-                        pass
+            if _is_finite_number(lr):
+                self._safe_logger_record(
+                    "train/learning_rate", float(lr), exclude=logger_exclude
+                )
         except Exception:
             pass
 
@@ -2201,18 +2156,10 @@ class InfoMetricsCallback(TensorboardCallback):
                 cr = getattr(self.model, "clip_range", None)
                 if callable(cr):
                     cr = cr(progress_remaining)
-                if isinstance(cr, (int, float)) and np.isfinite(cr):
-                    try:
-                        self.logger.record("train/clip_range", float(cr))
-                    except Exception:
-                        try:
-                            self.logger.record(
-                                "train/clip_range",
-                                float(cr),
-                                exclude=("tensorboard",),
-                            )
-                        except Exception:
-                            pass
+                if _is_finite_number(cr):
+                    self._safe_logger_record(
+                        "train/clip_range", float(cr), exclude=logger_exclude
+                    )
             except Exception:
                 pass
 
@@ -2228,9 +2175,15 @@ class RolloutPlotCallback(BaseCallback):
         figures = self.training_env.env_method("get_env_plot")
         for i, fig in enumerate(figures):
             figure = Figure(fig, close=True)
-            self.logger.record(
-                f"best/train_env{i}", figure, exclude=("stdout", "log", "json", "csv")
-            )
+            try:
+                self.logger.record(
+                    f"best/train_env{i}",
+                    figure,
+                    exclude=("stdout", "log", "json", "csv"),
+                )
+            except Exception as e:
+                logger.error("logger.record failed at %r: %r", f"best/train_env{i}", e)
+                pass
         return True
 
     def _on_step(self) -> bool:
@@ -2314,8 +2267,17 @@ class MaskableTrialEvalCallback(MaskableEvalCallback):
                 )
             if np.isfinite(best_mean_reward):
                 try:
-                    self.logger.record("eval/best_mean_reward", best_mean_reward)
-                except Exception:
+                    self.logger.record(
+                        "eval/best_mean_reward",
+                        best_mean_reward,
+                        exclude=("stdout", "log", "json", "csv"),
+                    )
+                except Exception as e:
+                    logger.error(
+                        "Optuna: logger.record failed at %r: %r",
+                        "eval/best_mean_reward",
+                        e,
+                    )
                     pass
             else:
                 logger.warning(
index 47c2376628869f645c471ac7064789fcc29f12dd..692f5244880ac483a54d307acb3e69108681f70c 100644 (file)
@@ -718,7 +718,9 @@ def zigzag(
 regressors = {"xgboost", "lightgbm"}
 
 
-def get_optuna_callbacks(trial: optuna.trial.Trial, regressor: str) -> list[Callable]:
+def get_optuna_callbacks(
+    trial: optuna.trial.Trial, regressor: str
+) -> list[Callable[[optuna.trial.Trial, str], None]]:
     if regressor == "xgboost":
         callbacks = [
             optuna.integration.XGBoostPruningCallback(trial, "validation_0-rmse")
@@ -741,7 +743,7 @@ def fit_regressor(
     eval_weights: Optional[list[NDArray[np.floating]]],
     model_training_parameters: dict[str, Any],
     init_model: Any = None,
-    callbacks: Optional[list[Callable]] = None,
+    callbacks: Optional[list[Callable[[optuna.trial.Trial, str], None]]] = None,
     trial: Optional[optuna.trial.Trial] = None,
 ) -> Any:
     if regressor == "xgboost":