-from collections import deque
import copy
import gc
import json
raise ValueError(
"FreqAI model requires StaticPairList method defined in pairlists configuration and pair_whitelist defined in exchange section configuration"
)
- self.observations_buffer: Dict[str, deque] = {}
self.is_maskable: bool = (
self.model_type == "MaskablePPO"
) # Enable action masking
rollout_plot_callback = None
verbose = self.model_training_parameters.get("verbose", 0)
- eval_freq //= self.n_envs
-
if self.plot_new_best:
rollout_plot_callback = RolloutPlotCallback(verbose=verbose)
model = self.dd.model_dictionary[dk.pair]
model.set_env(self.train_env)
- callbacks = self.get_callbacks(train_timesteps, str(dk.data_path))
+ callbacks = self.get_callbacks(
+ train_timesteps // self.n_envs, str(dk.data_path)
+ )
try:
model.learn(total_timesteps=total_timesteps, callback=callbacks)
finally:
:param dk: FreqaiDatakitchen = data kitchen for the current pair
:param model: Any = the trained model used to inference the features.
"""
- if not self.observations_buffer.get(dk.pair):
- buffer_size = max(1, self.frame_stacking)
- initial_observation = dataframe.iloc[0].to_numpy(dtype=np.float32)
- self.observations_buffer[dk.pair] = deque(
- [initial_observation] * buffer_size, maxlen=buffer_size
- )
def _is_valid(action: int, position: float) -> bool:
"""
observation = observation.to_numpy(dtype=np.float32)
- self.observations_buffer[dk.pair].append(observation)
-
- stacked_observations = np.concatenate(
- self.observations_buffer[dk.pair], axis=1
- )
+ if self.frame_stacking:
+ # FIXME: proper observation stacking need more work
+ observations = np.repeat(
+ observation, axis=1, repeats=self.frame_stacking
+ )
+ else:
+ observations = observation
action, _ = model.predict(
- stacked_observations, deterministic=True, **action_masks_param
+ observations, deterministic=True, **action_masks_param
)
return action
else:
study_name = identifier
storage = self.get_storage()
- eval_freq = len(train_df) // self.n_envs
+ eval_freq = max(1, len(train_df) // self.n_envs)
study: Study = create_study(
study_name=study_name,
sampler=TPESampler(
),
pruner=HyperbandPruner(
min_resource=3,
- max_resource=total_timesteps // eval_freq,
+ max_resource=(total_timesteps + eval_freq - 1) // eval_freq,
reduction_factor=3,
),
direction=StudyDirection.MAXIMIZE,
**params,
)
- callbacks = self.get_callbacks(len(train_df), str(dk.data_path), trial)
+ callbacks = self.get_callbacks(
+ len(train_df) // self.n_envs, str(dk.data_path), trial
+ )
try:
model.learn(total_timesteps=total_timesteps, callback=callbacks)
except AssertionError:
"""
Reset is called at the beginning of every episode
"""
- _, history = super().reset(seed, **kwargs)
+ observation, history = super().reset(seed, **kwargs)
self._force_action: Optional[ForceActions] = None
self._last_closed_position: Positions = None
self._last_closed_trade_tick: int = 0
- return self._get_observation(), history
+ return observation, history
def _get_reward_factor_at_trade_exit(
self,