import logging
import time
import warnings
+from collections import defaultdict
from collections.abc import Mapping
from enum import IntEnum
from functools import lru_cache
from pathlib import Path
+from statistics import stdev
from typing import Any, Callable, Dict, Optional, Tuple, Type
import matplotlib
if self.check_envs:
logger.info("Checking environments...")
- check_env(
- self.MyRLEnv(
- id="train_env_check", df=train_df, prices=prices_train, **env_dict
- )
+ _train_env_check = self.MyRLEnv(
+ id="train_env_check", df=train_df, prices=prices_train, **env_dict
)
- check_env(
- self.MyRLEnv(
- id="eval_env_check", df=test_df, prices=prices_test, **env_dict
- )
+ try:
+ check_env(_train_env_check)
+ finally:
+ _train_env_check.close()
+ _eval_env_check = self.MyRLEnv(
+ id="eval_env_check", df=test_df, prices=prices_test, **env_dict
)
+ try:
+ check_env(_eval_env_check)
+ finally:
+ _eval_env_check.close()
logger.info("Populating environments: %s", self.n_envs)
train_env = DummyVecEnv(
for i in range(self.n_envs)
]
)
+ if self.frame_stacking == 1:
+ logger.warning(
+ "frame_stacking=1 is equivalent to no stacking; use >=2 or 0"
+ )
if self.frame_stacking:
logger.info("Frame stacking: %s", self.frame_stacking)
train_env = VecFrameStack(train_env, n_stack=self.frame_stacking)
best_trial_params = self.study(train_df, total_timesteps, dk)
if best_trial_params is None:
logger.error(
- "Hyperopt failed. Using default configured model params instead."
+ "Hyperopt failed. Using default configured model params instead"
)
best_trial_params = self.get_model_params()
model_params = best_trial_params
)
else:
logger.info(
- "Continual training activated - starting training from previously trained agent."
+ "Continual training activated - starting training from previously trained agent"
)
model = self.dd.model_dictionary[dk.pair]
model.set_env(self.train_env)
model_filename = dk.model_filename if dk.model_filename else "best"
model_path = Path(dk.data_path / f"{model_filename}_model.zip")
if model_path.is_file():
- logger.info(f"Callback found a best model: {model_path}.")
+ logger.info(f"Callback found a best model: {model_path}")
try:
best_model = self.MODELCLASS.load(
dk.data_path / f"{model_filename}_model"
except Exception as e:
logger.error(f"Error loading best model: {repr(e)}", exc_info=True)
- logger.info("Couldn't find best model, using final model instead.")
+ logger.info("Couldn't find best model, using final model instead")
return model
Get environment data from the first to the last trade
"""
if not self.history or not self.trade_history:
- logger.warning("History or trade history is empty.")
+ logger.warning("history or trade_history is empty")
return DataFrame()
_history_df = DataFrame.from_dict(self.history)
"tick" not in _history_df.columns
or "tick" not in _trade_history_df.columns
):
- logger.warning(
- "'tick' column is missing from history or trade history."
- )
+ logger.warning("'tick' column is missing from history or trade history")
return DataFrame()
_rollout_history = merge(
"n_epochs": self.model.n_epochs,
"ent_coef": self.model.ent_coef,
"vf_coef": self.model.vf_coef,
+ "max_grad_norm": self.model.max_grad_norm,
}
)
+ if getattr(self.model, "target_kl", None) is not None:
+ hparam_dict["target_kl"] = self.model.target_kl
if "DQN" in self.model.__class__.__name__:
hparam_dict.update(
{
)
def _on_step(self) -> bool:
- local_info = self.locals.get("infos", [{}])[0]
+ infos_list: list[dict[str, Any]] | None = self.locals.get("infos")
+ aggregated_info: Dict[str, Any] = {}
+ if infos_list:
+ numeric_acc: dict[str, list[float]] = defaultdict(list)
+ non_numeric_acc: dict[str, set[Any]] = defaultdict(set)
+ for info_dict in infos_list:
+ if not isinstance(info_dict, dict):
+ continue
+ for k, v in info_dict.items():
+ if k in ("episode", "terminal_observation", "TimeLimit.truncated"):
+ continue
+ if isinstance(v, (int, float)) and not isinstance(v, bool):
+ numeric_acc[k].append(float(v))
+ else:
+ non_numeric_acc[k].add(v)
+ for k, values in numeric_acc.items():
+ if not values:
+ continue
+ mean_v = sum(values) / len(values)
+ aggregated_info[k] = mean_v
+ if len(values) > 1:
+ try:
+ aggregated_info[f"{k}_std"] = stdev(values)
+ except Exception:
+ pass
+ for k, values in non_numeric_acc.items():
+ if len(values) == 1:
+ aggregated_info[k] = next(iter(values))
+ else:
+ aggregated_info[k] = "mixed"
+
if self.training_env is None:
return True
- tensorboard_metrics = self.training_env.get_attr("tensorboard_metrics")[0]
- for metric in local_info:
- if metric not in ["episode", "terminal_observation", "TimeLimit.truncated"]:
- self.logger.record(f"info/{metric}", local_info.get(metric))
- for category in tensorboard_metrics:
- for metric in tensorboard_metrics.get(category, {}):
- self.logger.record(
- f"{category}/{metric}",
- tensorboard_metrics.get(category, {}).get(metric),
- )
+
+ try:
+ tensorboard_metrics_list = self.training_env.get_attr("tensorboard_metrics")
+ except Exception:
+ tensorboard_metrics_list = []
+ tensorboard_metrics = (
+ tensorboard_metrics_list[0] if tensorboard_metrics_list else {}
+ )
+
+ for metric, value in aggregated_info.items():
+ self.logger.record(f"info/{metric}", value)
+
+ for category, metrics in tensorboard_metrics.items():
+ if not isinstance(metrics, dict):
+ continue
+ for metric, value in metrics.items():
+ self.logger.record(f"{category}/{metric}", value)
return True
self.is_pruned = False
def _on_step(self) -> bool:
+ if self.is_pruned:
+ return False
+
if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
super()._on_step()
self.eval_idx += 1
+ last_mean_reward = getattr(self, "last_mean_reward", np.nan)
+ if not isinstance(last_mean_reward, (int, float)) or not np.isfinite(
+ float(last_mean_reward)
+ ):
+ self.is_pruned = True
+ return False
if hasattr(self.trial, "report"):
- self.trial.report(self.last_mean_reward, self.eval_idx)
-
- # Prune trial if needed
+ try:
+ self.trial.report(last_mean_reward, self.eval_idx)
+ except Exception:
+ self.is_pruned = True
+ return False
if hasattr(self.trial, "should_prune") and self.trial.should_prune():
self.is_pruned = True
return False