From f22d9787188b0cda293cc5423065a1660072efe1 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Tue, 9 Sep 2025 15:15:27 +0200 Subject: [PATCH] feat(reforcexy): make tensorboard multi envs aware MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/user_data/freqaimodels/ReforceXY.py | 119 +++++++++++++----- 1 file changed, 89 insertions(+), 30 deletions(-) diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index 9bee654..9a7a3b6 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -4,10 +4,12 @@ import json import logging import time import warnings +from collections import defaultdict from collections.abc import Mapping from enum import IntEnum from functools import lru_cache from pathlib import Path +from statistics import stdev from typing import Any, Callable, Dict, Optional, Tuple, Type import matplotlib @@ -174,16 +176,20 @@ class ReforceXY(BaseReinforcementLearningModel): if self.check_envs: logger.info("Checking environments...") - check_env( - self.MyRLEnv( - id="train_env_check", df=train_df, prices=prices_train, **env_dict - ) + _train_env_check = self.MyRLEnv( + id="train_env_check", df=train_df, prices=prices_train, **env_dict ) - check_env( - self.MyRLEnv( - id="eval_env_check", df=test_df, prices=prices_test, **env_dict - ) + try: + check_env(_train_env_check) + finally: + _train_env_check.close() + _eval_env_check = self.MyRLEnv( + id="eval_env_check", df=test_df, prices=prices_test, **env_dict ) + try: + check_env(_eval_env_check) + finally: + _eval_env_check.close() logger.info("Populating environments: %s", self.n_envs) train_env = DummyVecEnv( @@ -214,6 +220,10 @@ class ReforceXY(BaseReinforcementLearningModel): for i in range(self.n_envs) ] ) + if self.frame_stacking == 1: + logger.warning( + "frame_stacking=1 is equivalent to no stacking; use >=2 or 0" + ) if self.frame_stacking: logger.info("Frame stacking: %s", self.frame_stacking) train_env = VecFrameStack(train_env, n_stack=self.frame_stacking) @@ -376,7 +386,7 @@ class ReforceXY(BaseReinforcementLearningModel): best_trial_params = self.study(train_df, total_timesteps, dk) if best_trial_params is None: logger.error( - "Hyperopt failed. Using default configured model params instead." + "Hyperopt failed. Using default configured model params instead" ) best_trial_params = self.get_model_params() model_params = best_trial_params @@ -400,7 +410,7 @@ class ReforceXY(BaseReinforcementLearningModel): ) else: logger.info( - "Continual training activated - starting training from previously trained agent." + "Continual training activated - starting training from previously trained agent" ) model = self.dd.model_dictionary[dk.pair] model.set_env(self.train_env) @@ -422,7 +432,7 @@ class ReforceXY(BaseReinforcementLearningModel): model_filename = dk.model_filename if dk.model_filename else "best" model_path = Path(dk.data_path / f"{model_filename}_model.zip") if model_path.is_file(): - logger.info(f"Callback found a best model: {model_path}.") + logger.info(f"Callback found a best model: {model_path}") try: best_model = self.MODELCLASS.load( dk.data_path / f"{model_filename}_model" @@ -431,7 +441,7 @@ class ReforceXY(BaseReinforcementLearningModel): except Exception as e: logger.error(f"Error loading best model: {repr(e)}", exc_info=True) - logger.info("Couldn't find best model, using final model instead.") + logger.info("Couldn't find best model, using final model instead") return model @@ -1251,7 +1261,7 @@ class ReforceXY(BaseReinforcementLearningModel): Get environment data from the first to the last trade """ if not self.history or not self.trade_history: - logger.warning("History or trade history is empty.") + logger.warning("history or trade_history is empty") return DataFrame() _history_df = DataFrame.from_dict(self.history) @@ -1261,9 +1271,7 @@ class ReforceXY(BaseReinforcementLearningModel): "tick" not in _history_df.columns or "tick" not in _trade_history_df.columns ): - logger.warning( - "'tick' column is missing from history or trade history." - ) + logger.warning("'tick' column is missing from history or trade history") return DataFrame() _rollout_history = merge( @@ -1404,8 +1412,11 @@ class InfoMetricsCallback(TensorboardCallback): "n_epochs": self.model.n_epochs, "ent_coef": self.model.ent_coef, "vf_coef": self.model.vf_coef, + "max_grad_norm": self.model.max_grad_norm, } ) + if getattr(self.model, "target_kl", None) is not None: + hparam_dict["target_kl"] = self.model.target_kl if "DQN" in self.model.__class__.__name__: hparam_dict.update( { @@ -1434,19 +1445,56 @@ class InfoMetricsCallback(TensorboardCallback): ) def _on_step(self) -> bool: - local_info = self.locals.get("infos", [{}])[0] + infos_list: list[dict[str, Any]] | None = self.locals.get("infos") + aggregated_info: Dict[str, Any] = {} + if infos_list: + numeric_acc: dict[str, list[float]] = defaultdict(list) + non_numeric_acc: dict[str, set[Any]] = defaultdict(set) + for info_dict in infos_list: + if not isinstance(info_dict, dict): + continue + for k, v in info_dict.items(): + if k in ("episode", "terminal_observation", "TimeLimit.truncated"): + continue + if isinstance(v, (int, float)) and not isinstance(v, bool): + numeric_acc[k].append(float(v)) + else: + non_numeric_acc[k].add(v) + for k, values in numeric_acc.items(): + if not values: + continue + mean_v = sum(values) / len(values) + aggregated_info[k] = mean_v + if len(values) > 1: + try: + aggregated_info[f"{k}_std"] = stdev(values) + except Exception: + pass + for k, values in non_numeric_acc.items(): + if len(values) == 1: + aggregated_info[k] = next(iter(values)) + else: + aggregated_info[k] = "mixed" + if self.training_env is None: return True - tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0] - for metric in local_info: - if metric not in ["episode", "terminal_observation", "TimeLimit.truncated"]: - self.logger.record(f"info/{metric}", local_info.get(metric)) - for category in tensorboard_metrics: - for metric in tensorboard_metrics.get(category, {}): - self.logger.record( - f"{category}/{metric}", - tensorboard_metrics.get(category, {}).get(metric), - ) + + try: + tensorboard_metrics_list = self.training_env.get_attr("tensorboard_metrics") + except Exception: + tensorboard_metrics_list = [] + tensorboard_metrics = ( + tensorboard_metrics_list[0] if tensorboard_metrics_list else {} + ) + + for metric, value in aggregated_info.items(): + self.logger.record(f"info/{metric}", value) + + for category, metrics in tensorboard_metrics.items(): + if not isinstance(metrics, dict): + continue + for metric, value in metrics.items(): + self.logger.record(f"{category}/{metric}", value) return True @@ -1498,13 +1546,24 @@ class MaskableTrialEvalCallback(MaskableEvalCallback): self.is_pruned = False def _on_step(self) -> bool: + if self.is_pruned: + return False + if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0: super()._on_step() self.eval_idx += 1 + last_mean_reward = getattr(self, "last_mean_reward", np.nan) + if not isinstance(last_mean_reward, (int, float)) or not np.isfinite( + float(last_mean_reward) + ): + self.is_pruned = True + return False if hasattr(self.trial, "report"): - self.trial.report(self.last_mean_reward, self.eval_idx) - - # Prune trial if needed + try: + self.trial.report(last_mean_reward, self.eval_idx) + except Exception: + self.is_pruned = True + return False if hasattr(self.trial, "should_prune") and self.trial.should_prune(): self.is_pruned = True return False -- 2.43.0