From 00b07d25043176584217ac4b4eb6895a9be9d7e3 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Wed, 17 Sep 2025 16:34:49 +0200 Subject: [PATCH] refactor(reforcexy): align variable name MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/user_data/freqaimodels/ReforceXY.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index 96b25b8..ee4cf36 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -1659,7 +1659,9 @@ class InfoMetricsCallback(TensorboardCallback): tensorboard_metrics_list = [] aggregated_tensorboard_metrics: Dict[str, Dict[str, Any]] = defaultdict(dict) - aggregate_tensorboard_counts: Dict[str, Dict[str, int]] = defaultdict(dict) + aggregated_tensorboard_metric_counts: Dict[str, Dict[str, int]] = defaultdict( + dict + ) for env_metrics in tensorboard_metrics_list or []: if not isinstance(env_metrics, dict): continue @@ -1667,7 +1669,7 @@ class InfoMetricsCallback(TensorboardCallback): if not isinstance(metrics, dict): continue cat_dict = aggregated_tensorboard_metrics.setdefault(category, {}) - cnt_dict = aggregate_tensorboard_counts.setdefault(category, {}) + cnt_dict = aggregated_tensorboard_metric_counts.setdefault(category, {}) for metric, value in metrics.items(): if _is_finite_number(value): v = float(value) @@ -1679,7 +1681,7 @@ class InfoMetricsCallback(TensorboardCallback): cnt_dict[metric] = cnt_dict.get(metric, 0) + 1 else: if ( - aggregate_tensorboard_counts.get(category, {}).get( + aggregated_tensorboard_metric_counts.get(category, {}).get( metric, 0 ) == 0 @@ -1748,17 +1750,17 @@ class InfoMetricsCallback(TensorboardCallback): except Exception: pass try: - count = aggregate_tensorboard_counts.get(category, {}).get( - metric - ) + count = aggregated_tensorboard_metric_counts.get( + category, {} + ).get(metric) if isinstance(value, (int, float)) and count and count > 0: mean = float(value) / float(count) self.logger.record(f"{category}/{metric}_mean", mean) except Exception: try: - count = aggregate_tensorboard_counts.get(category, {}).get( - metric - ) + count = aggregated_tensorboard_metric_counts.get( + category, {} + ).get(metric) if isinstance(value, (int, float)) and count and count > 0: mean = float(value) / float(count) self.logger.record( -- 2.43.0