From 0a9a3ab6eafe461ec2cb3a65a70c6a5082505a6b Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Sun, 21 Sep 2025 16:37:54 +0200 Subject: [PATCH] refactor(reforcexy): factor out safe logger helper MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/user_data/freqaimodels/ReforceXY.py | 228 ++++++++---------- quickadapter/user_data/strategies/Utils.py | 6 +- 2 files changed, 99 insertions(+), 135 deletions(-) diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index bcad2a8..197039d 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -170,10 +170,10 @@ class ReforceXY(BaseReinforcementLearningModel): @staticmethod def get_action_masks( position: Positions, force_action: Optional[ForceActions] = None - ) -> NDArray[bool]: + ) -> NDArray[np.bool_]: position = ReforceXY._normalize_position(position) - action_masks = np.zeros(len(Actions), dtype=bool) + action_masks = np.zeros(len(Actions), dtype=np.bool_) if force_action is not None and position in (Positions.Long, Positions.Short): if position == Positions.Long: @@ -422,8 +422,8 @@ class ReforceXY(BaseReinforcementLearningModel): if self.activate_tensorboard: info_callback = InfoMetricsCallback( actions=Actions, - verbose=verbose, throttle=self.rl_config.get("tensorboard_throttle", 1), + verbose=verbose, ) callbacks.append(info_callback) if self.plot_new_best: @@ -590,7 +590,7 @@ class ReforceXY(BaseReinforcementLearningModel): return Positions.Neutral return position - frame_buffer: list[np.ndarray] = [] + frame_buffer: list[NDArray[np.float32]] = [] def _predict(window) -> int: observation: DataFrame = dataframe.iloc[window.index] @@ -611,7 +611,7 @@ class ReforceXY(BaseReinforcementLearningModel): np_observation = observation.to_numpy(dtype=np.float32) - fb: list[np.ndarray] = frame_buffer + fb: list[NDArray[np.float32]] = frame_buffer frame_stacking = self.frame_stacking if frame_stacking and frame_stacking > 1: fb.append(np_observation.copy()) @@ -1449,7 +1449,7 @@ class MyRLEnv(Base5ActionRLEnv): ) ) - def action_masks(self) -> NDArray[bool]: + def action_masks(self) -> NDArray[np.bool_]: return ReforceXY.get_action_masks(self._position, self._force_action) def get_feature_value( @@ -1784,6 +1784,28 @@ class InfoMetricsCallback(TensorboardCallback): super().__init__(*args, **kwargs) self.throttle = 1 if throttle < 1 else throttle + def _safe_logger_record( + self, key: str, value: Any, exclude: Optional[Tuple[str, ...]] = None + ) -> None: + try: + self.logger.record(key, value, exclude=exclude) + except Exception as e: + logger.warning("logger.record failed at %r: %r", key, e) + if exclude is None: + exclude = ("tensorboard",) + else: + exclude_set = set(exclude) + exclude_set.add("tensorboard") + exclude_set.discard("stdout") + exclude = tuple(exclude_set) + try: + self.logger.record(key, value, exclude=exclude) + except Exception as e: + logger.error( + "logger.record retry on stdout failed again at %r: %r", key, e + ) + pass + def _on_training_start(self) -> None: lr_schedule = "unknown" lr_iv = np.nan @@ -1923,7 +1945,7 @@ class InfoMetricsCallback(TensorboardCallback): "train/exploration_rate": 0.0, } ) - self.logger.record( + self._safe_logger_record( "hparams", HParam(hparam_dict, metric_dict), exclude=("stdout", "log", "json", "csv"), @@ -2008,28 +2030,15 @@ class InfoMetricsCallback(TensorboardCallback): else: aggregated_info[f"{k}_mode"] = "mixed" - try: - self.logger.record("info/n_envs", int(len(infos_list))) - except Exception: - try: - self.logger.record( - "info/n_envs", int(len(infos_list)), exclude=("tensorboard",) - ) - except Exception: - pass + logger_exclude = ("stdout", "log", "json", "csv") + self._safe_logger_record( + "info/n_envs", int(len(infos_list)), exclude=logger_exclude + ) if filtered_values > 0: - try: - self.logger.record("info/filtered_values", int(filtered_values)) - except Exception: - try: - self.logger.record( - "info/filtered_values", - int(filtered_values), - exclude=("tensorboard",), - ) - except Exception: - pass + self._safe_logger_record( + "info/filtered_values", int(filtered_values), exclude=logger_exclude + ) if self.training_env is None: return True @@ -2065,15 +2074,7 @@ class InfoMetricsCallback(TensorboardCallback): cat_dict[metric] = value for metric, value in aggregated_info.items(): - try: - self.logger.record(f"info/{metric}", value) - except Exception: - try: - self.logger.record( - f"info/{metric}", value, exclude=("tensorboard",) - ) - except Exception: - pass + self._safe_logger_record(f"info/{metric}", value, exclude=logger_exclude) if isinstance(infos_list, list) and infos_list: cat_keys = ("force_action", "action", "position") @@ -2093,61 +2094,34 @@ class InfoMetricsCallback(TensorboardCallback): for k, counts in cat_counts.items(): cat_total = max(1, int(cat_totals.get(k, 0))) for name, cnt in counts.items(): - try: - self.logger.record(f"info/{k}/{name}_count", int(cnt)) - self.logger.record( - f"info/{k}/{name}_ratio", float(cnt) / float(cat_total) - ) - except Exception: - try: - self.logger.record( - f"info/{k}/{name}_count", - int(cnt), - exclude=("tensorboard",), - ) - self.logger.record( - f"info/{k}/{name}_ratio", - float(cnt) / float(cat_total), - exclude=("tensorboard",), - ) - except Exception: - pass + self._safe_logger_record( + f"info/{k}/{name}_count", int(cnt), exclude=logger_exclude + ) + self._safe_logger_record( + f"info/{k}/{name}_ratio", + float(cnt) / float(cat_total), + exclude=logger_exclude, + ) for category, metrics in aggregated_tensorboard_metrics.items(): if isinstance(metrics, dict): for metric, value in metrics.items(): - try: - self.logger.record(f"{category}/{metric}_sum", value) - except Exception: - try: - self.logger.record( - f"{category}/{metric}_sum", - value, - exclude=("tensorboard",), - ) - except Exception: - pass - try: - count = aggregated_tensorboard_metric_counts.get( - category, {} - ).get(metric) - if isinstance(value, (int, float)) and count and count > 0: - self.logger.record( - f"{category}/{metric}_mean", float(value) / float(count) - ) - except Exception: - try: - count = aggregated_tensorboard_metric_counts.get( - category, {} - ).get(metric) - if isinstance(value, (int, float)) and count and count > 0: - self.logger.record( - f"{category}/{metric}_mean", - float(value) / float(count), - exclude=("tensorboard",), - ) - except Exception: - pass + self._safe_logger_record( + f"{category}/{metric}_sum", value, exclude=logger_exclude + ) + count = aggregated_tensorboard_metric_counts.get(category, {}).get( + metric + ) + if ( + _is_finite_number(value) + and isinstance(count, int) + and count > 0 + ): + self._safe_logger_record( + f"{category}/{metric}_mean", + float(value) / float(count), + exclude=logger_exclude, + ) try: total_timesteps = getattr(self.model, "_total_timesteps", None) @@ -2157,23 +2131,12 @@ class InfoMetricsCallback(TensorboardCallback): else: progress_done = 0.0 progress_remaining = 1.0 - progress_done - try: - self.logger.record("train/progress_done", progress_done) - self.logger.record("train/progress_remaining", progress_remaining) - except Exception: - try: - self.logger.record( - "train/progress_done", - progress_done, - exclude=("tensorboard",), - ) - self.logger.record( - "train/progress_remaining", - progress_remaining, - exclude=("tensorboard",), - ) - except Exception: - pass + self._safe_logger_record( + "train/progress_done", progress_done, exclude=logger_exclude + ) + self._safe_logger_record( + "train/progress_remaining", progress_remaining, exclude=logger_exclude + ) except Exception: progress_remaining = 1.0 @@ -2181,18 +2144,10 @@ class InfoMetricsCallback(TensorboardCallback): lr = getattr(self.model, "learning_rate", None) if callable(lr): lr = lr(progress_remaining) - if isinstance(lr, (int, float)) and np.isfinite(lr): - try: - self.logger.record("train/learning_rate", float(lr)) - except Exception: - try: - self.logger.record( - "train/learning_rate", - float(lr), - exclude=("tensorboard",), - ) - except Exception: - pass + if _is_finite_number(lr): + self._safe_logger_record( + "train/learning_rate", float(lr), exclude=logger_exclude + ) except Exception: pass @@ -2201,18 +2156,10 @@ class InfoMetricsCallback(TensorboardCallback): cr = getattr(self.model, "clip_range", None) if callable(cr): cr = cr(progress_remaining) - if isinstance(cr, (int, float)) and np.isfinite(cr): - try: - self.logger.record("train/clip_range", float(cr)) - except Exception: - try: - self.logger.record( - "train/clip_range", - float(cr), - exclude=("tensorboard",), - ) - except Exception: - pass + if _is_finite_number(cr): + self._safe_logger_record( + "train/clip_range", float(cr), exclude=logger_exclude + ) except Exception: pass @@ -2228,9 +2175,15 @@ class RolloutPlotCallback(BaseCallback): figures = self.training_env.env_method("get_env_plot") for i, fig in enumerate(figures): figure = Figure(fig, close=True) - self.logger.record( - f"best/train_env{i}", figure, exclude=("stdout", "log", "json", "csv") - ) + try: + self.logger.record( + f"best/train_env{i}", + figure, + exclude=("stdout", "log", "json", "csv"), + ) + except Exception as e: + logger.error("logger.record failed at %r: %r", f"best/train_env{i}", e) + pass return True def _on_step(self) -> bool: @@ -2314,8 +2267,17 @@ class MaskableTrialEvalCallback(MaskableEvalCallback): ) if np.isfinite(best_mean_reward): try: - self.logger.record("eval/best_mean_reward", best_mean_reward) - except Exception: + self.logger.record( + "eval/best_mean_reward", + best_mean_reward, + exclude=("stdout", "log", "json", "csv"), + ) + except Exception as e: + logger.error( + "Optuna: logger.record failed at %r: %r", + "eval/best_mean_reward", + e, + ) pass else: logger.warning( diff --git a/quickadapter/user_data/strategies/Utils.py b/quickadapter/user_data/strategies/Utils.py index 47c2376..692f524 100644 --- a/quickadapter/user_data/strategies/Utils.py +++ b/quickadapter/user_data/strategies/Utils.py @@ -718,7 +718,9 @@ def zigzag( regressors = {"xgboost", "lightgbm"} -def get_optuna_callbacks(trial: optuna.trial.Trial, regressor: str) -> list[Callable]: +def get_optuna_callbacks( + trial: optuna.trial.Trial, regressor: str +) -> list[Callable[[optuna.trial.Trial, str], None]]: if regressor == "xgboost": callbacks = [ optuna.integration.XGBoostPruningCallback(trial, "validation_0-rmse") @@ -741,7 +743,7 @@ def fit_regressor( eval_weights: Optional[list[NDArray[np.floating]]], model_training_parameters: dict[str, Any], init_model: Any = None, - callbacks: Optional[list[Callable]] = None, + callbacks: Optional[list[Callable[[optuna.trial.Trial, str], None]]] = None, trial: Optional[optuna.trial.Trial] = None, ) -> Any: if regressor == "xgboost": -- 2.43.0