from collections.abc import Mapping
from enum import IntEnum
from pathlib import Path
-from typing import Any, Callable, Dict, Literal, Optional, Tuple, Type, Union
+from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
import matplotlib
import matplotlib.pyplot as plt
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self.pairs: list[str] = self.config.get("exchange", {}).get("pair_whitelist")
+ self.pairs: List[str] = self.config.get("exchange", {}).get("pair_whitelist")
if not self.pairs:
raise ValueError(
"FreqAI model requires StaticPairList method defined in pairlists configuration and pair_whitelist defined in exchange section configuration"
if not model_params.get("policy_kwargs"):
model_params["policy_kwargs"] = {}
- default_net_arch: list[int] = [128, 128]
+ default_net_arch: List[int] = [128, 128]
net_arch: Union[
- list[int],
- Dict[str, list[int]],
+ List[int],
+ Dict[str, List[int]],
Literal["small", "medium", "large", "extra_large"],
] = model_params.get("policy_kwargs", {}).get("net_arch", default_net_arch)
"vf": net_arch,
}
elif isinstance(net_arch, dict):
- pi: Optional[list[int]] = net_arch.get("pi")
- vf: Optional[list[int]] = net_arch.get("vf")
+ pi: Optional[List[int]] = net_arch.get("pi")
+ vf: Optional[List[int]] = net_arch.get("vf")
if not isinstance(pi, list) or not isinstance(vf, list):
model_params["policy_kwargs"]["net_arch"] = {
"pi": pi
eval_freq: int,
data_path: str,
trial: Optional[Trial] = None,
- ) -> list[BaseCallback]:
+ ) -> List[BaseCallback]:
"""
Get the model specific callbacks
"""
- callbacks: list[BaseCallback] = []
+ callbacks: List[BaseCallback] = []
no_improvement_callback = None
rollout_plot_callback = None
verbose = int(self.get_model_params().get("verbose", 0))
if "PPO" in self.model_type:
min_timesteps = 2 * model_params.get("n_steps", 0) * self.n_envs
- if total_timesteps < min_timesteps:
+ if total_timesteps <= min_timesteps:
logger.warning(
- "total_timesteps=%s is less than 2*n_steps*n_envs=%s. This may lead to suboptimal training results",
+ "total_timesteps=%s is less than or equal to 2*n_steps*n_envs=%s. This may lead to suboptimal training results for model %s",
total_timesteps,
min_timesteps,
+ self.model_type,
)
if self.activate_tensorboard:
return Positions.Neutral
return position
- frame_buffer: list[NDArray[np.float32]] = []
+ frame_buffer: List[NDArray[np.float32]] = []
def _predict(window) -> int:
observation: DataFrame = dataframe.iloc[window.index]
np_observation = observation.to_numpy(dtype=np.float32)
- fb: list[NDArray[np.float32]] = frame_buffer
+ fb: List[NDArray[np.float32]] = frame_buffer
frame_stacking = self.frame_stacking
if frame_stacking and frame_stacking > 1:
fb.append(np_observation.copy())
)
return int(action)
- predicted_actions: list[int] = []
+ predicted_actions: List[int] = []
for window_end in range(self.CONV_WIDTH, len(dataframe) + 1):
window = dataframe.iloc[window_end - self.CONV_WIDTH : window_end]
action = _predict(window)
logger.error("logger.record retry on stdout failed at %r: %r", key, e)
pass
+ @staticmethod
+ def build_train_freq(
+ train_freq: Optional[Union[int, Tuple[int, ...], List[int]]],
+ ) -> Optional[int]:
+ train_freq_val: Optional[int] = None
+ try:
+ if isinstance(train_freq, int):
+ train_freq_val = train_freq
+ elif isinstance(train_freq, (tuple, list)) and train_freq:
+ if isinstance(train_freq[0], int):
+ train_freq_val = train_freq[0]
+ elif hasattr(train_freq, "freq"):
+ freq = getattr(train_freq, "freq")
+ if isinstance(freq, int):
+ train_freq_val = freq
+ except Exception:
+ train_freq_val = None
+ return train_freq_val
+
def _on_training_start(self) -> None:
lr = getattr(self.model, "learning_rate", None)
lr_schedule, lr_iv, lr_fv = get_schedule_type(lr)
"exploration_rate": float(self.model.exploration_rate),
}
)
- train_freq = getattr(self.model, "train_freq", None)
- train_freq_val: int | None = None
- try:
- if isinstance(train_freq, int):
- train_freq_val = train_freq
- elif isinstance(train_freq, (tuple, list)) and train_freq:
- if isinstance(train_freq[0], int):
- train_freq_val = train_freq[0]
- elif hasattr(train_freq, "freq"):
- freq = getattr(train_freq, "freq")
- if isinstance(freq, int):
- train_freq_val = freq
- except Exception:
- train_freq_val = None
- if train_freq_val is not None:
- hparam_dict.update({"train_freq": train_freq_val})
+ train_freq = InfoMetricsCallback.build_train_freq(
+ getattr(self.model, "train_freq", None)
+ )
+ if train_freq is not None:
+ hparam_dict.update({"train_freq": train_freq})
if "QRDQN" in self.model.__class__.__name__:
hparam_dict.update({"n_quantiles": int(self.model.n_quantiles)})
metric_dict: dict[str, float | int] = {
except Exception:
return False
- infos_list: list[Dict[str, Any]] | None = self.locals.get("infos")
+ infos_list: List[Dict[str, Any]] | None = self.locals.get("infos")
aggregated_info: Dict[str, Any] = {}
if isinstance(infos_list, list) and infos_list:
- numeric_acc: Dict[str, list[float]] = defaultdict(list)
+ numeric_acc: Dict[str, List[float]] = defaultdict(list)
non_numeric_counts: Dict[str, Dict[Any, int]] = defaultdict(
lambda: defaultdict(int)
)
def get_net_arch(
model_type: str, net_arch_type: Literal["small", "medium", "large", "extra_large"]
-) -> Union[list[int], Dict[str, list[int]]]:
+) -> Union[List[int], Dict[str, List[int]]]:
"""
Get network architecture
"""