+from collections import deque
import copy
import gc
import json
"max_trade_duration_candles": 96, // Timeout exit value used with force_actions
"force_actions": false, // Utilize minimal_roi, stoploss, and max_trade_duration_candles as TP/SL/Timeout in the environment
"n_envs": 1, // Number of DummyVecEnv environments
- "frame_staking": 0, // Number of VecFrameStack stacks (set > 1 to use)
+ "frame_stacking": 0, // Number of VecFrameStack stacks (set > 1 to use)
"lr_schedule": false, // Enable learning rate linear schedule
"cr_schedule": false, // Enable clip range linear schedule
"max_no_improvement_evals": 0, // Maximum consecutive evaluations without a new best model
raise ValueError(
"FreqAI model requires StaticPairList method defined in pairlists configuration and pair_whitelist defined in exchange section configuration"
)
+ self.observations_buffer: Dict[str, deque] = {}
self.is_maskable: bool = (
self.model_type == "MaskablePPO"
) # Enable action masking
self.lr_schedule: bool = self.rl_config.get("lr_schedule", False)
self.cr_schedule: bool = self.rl_config.get("cr_schedule", False)
self.n_envs: int = self.rl_config.get("n_envs", 1)
- self.frame_staking: int = self.rl_config.get("frame_staking", 0)
- self.frame_staking += 1 if self.frame_staking == 1 else 0
+ self.frame_stacking: int = self.rl_config.get("frame_stacking", 0)
self.max_no_improvement_evals: int = self.rl_config.get(
"max_no_improvement_evals", 0
)
If user has activated any custom function that may conflict, this
function will set them to false and warn them
"""
- if self.continual_learning and self.frame_staking:
+ if self.continual_learning and self.frame_stacking:
logger.warning(
- "User tried to use continual_learning with frame_staking. \
+ "User tried to use continual_learning with frame_stacking. \
Deactivating continual_learning"
)
self.continual_learning = False
for i in range(self.n_envs)
]
)
- if self.frame_staking:
- logger.info("Frame staking: %s", self.frame_staking)
- train_env = VecFrameStack(train_env, n_stack=self.frame_staking)
- eval_env = VecFrameStack(eval_env, n_stack=self.frame_staking)
+ if self.frame_stacking:
+ logger.info("Frame stacking: %s", self.frame_stacking)
+ train_env = VecFrameStack(train_env, n_stack=self.frame_stacking)
+ eval_env = VecFrameStack(eval_env, n_stack=self.frame_stacking)
self.train_env = VecMonitor(train_env)
+ if self.frame_stacking and not self.train_env.observation_space.shape:
+ raise ValueError("Frame stacking requires predefined observation shape")
self.eval_env = VecMonitor(eval_env)
def get_model_params(self) -> Dict:
def get_callbacks(
self, eval_freq: int, data_path: str, trial: Trial = None
- ) -> list:
+ ) -> list[BaseCallback]:
"""
Get the model specific callbacks
"""
rollout_plot_callback = None
verbose = self.model_training_parameters.get("verbose", 0)
- if self.n_envs > 1:
- eval_freq //= self.n_envs
+ eval_freq //= self.n_envs
if self.plot_new_best:
rollout_plot_callback = RolloutPlotCallback(verbose=verbose)
train_df = data_dictionary["train_features"]
train_timesteps = len(train_df)
test_timesteps = len(data_dictionary["test_features"])
- train_cycles = int(self.rl_config.get("train_cycles", 250))
- total_timesteps = train_timesteps * train_cycles
+ train_cycles = max(1, int(self.rl_config.get("train_cycles", 25)))
+ total_timesteps = train_timesteps * train_cycles * self.n_envs
train_days = steps_to_days(train_timesteps, self.config["timeframe"])
total_days = steps_to_days(total_timesteps, self.config["timeframe"])
logger.info("Action masking: %s", self.is_maskable)
logger.info(
- "Train: %s steps (%s days) * %s cycles = Total %s (%s days)",
+ "Train: %s steps (%s days) * %s cycles * %s environments = Total %s (%s days)",
train_timesteps,
train_days,
train_cycles,
+ self.n_envs,
total_timesteps,
total_days,
)
finally:
if self.progressbar_callback:
self.progressbar_callback.on_training_end()
+ self.close_envs()
+ model.env.close()
time_spent = time.time() - start
self.dd.update_metric_tracker("fit_time", time_spent, dk.pair)
model_path = Path(dk.data_path / f"{model_filename}_model.zip")
if model_path.is_file():
logger.info(f"Callback found a best model: {model_path}.")
- best_model = self.MODELCLASS.load(dk.data_path / f"{model_filename}_model")
- return best_model
+ try:
+ best_model = self.MODELCLASS.load(
+ dk.data_path / f"{model_filename}_model"
+ )
+ return best_model
+ except Exception as e:
+ logger.error(f"Error loading best model: {e}", exc_info=True)
logger.info("Couldn't find best model, using final model instead.")
:param dk: FreqaiDatakitchen = data kitchen for the current pair
:param model: Any = the trained model used to inference the features.
"""
+ if not self.observations_buffer.get(dk.pair):
+ buffer_size = max(1, self.frame_stacking)
+ initial_observation = dataframe.iloc[0].to_numpy(dtype=np.float32)
+ self.observations_buffer[dk.pair] = deque(
+ [initial_observation] * buffer_size, maxlen=buffer_size
+ )
def _is_valid(action: int, position: float) -> bool:
"""
return [_is_valid(action.value, position) for action in Actions]
def _predict(window):
- observations: DataFrame = dataframe.iloc[window.index]
+ observation: DataFrame = dataframe.iloc[window.index]
action_masks_param: dict = {}
if self.live and self.rl_config.get("add_state_info", False):
position, pnl, trade_duration = self.get_state_info(dk.pair)
# STATE_INFO
- observations["pnl"] = pnl
- observations["position"] = position
- observations["trade_duration"] = trade_duration
+ observation["pnl"] = pnl
+ observation["position"] = position
+ observation["trade_duration"] = trade_duration
if self.is_maskable:
action_masks_param = {"action_masks": _action_masks(position)}
- observations = observations.to_numpy(dtype=np.float32)
+ observation = observation.to_numpy(dtype=np.float32)
- if self.frame_staking:
- observations = np.repeat(
- observations, axis=1, repeats=self.frame_staking
- )
+ self.observations_buffer[dk.pair].append(observation)
+
+ stacked_observations = np.concatenate(
+ self.observations_buffer[dk.pair], axis=1
+ )
action, _ = model.predict(
- observations, deterministic=True, **action_masks_param
+ stacked_observations, deterministic=True, **action_masks_param
)
return action
output = output.rolling(window=self.CONV_WIDTH).apply(_predict)
return output
- def get_storage(self, pair: str | None = None) -> BaseStorage:
+ def get_storage(self, pair: str | None = None) -> BaseStorage | None:
"""
Get the storage for Optuna
"""
else:
study_name = identifier
storage = self.get_storage()
+ eval_freq = len(train_df) // self.n_envs
study: Study = create_study(
study_name=study_name,
sampler=TPESampler(
group=True,
),
pruner=HyperbandPruner(
- min_resource=1, max_resource=self.optuna_n_trials, reduction_factor=3
+ min_resource=3,
+ max_resource=total_timesteps // eval_freq,
+ reduction_factor=3,
),
direction=StudyDirection.MAXIMIZE,
storage=storage,
"""
if "PPO" in self.model_type:
params = sample_params_ppo(trial)
+ if params.get("n_steps", 0) > total_timesteps:
+ raise TrialPruned("n_steps exceeds total_timesteps")
elif "QRDQN" in self.model_type:
params = sample_params_qrdqn(trial)
elif "DQN" in self.model_type:
tensorboard_log=tensorboard_log_path,
**params,
)
- callbacks = self.get_callbacks(len(train_df), str(dk.data_path), trial)
+ callbacks = self.get_callbacks(len(train_df), str(dk.data_path), trial)
try:
model.learn(total_timesteps=total_timesteps, callback=callbacks)
except AssertionError:
and falling prices in Short positions.
The actions Sell/Buy or Hold during a Long position trigger the sell/buy-fee.
"""
+ if self._current_tick <= 0:
+ return 0.0
if self._position == Positions.Long:
current_price = self.prices.iloc[self._current_tick].open
previous_price = self.prices.iloc[self._current_tick - 1].open
_rollout_history = _history_df.merge(
_trade_history_df, on="tick", how="left"
- )
+ ).fillna(method="ffill")
_price_history = (
self.prices.iloc[_rollout_history.tick].copy().reset_index()
)
"medium": {"pi": [256, 256], "vf": [256, 256]},
"large": {"pi": [512, 512], "vf": [512, 512]},
"extra_large": {"pi": [1024, 1024], "vf": [1024, 1024]},
- }[net_arch_type]
+ }.get(net_arch_type, {"pi": [128, 128], "vf": [128, 128]})
return {
"small": [128, 128],
"medium": [256, 256],
"large": [512, 512],
"extra_large": [1024, 1024],
- }[net_arch_type]
+ }.get(net_arch_type, [128, 128])
def get_activation_fn(activation_fn_name: str) -> type[th.nn.Module]:
"relu": th.nn.ReLU,
"elu": th.nn.ELU,
"leaky_relu": th.nn.LeakyReLU,
- }[activation_fn_name]
+ }.get(activation_fn_name, th.nn.ReLU)
def get_optimizer_class(optimizer_class_name: str) -> type[th.optim.Optimizer]:
return {
"adam": th.optim.Adam,
"rmsprop": th.optim.RMSprop,
- }[optimizer_class_name]
+ }.get(optimizer_class_name, th.optim.Adam)
def sample_params_ppo(trial: Trial) -> Dict[str, Any]: