self.unset_unsupported()
@staticmethod
- def build_action_mask(
+ @lru_cache(maxsize=8)
+ def get_action_masks(
position: Positions, force_action: Optional[ForceActions] = None
- ) -> np.ndarray:
- action_mask = np.zeros(len(Actions), dtype=bool)
+ ) -> NDArray[bool]:
+ action_masks = np.zeros(len(Actions), dtype=bool)
- action_mask[Actions.Neutral.value] = True
+ if force_action is not None and position in (Positions.Long, Positions.Short):
+ if position == Positions.Long:
+ action_masks[Actions.Long_exit.value] = True
+ else:
+ action_masks[Actions.Short_exit.value] = True
+ return action_masks
+ action_masks[Actions.Neutral.value] = True
if position == Positions.Neutral:
- action_mask[Actions.Long_enter.value] = True
- action_mask[Actions.Short_enter.value] = True
+ action_masks[Actions.Long_enter.value] = True
+ action_masks[Actions.Short_enter.value] = True
elif position == Positions.Long:
- action_mask[Actions.Long_exit.value] = True
+ action_masks[Actions.Long_exit.value] = True
elif position == Positions.Short:
- action_mask[Actions.Short_exit.value] = True
-
- if force_action is not None and position in (Positions.Long, Positions.Short):
- force_action_mask = np.zeros(len(Actions), dtype=bool)
- try:
- if position == Positions.Long:
- force_action_mask[Actions.Long_exit.value] = True
- elif position == Positions.Short:
- force_action_mask[Actions.Short_exit.value] = True
- except Exception:
- return action_mask
- if force_action_mask.any():
- return force_action_mask
- return action_mask
+ action_masks[Actions.Short_exit.value] = True
- if not action_mask.any():
- try:
- action_mask[Actions.Neutral.value] = True
- except Exception:
- action_mask = np.ones_like(action_mask, dtype=bool)
- return action_mask
+ return action_masks
def unset_unsupported(self) -> None:
"""
observations = np_observation.reshape(1, -1)
if self.action_masking and self.inference_masking:
- action_masks_param["action_masks"] = ReforceXY.build_action_mask(
+ action_masks_param["action_masks"] = ReforceXY.get_action_masks(
simulated_position, None
)
low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float32
)
+ def _is_valid(self, action: int) -> bool:
+ return ReforceXY.get_action_masks(self._position, self._force_action)[
+ action
+ ]
+
def reset_env(
self,
df: DataFrame,
of weights in NN)
"""
# first, penalize if the action is not valid
- if (
- not self.action_masking
- and not self._force_action
- and not self._is_valid(action)
- ):
+ if not self.action_masking and not self._is_valid(action):
self.tensorboard_log("invalid", category="actions")
return self.rl_config.get("model_reward_parameters", {}).get(
"invalid_action", -2.0
)
)
- def action_masks(self):
- try:
- return ReforceXY.build_action_mask(self._position, self._force_action)
- except Exception:
- action_mask = np.zeros(len(Actions), dtype=bool)
- try:
- action_mask[Actions.Neutral.value] = True
- except Exception:
- action_mask = np.ones(len(Actions), dtype=bool)
- return action_mask
+ def action_masks(self) -> NDArray[bool]:
+ return ReforceXY.get_action_masks(self._position, self._force_action)
def get_feature_value(
self,