self._model_params_cache: Optional[Dict[str, Any]] = None
self.unset_unsupported()
+ @staticmethod
+ def is_short_allowed(trading_mode: str) -> bool:
+ if trading_mode in {"margin", "futures"}:
+ return True
+ elif trading_mode == "spot":
+ return False
+ else:
+ raise ValueError(f"Invalid trading_mode: {trading_mode}")
+
@staticmethod
def _normalize_position(position: Any) -> Positions:
if isinstance(position, Positions):
@staticmethod
def get_action_masks(
- position: Positions, force_action: Optional[ForceActions] = None
+ trading_mode: str,
+ position: Positions,
+ force_action: Optional[ForceActions] = None,
) -> NDArray[np.bool_]:
+ is_short_allowed = ReforceXY.is_short_allowed(trading_mode)
position = ReforceXY._normalize_position(position)
action_masks = np.zeros(len(Actions), dtype=np.bool_)
action_masks[Actions.Neutral.value] = True
if position == Positions.Neutral:
action_masks[Actions.Long_enter.value] = True
- action_masks[Actions.Short_enter.value] = True
+ if is_short_allowed:
+ action_masks[Actions.Short_enter.value] = True
elif position == Positions.Long:
action_masks[Actions.Long_exit.value] = True
elif position == Positions.Short:
:param model: Any = the trained model used to inference the features.
"""
- simulated_position: Positions = Positions.Neutral
+ virtual_position: Positions = Positions.Neutral
- def _update_simulated_position(action: int, position: Positions) -> Positions:
+ def _update_virtual_position(action: int, position: Positions) -> Positions:
if action == Actions.Long_enter.value and position == Positions.Neutral:
return Positions.Long
if action == Actions.Short_enter.value and position == Positions.Neutral:
if self.action_masking and self.inference_masking:
action_masks_param["action_masks"] = ReforceXY.get_action_masks(
- simulated_position
+ self.config.get("trading_mode"), virtual_position
)
action, _ = model.predict(
window = dataframe.iloc[window_end - self.CONV_WIDTH : window_end]
action = _predict(window)
predicted_actions.append(action)
- simulated_position = _update_simulated_position(action, simulated_position)
+ virtual_position = _update_virtual_position(action, virtual_position)
pad = [np.nan] * (self.CONV_WIDTH - 1)
actions_list = pad + predicted_actions
else self.get_storage()
)
if "PPO" in self.model_type:
- resource_eval_freq = max(PPO_N_STEPS)
+ resource_eval_freq = min(PPO_N_STEPS)
else:
resource_eval_freq = self.get_eval_freq(total_timesteps, hyperopt=True)
reduction_factor = 3
train_env = DummyVecEnv(train_fns)
eval_env = DummyVecEnv(eval_fns)
- train_env = VecMonitor(train_env)
- eval_env = VecMonitor(eval_env)
-
if self.frame_stacking:
train_env = VecFrameStack(train_env, n_stack=self.frame_stacking)
eval_env = VecFrameStack(eval_env, n_stack=self.frame_stacking)
+ train_env = VecMonitor(train_env)
+ eval_env = VecMonitor(eval_env)
+
return train_env, eval_env
def objective(
)
def _is_valid(self, action: int) -> bool:
- return ReforceXY.get_action_masks(self._position, self._force_action)[action]
+ return ReforceXY.get_action_masks(
+ self.config.get("trading_mode"), self._position, self._force_action
+ )[action]
def reset_env(
self,
)
def action_masks(self) -> NDArray[np.bool_]:
- return ReforceXY.get_action_masks(self._position, self._force_action)
+ return ReforceXY.get_action_masks(
+ self.config.get("trading_mode"), self._position, self._force_action
+ )
def get_feature_value(
self,
right_index=True,
how="left",
)
- except Exception:
- try:
- _price_history = (
- self.prices.iloc[_rollout_history.tick]
- .copy()
- .reset_index(drop=True)
- )
- history = merge(
- _rollout_history,
- _price_history,
- left_index=True,
- right_index=True,
- )
- except Exception as e:
- logger.error(
- f"Failed to merge history with prices: {repr(e)}",
- exc_info=True,
- )
- return DataFrame()
+ except Exception as e:
+ logger.error(
+ f"Failed to merge history with prices: {repr(e)}",
+ exc_info=True,
+ )
+ return DataFrame()
return history
def get_env_plot(self) -> plt.Figure: