return True
- def _action_masks(position: float):
+ def _action_masks(position: float) -> list[bool]:
return [_is_valid(action.value, position) for action in Actions]
- def _predict(window):
+ def _predict(window) -> int:
observation: DataFrame = dataframe.iloc[window.index]
action_masks_param: Dict[str, Any] = {}
- if self.live and self.rl_config.get("add_state_info", False):
- position, pnl, trade_duration = self.get_state_info(dk.pair)
+ if self.rl_config.get("add_state_info", False):
+ if self.live:
+ position, pnl, trade_duration = self.get_state_info(dk.pair)
+ else:
+ # TODO: state info handling for backtests
+ position = Positions.Neutral
+ pnl = 0.0
+ trade_duration = 0
+
# STATE_INFO
observation["pnl"] = pnl
observation["position"] = position
if self.is_maskable:
action_masks_param = {"action_masks": _action_masks(position)}
- observation = observation.to_numpy(dtype=np.float32)
-
- if self.frame_stacking:
- # FIXME: proper observation stacking needs more work
- observations = np.repeat(
- observation, axis=1, repeats=self.frame_stacking
- )
+ np_observation = observation.to_numpy(dtype=np.float32)
+
+ frame_stacking = self.frame_stacking
+ if frame_stacking and frame_stacking > 1:
+ if not hasattr(_predict, "_frame_buffer"):
+ _predict._frame_buffer = []
+ fb: list[np.ndarray] = getattr(_predict, "_frame_buffer")
+ fb.append(np_observation)
+ if len(fb) > frame_stacking:
+ del fb[0 : len(fb) - frame_stacking]
+ if len(fb) < frame_stacking:
+ pad_needed = frame_stacking - len(fb)
+ fb_padded = [fb[0]] * pad_needed + fb
+ else:
+ fb_padded = fb
+ stacked_observations = np.concatenate(fb_padded, axis=0)
+ observations = stacked_observations.flatten()
else:
- observations = observation
+ observations = np_observation.flatten()
action, _ = model.predict(
observations, deterministic=True, **action_masks_param
)
- return action
+ return int(action)
- output = DataFrame(np.zeros(len(dataframe)), columns=dk.label_list)
- output = output.rolling(window=self.CONV_WIDTH).apply(_predict)
+ actions = dataframe.iloc[:, 0].rolling(window=self.CONV_WIDTH).apply(_predict)
+ output = DataFrame({label: actions for label in dk.label_list})
return output
def get_storage(self, pair: Optional[str] = None) -> BaseStorage: