action_masks_param = {"action_masks": _action_masks(position)}
np_observation = observation.to_numpy(dtype=np.float32)
+ shape = getattr(self, "shape", None)
+ if shape and np_observation.shape != shape:
+ logger.error(
+ "Frame shape mismatch: got %s expected %s",
+ np_observation.shape,
+ shape,
+ )
+ raise ValueError("Frame shape mismatch")
frame_stacking = self.frame_stacking
if frame_stacking and frame_stacking > 1:
fb_padded = [fb[0]] * pad_needed + fb
else:
fb_padded = fb
- stacked_observations = np.concatenate(fb_padded, axis=0)
- observations = stacked_observations.flatten()
- else:
- observations = np_observation.flatten()
-
- if observations.ndim == 1:
- observations = observations.reshape(1, -1)
+ stacked_observations = np.stack(fb_padded, axis=0)
+ observations = stacked_observations.reshape(1, -1)
else:
- observations = observations
+ observations = np_observation.reshape(1, -1)
action, _ = model.predict(
observations, deterministic=True, **action_masks_param