)
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.logger import Figure, HParam
-from stable_baselines3.common.utils import set_random_seed
+from stable_baselines3.common.utils import ConstantSchedule, set_random_seed
from stable_baselines3.common.vec_env import (
DummyVecEnv,
SubprocVecEnv,
lr = model_params.get("learning_rate", 0.0003)
if isinstance(lr, (int, float)):
lr = float(lr)
- model_params["learning_rate"] = linear_schedule(lr)
+ model_params["learning_rate"] = get_schedule("linear", lr)
logger.info(
"Learning rate linear schedule enabled, initial value: %s", lr
)
cr = model_params.get("clip_range", 0.2)
if isinstance(cr, (int, float)):
cr = float(cr)
- model_params["clip_range"] = linear_schedule(cr)
+ model_params["clip_range"] = get_schedule("linear", cr)
logger.info("Clip range linear schedule enabled, initial value: %s", cr)
if "DQN" in self.model_type:
self._update_total_profit()
self._last_closed_position = self._position
self._position = Positions.Neutral
- self._last_closed_trade_tick = self._current_tick
self._last_trade_tick = None
+ self._last_closed_trade_tick = self._current_tick
def execute_trade(self, action: int) -> Optional[str]:
"""
pass
def _on_training_start(self) -> None:
- lr_schedule = "unknown"
- lr_iv = np.nan
- lr_fv = np.nan
lr = getattr(self.model, "learning_rate", None)
- if callable(lr):
- lr_schedule = "linear"
- try:
- lr_iv = lr(1.0)
- except Exception:
- lr_iv = np.nan
- try:
- lr_fv = lr(0.0)
- except Exception:
- lr_fv = np.nan
- elif isinstance(lr, (int, float)):
- lr_schedule = "constant"
- lr_iv = float(lr)
- lr_fv = float(lr)
+ lr_schedule, lr_iv, lr_fv = get_schedule_type(lr)
n_stack = 1
env = getattr(self, "training_env", None)
while env is not None:
"batch_size": int(self.model.batch_size),
}
if "PPO" in self.model.__class__.__name__:
- cr_schedule = "unknown"
- cr_iv = np.nan
- cr_fv = np.nan
cr = getattr(self.model, "clip_range", None)
- if callable(cr):
- cr_schedule = "linear"
- try:
- cr_iv = cr(1.0)
- except Exception:
- cr_iv = np.nan
- try:
- cr_fv = cr(0.0)
- except Exception:
- cr_fv = np.nan
- elif isinstance(cr, (int, float)):
- cr_schedule = "constant"
- cr_iv = float(cr)
- cr_fv = float(cr)
+ cr_schedule, cr_iv, cr_fv = get_schedule_type(cr)
hparam_dict.update(
{
"cr_schedule": cr_schedule,
except Exception:
progress_remaining = 1.0
+ def _eval_schedule(schedule: Any) -> float | None:
+ schedule_type, _, _ = get_schedule_type(schedule)
+ try:
+ if schedule_type == "linear":
+ return float(schedule(progress_remaining))
+ if schedule_type == "constant":
+ if callable(schedule):
+ return float(schedule(0.0))
+ if isinstance(schedule, (int, float)):
+ return float(schedule)
+ return None
+ except Exception:
+ return None
+
try:
lr = getattr(self.model, "learning_rate", None)
- if callable(lr):
- lr = lr(progress_remaining)
+ lr = _eval_schedule(lr)
if _is_finite_number(lr):
self._safe_logger_record(
"train/learning_rate", float(lr), exclude=logger_exclude
if "PPO" in self.model.__class__.__name__:
try:
cr = getattr(self.model, "clip_range", None)
- if callable(cr):
- cr = cr(progress_remaining)
+ cr = _eval_schedule(cr)
if _is_finite_number(cr):
self._safe_logger_record(
"train/clip_range", float(cr), exclude=logger_exclude
return True
+class SimpleLinearSchedule:
+ """
+ Linear schedule (from initial value to zero),
+ simpler than sb3 LinearSchedule.
+
+ :param initial_value: (float or str) The initial value for the schedule
+ """
+
+ def __init__(self, initial_value: Union[float, str]) -> None:
+ # Force conversion to float
+ self.initial_value = float(initial_value)
+
+ def __call__(self, progress_remaining: float) -> float:
+ return progress_remaining * self.initial_value
+
+ def __repr__(self) -> str:
+ return f"SimpleLinearSchedule(initial_value={self.initial_value})"
+
+
def deepmerge(dst: Dict[str, Any], src: Dict[str, Any]) -> Dict[str, Any]:
"""Recursively merge two dicts without mutating inputs"""
dst_copy = copy.deepcopy(dst)
return dst_copy
-def linear_schedule(initial_value: float) -> Callable[[float], float]:
- def func(progress_remaining: float) -> float:
- return progress_remaining * initial_value
-
- return func
-
-
def _compute_gradient_steps(tf: int, ss: int) -> int:
if tf > 0 and ss > 0:
return min(tf, max(tf // ss, 1))
return round(days, 1)
+def get_schedule_type(
+ schedule: Any,
+) -> Tuple[Literal["constant", "linear", "unknown"], float, float]:
+ if isinstance(schedule, (int, float)):
+ try:
+ schedule = float(schedule)
+ return "constant", schedule, schedule
+ except Exception:
+ return "constant", np.nan, np.nan
+ elif isinstance(schedule, ConstantSchedule):
+ try:
+ return "constant", schedule(1.0), schedule(0.0)
+ except Exception:
+ return "constant", np.nan, np.nan
+ elif isinstance(schedule, SimpleLinearSchedule):
+ try:
+ return "linear", schedule(1.0), schedule(0.0)
+ except Exception:
+ return "linear", np.nan, np.nan
+
+ return "unknown", np.nan, np.nan
+
+
+def get_schedule(
+ schedule_type: Literal["linear", "constant"],
+ initial_value: float,
+) -> Callable[[float], float]:
+ if schedule_type == "linear":
+ return SimpleLinearSchedule(initial_value)
+ elif schedule_type == "constant":
+ return ConstantSchedule(initial_value)
+
+
def get_net_arch(
model_type: str, net_arch_type: Literal["small", "medium", "large", "extra_large"]
) -> Union[list[int], Dict[str, list[int]]]:
lr = optuna_params.get("learning_rate")
if lr is None:
raise ValueError(f"missing 'learning_rate' in optuna params for {model_type}")
- lr: float | Callable[[float], float] = float(lr)
- if optuna_params.get("lr_schedule") == "linear":
- lr = linear_schedule(lr)
+ lr = get_schedule(optuna_params.get("lr_schedule", "constant"), float(lr))
if "PPO" in model_type:
cr = optuna_params.get("clip_range")
if cr is None:
raise ValueError(f"missing 'clip_range' in optuna params for {model_type}")
- cr: float | Callable[[float], float] = float(cr)
- if optuna_params.get("cr_schedule") == "linear":
- cr = linear_schedule(cr)
+ cr = get_schedule(optuna_params.get("cr_schedule", "constant"), float(cr))
model_params.update(
{