self._model_params_cache: Optional[Dict[str, Any]] = None
self.unset_unsupported()
- @staticmethod
- def _is_short_allowed(trading_mode: str) -> bool:
- if trading_mode in {"margin", "futures"}:
- return True
- elif trading_mode == "spot":
- return False
- else:
- raise ValueError(f"Invalid trading_mode: {trading_mode}")
-
@staticmethod
def _normalize_position(position: Any) -> Positions:
if isinstance(position, Positions):
@staticmethod
def get_action_masks(
- trading_mode: str,
+ can_short: bool,
position: Positions,
force_action: Optional[ForceActions] = None,
) -> NDArray[np.bool_]:
position = ReforceXY._normalize_position(position)
cache_key = (
- trading_mode,
+ can_short,
position.value,
force_action.value if force_action else None,
)
if cache_key in ReforceXY._action_masks_cache:
return ReforceXY._action_masks_cache[cache_key]
- is_short_allowed = ReforceXY._is_short_allowed(trading_mode)
-
action_masks = np.zeros(len(Actions), dtype=np.bool_)
if force_action is not None and position in (Positions.Long, Positions.Short):
action_masks[Actions.Neutral.value] = True
if position == Positions.Neutral:
action_masks[Actions.Long_enter.value] = True
- if is_short_allowed:
+ if can_short:
action_masks[Actions.Short_enter.value] = True
elif position == Positions.Long:
action_masks[Actions.Long_exit.value] = True
if self.action_masking and self.inference_masking:
action_masks_param["action_masks"] = ReforceXY.get_action_masks(
- self.config.get("trading_mode"), virtual_position
+ self.can_short, virtual_position
)
action, _ = model.predict(
def _is_valid(self, action: int) -> bool:
return ReforceXY.get_action_masks(
- self.config.get("trading_mode"), self._position, self._force_action
+ self.can_short, self._position, self._force_action
)[action]
def reset_env(
if not self.is_tradesignal(action):
return None
+ if (
+ action in (Actions.Short_enter.value, Actions.Short_exit.value)
+ and not self.can_short
+ ):
+ return None
+
# Enter trade based on action
if action in (Actions.Long_enter.value, Actions.Short_enter.value):
self._enter_trade(action)
def action_masks(self) -> NDArray[np.bool_]:
return ReforceXY.get_action_masks(
- self.config.get("trading_mode"), self._position, self._force_action
+ self.can_short, self._position, self._force_action
)
def get_feature_value(