"""
Set training and evaluation environments
"""
- self.close_envs()
+ if self.train_env is not None or self.eval_env is not None:
+ logger.info("Closing environments")
+ self.close_envs()
train_df = data_dictionary.get("train_features")
test_df = data_dictionary.get("test_features")
logger.warning("Optuna encountered NaN (AssertionError)")
nan_encountered = True
except ValueError as e:
- if "NaN" in str(e):
+ if "nan" in str(e).lower():
logger.warning("Optuna encountered NaN (ValueError)")
nan_encountered = True
else:
raise
+ except FloatingPointError as e:
+ logger.warning("Optuna encountered NaN/Inf (FloatingPointError): %s", e)
+ nan_encountered = True
+ except RuntimeError as e:
+ if "nan" in str(e).lower() or "inf" in str(e).lower():
+ logger.warning("Optuna encountered NaN/Inf (RuntimeError): %s", e)
+ nan_encountered = True
+ else:
+ raise
finally:
if self.progressbar_callback:
self.progressbar_callback.on_training_end()
"""
self._current_tick += 1
self._update_unrealized_total_profit()
+ pnl = self.get_unrealized_profit()
+ self._update_portfolio_log_returns()
self._force_action = self._get_force_action()
reward = self.calculate_reward(action)
self.total_reward += reward
"force_action": (
self._force_action.name if self._force_action else None
),
- "pnl": self.get_unrealized_profit(),
+ "pnl": round(pnl, 5),
"reward": round(reward, 5),
"total_reward": round(self.total_reward, 5),
"total_profit": round(self._total_profit, 5),
"idle_duration": self.get_idle_duration(),
"trade_duration": self.get_trade_duration(),
- "trade_count": len(self.trade_history),
+ "trade_count": int(len(self.trade_history) // 2),
}
self._update_history(info)
return (
return np.log(previous_price) - np.log(current_price)
return 0.0
- def update_portfolio_log_returns(self):
+ def _update_portfolio_log_returns(self):
self.portfolio_log_returns[self._current_tick] = (
self.get_most_recent_return()
)
fig.suptitle(
f"Total Reward: {self.total_reward:.2f} ~ "
+ f"Total Profit: {self._total_profit:.2f} ~ "
- + f"Trades: {len(self.trade_history)}"
+ + f"Trades: {int(len(self.trade_history) // 2)}",
)
fig.tight_layout()
return fig
if isinstance(infos_list, list) and infos_list:
numeric_acc: Dict[str, list[float]] = defaultdict(list)
- non_numeric_acc: Dict[str, set[Any]] = defaultdict(set)
+ non_numeric_counts: Dict[str, Dict[Any, int]] = defaultdict(
+ lambda: defaultdict(int)
+ )
+ filtered_values: int = 0
+
+ def _is_numeric_non_bool(x: Any) -> bool:
+ return isinstance(
+ x, (int, float, np.integer, np.floating)
+ ) and not isinstance(x, bool)
+
+ def _is_finite_number(x: Any) -> bool:
+ if not _is_numeric_non_bool(x):
+ return False
+ try:
+ return np.isfinite(float(x))
+ except Exception:
+ return False
for info_dict in infos_list:
if not isinstance(info_dict, dict):
for k, v in info_dict.items():
if k in {"episode", "terminal_observation", "TimeLimit.truncated"}:
continue
- if isinstance(v, (int, float)) and not isinstance(v, bool):
+ if _is_finite_number(v):
numeric_acc[k].append(float(v))
+ elif _is_numeric_non_bool(v):
+ filtered_values += 1
else:
- non_numeric_acc[k].add(v)
+ non_numeric_counts[k][v] += 1
for k, values in numeric_acc.items():
if not values:
continue
- values_mean = sum(values) / len(values)
- aggregated_info[k] = values_mean
+ mean = sum(values) / len(values)
+ aggregated_info[k] = mean
if len(values) > 1:
try:
aggregated_info[f"{k}_std"] = stdev(values)
for key in ("reward", "pnl"):
values = numeric_acc.get(key)
if values:
- aggregated_info[f"{key}_min"] = float(min(values))
- aggregated_info[f"{key}_max"] = float(max(values))
+ try:
+ aggregated_info[f"{key}_min"] = float(min(values))
+ aggregated_info[f"{key}_max"] = float(max(values))
+ percentiles = np.percentile(values, [25, 50, 75, 90])
+ aggregated_info[f"{key}_p25"] = float(percentiles[0])
+ aggregated_info[f"{key}_p50"] = float(percentiles[1])
+ aggregated_info[f"{key}_p75"] = float(percentiles[2])
+ aggregated_info[f"{key}_p90"] = float(percentiles[3])
+ med = float(percentiles[1])
+ mad = float(np.median(np.abs(np.array(values) - med)))
+ aggregated_info[f"{key}_mad"] = mad
+ except Exception:
+ pass
- for k, values in non_numeric_acc.items():
- aggregated_info[k] = next(iter(values)) if len(values) == 1 else "mixed"
+ for k, counts in non_numeric_counts.items():
+ if not counts:
+ continue
+ if len(counts) == 1:
+ try:
+ aggregated_info[f"{k}_mode"] = next(iter(counts.keys()))
+ except Exception:
+ pass
+ else:
+ aggregated_info[f"{k}_mode"] = "mixed"
+
+ try:
+ self.logger.record("info/n_envs", int(len(infos_list)))
+ except Exception:
+ pass
+
+ if filtered_values > 0:
+ try:
+ self.logger.record("info/filtered_values", int(filtered_values))
+ except Exception:
+ try:
+ self.logger.record(
+ "info/filtered_values",
+ int(filtered_values),
+ exclude=("tensorboard",),
+ )
+ except Exception:
+ pass
if self.training_env is None:
return True
try:
tensorboard_metrics_list = self.training_env.get_attr("tensorboard_metrics")
- tensorboard_metrics = (
- tensorboard_metrics_list[0] if tensorboard_metrics_list else {}
- )
except Exception:
- tensorboard_metrics = {}
+ tensorboard_metrics_list = []
+
+ aggregated_tensorboard_metrics: Dict[str, Dict[str, Any]] = defaultdict(dict)
+ aggregate_tensorboard_counts: Dict[str, Dict[str, int]] = defaultdict(dict)
+ for env_metrics in tensorboard_metrics_list or []:
+ if not isinstance(env_metrics, dict):
+ continue
+ for category, metrics in env_metrics.items():
+ if not isinstance(metrics, dict):
+ continue
+ cat_dict = aggregated_tensorboard_metrics.setdefault(category, {})
+ cnt_dict = aggregate_tensorboard_counts.setdefault(category, {})
+ for metric, value in metrics.items():
+ if _is_finite_number(value):
+ v = float(value)
+ try:
+ base = float(cat_dict.get(metric, 0.0))
+ except Exception:
+ base = 0.0
+ cat_dict[metric] = base + v
+ cnt_dict[metric] = cnt_dict.get(metric, 0) + 1
+ else:
+ if (
+ aggregate_tensorboard_counts.get(category, {}).get(
+ metric, 0
+ )
+ == 0
+ ):
+ cat_dict[metric] = value
for metric, value in aggregated_info.items():
- self.logger.record(f"info/{metric}", value)
+ try:
+ self.logger.record(f"info/{metric}", value)
+ except Exception:
+ try:
+ self.logger.record(
+ f"info/{metric}", value, exclude=("tensorboard",)
+ )
+ except Exception:
+ pass
- for category, metrics in tensorboard_metrics.items():
+ if isinstance(infos_list, list) and infos_list:
+ cat_keys = ("force_action", "action", "position")
+ cat_counts: Dict[str, Dict[Any, int]] = {
+ k: defaultdict(int) for k in cat_keys
+ }
+ cat_totals: Dict[str, int] = {k: 0 for k in cat_keys}
+ for info_dict in infos_list:
+ if not isinstance(info_dict, dict):
+ continue
+ for k in cat_keys:
+ if k in info_dict:
+ v = info_dict.get(k)
+ cat_counts[k][v] += 1
+ cat_totals[k] += 1
+
+ for k, counts in cat_counts.items():
+ cat_total = max(1, int(cat_totals.get(k, 0)))
+ for name, cnt in counts.items():
+ try:
+ self.logger.record(f"info/{k}/{name}_count", int(cnt))
+ self.logger.record(
+ f"info/{k}/{name}_ratio", float(cnt) / float(cat_total)
+ )
+ except Exception:
+ try:
+ self.logger.record(
+ f"info/{k}/{name}_count",
+ int(cnt),
+ exclude=("tensorboard",),
+ )
+ self.logger.record(
+ f"info/{k}/{name}_ratio",
+ float(cnt) / float(cat_total),
+ exclude=("tensorboard",),
+ )
+ except Exception:
+ pass
+
+ for category, metrics in aggregated_tensorboard_metrics.items():
if isinstance(metrics, dict):
for metric, value in metrics.items():
- self.logger.record(f"{category}/{metric}", value)
+ try:
+ self.logger.record(f"{category}/{metric}", value)
+ except Exception:
+ try:
+ self.logger.record(
+ f"{category}/{metric}", value, exclude=("tensorboard",)
+ )
+ except Exception:
+ pass
+ try:
+ count = aggregate_tensorboard_counts.get(category, {}).get(
+ metric
+ )
+ if isinstance(value, (int, float)) and count and count > 0:
+ mean = float(value) / float(count)
+ self.logger.record(f"{category}/{metric}_mean", mean)
+ except Exception:
+ try:
+ count = aggregate_tensorboard_counts.get(category, {}).get(
+ metric
+ )
+ if isinstance(value, (int, float)) and count and count > 0:
+ mean = float(value) / float(count)
+ self.logger.record(
+ f"{category}/{metric}_mean",
+ mean,
+ exclude=("tensorboard",),
+ )
+ except Exception:
+ pass
return True
for i, fig in enumerate(figures):
figure = Figure(fig, close=True)
self.logger.record(
- f"best/train_env_{i}", figure, exclude=("stdout", "log", "json", "csv")
+ f"best/train_env{i}", figure, exclude=("stdout", "log", "json", "csv")
)
return True