"""
self.close_envs()
+ if not isinstance(self.n_envs, int) or self.n_envs < 1:
+ logger.warning("Invalid n_envs=%s. Forcing n_envs=1", self.n_envs)
+ self.n_envs = 1
+ if not isinstance(self.frame_stacking, int) or self.frame_stacking < 0:
+ logger.warning(
+ "Invalid frame_stacking=%s. Forcing frame_stacking=0",
+ self.frame_stacking,
+ )
+ self.frame_stacking = 0
+
train_df = data_dictionary.get("train_features")
test_df = data_dictionary.get("test_features")
env_dict = self.pack_env_dict(dk.pair)
seed = self.get_model_params().get("seed", 42)
+ set_random_seed(seed)
+ logger.info("Seeding RNGs with seed=%s (train), %s (eval)", seed, seed + 10_000)
if self.check_envs:
logger.info("Checking environments...")
self.MyRLEnv,
f"eval_env{i}",
i,
- seed,
+ seed + 10_000,
test_df,
prices_test,
env_info=env_dict,
)
if self.frame_stacking == 1:
logger.warning(
- "frame_stacking=1 is equivalent to no stacking; use >=2 or 0"
+ "frame_stacking=%s is equivalent to no stacking; use >=2 or 0",
+ self.frame_stacking,
)
if self.frame_stacking:
+ logger.info(
+ "Observation space shape pre-stacking: %s",
+ train_env.observation_space.shape,
+ )
logger.info("Frame stacking: %s", self.frame_stacking)
train_env = VecFrameStack(train_env, n_stack=self.frame_stacking)
eval_env = VecFrameStack(eval_env, n_stack=self.frame_stacking)
+ logger.info(
+ "Observation space shape post-stacking: %s",
+ train_env.observation_space.shape,
+ )
self.train_env = VecMonitor(train_env)
- if self.frame_stacking and not self.train_env.observation_space.shape:
- raise ValueError("Frame stacking requires predefined observation shape")
self.eval_env = VecMonitor(eval_env)
def get_model_params(self) -> Dict[str, Any]:
:return:
model Any = trained model to be used for inference in dry/live/backtesting
"""
+ if self.train_env is None or self.eval_env is None:
+ raise RuntimeError("Environments not set. Cannot run model training")
train_df = data_dictionary.get("train_features")
train_timesteps = len(train_df)
test_df = data_dictionary.get("test_features")
:param model: Any = the trained model used to inference the features.
"""
- def _is_valid(action: int, position: float) -> bool:
+ def _normalize_position(position: Any) -> Positions:
+ if isinstance(position, Positions):
+ return position
+ try:
+ f = float(position)
+ if f == float(Positions.Long.value):
+ return Positions.Long
+ if f == float(Positions.Short.value):
+ return Positions.Short
+ return Positions.Neutral
+ except Exception:
+ return Positions.Neutral
+
+ def _is_valid(action: int, position: Any) -> bool:
"""
Determine if the action is valid for the step
"""
+ position = _normalize_position(position)
# Agent should only try to exit if it is in position
if action in (Actions.Short_exit.value, Actions.Long_exit.value):
if position not in (Positions.Short, Positions.Long):
return True
- def _action_masks(position: float) -> list[bool]:
+ def _action_masks(position: Any) -> list[bool]:
return [_is_valid(action.value, position) for action in Actions]
def _predict(window) -> int:
- observation: DataFrame = dataframe.iloc[window.index]
+ observation: DataFrame = dataframe.loc[window.index]
action_masks_param: Dict[str, Any] = {}
if self.rl_config.get("add_state_info", False):
fb_padded = [fb[0]] * pad_needed + fb
else:
fb_padded = fb
- stacked_observations = np.stack(fb_padded, axis=0)
+ stacked_observations = np.concatenate(fb_padded, axis=1)
observations = stacked_observations.reshape(1, -1)
else:
observations = np_observation.reshape(1, -1)
"""
Defines a single trial for hyperparameter optimization using Optuna
"""
+ if self.train_env is None or self.eval_env is None:
+ raise RuntimeError("Environments not set. Cannot run HPO model training")
if "PPO" in self.model_type:
params = sample_params_ppo(trial, self.n_envs)
if params.get("n_steps", 0) > total_timesteps:
finally:
if self.progressbar_callback:
self.progressbar_callback.on_training_end()
- self.close_envs()
- if hasattr(model, "env") and model.env is not None:
- model.env.close()
if nan_encountered:
raise TrialPruned("NaN encountered during training")
Closes the training and evaluation environments if they are open
"""
if self.train_env:
- self.train_env.close()
+ try:
+ self.train_env.close()
+ finally:
+ self.train_env = None
if self.eval_env:
- self.eval_env.close()
+ try:
+ self.eval_env.close()
+ finally:
+ self.eval_env = None
class MyRLEnv(Base5ActionRLEnv):
"""
Resets the environment when the agent fails
"""
super().reset_env(df, prices, window_size, reward_kwargs, starting_point)
+ base_features = self.signal_features.shape[1]
if self.add_state_info:
# STATE_INFO
self.state_features = ["pnl", "position", "trade_duration"]
- self.total_features = self.signal_features.shape[1] + len(
- self.state_features
- )
- self.shape = (window_size, self.total_features)
+ self.total_features = base_features + len(self.state_features)
+ else:
+ self.state_features = []
+ self.total_features = base_features
+ self.shape = (window_size, self.total_features)
self.observation_space = Box(
low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float32
)
"trade_count": len(self.trade_history),
}
self.execute_trade(action)
+ info["position"] = self._position.value
+ info["force_action"] = (
+ self._force_action.name if self._force_action else None
+ )
+ info["pnl"] = self.get_unrealized_profit()
+ info["trade_duration"] = self.get_trade_duration()
+ info["trade_count"] = len(self.trade_history)
self._position_history.append(self._position)
self._update_history(info)
return (
_lr = (
float(_lr) if isinstance(_lr, (int, float, np.floating)) else "lr_schedule"
)
+ n_stack = 1
+ if self.training_env is not None and hasattr(self.training_env, "get_attr"):
+ try:
+ stacks = self.training_env.get_attr("n_stack")
+ if isinstance(stacks, (list, tuple)) and stacks and stacks[0]:
+ n_stack = int(stacks[0])
+ except Exception:
+ pass
hparam_dict: Dict[str, Any] = {
"algorithm": self.model.__class__.__name__,
+ "n_envs": int(self.model.n_envs),
+ "n_stack": n_stack,
"learning_rate": _lr,
"gamma": float(self.model.gamma),
"batch_size": int(self.model.batch_size),
df=train_df, prices=price, id=env_id, seed=seed + rank, **env_info
)
- set_random_seed(seed)
return _init