from collections import defaultdict, deque
from collections.abc import Mapping
from pathlib import Path
-from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ List,
+ Literal,
+ Optional,
+ Tuple,
+ Type,
+ Union,
+ cast,
+)
import matplotlib
import matplotlib.pyplot as plt
VecMonitor,
)
+ModelType = Literal["PPO", "RecurrentPPO", "MaskablePPO", "DQN", "QRDQN"]
+ScheduleType = Literal["linear", "constant", "unknown"]
+ExitPotentialMode = Literal[
+ "canonical",
+ "non_canonical",
+ "progressive_release",
+ "spike_cancel",
+ "retain_previous",
+]
+TransformFunction = Literal["tanh", "softsign", "arctan", "sigmoid", "asinh", "clip"]
+ExitAttenuationMode = Literal["legacy", "sqrt", "linear", "power", "half_life"]
+ActivationFunction = Literal["tanh", "relu", "elu", "leaky_relu"]
+OptimizerClass = Literal["adam", "adamw", "rmsprop"]
+NetArchSize = Literal["small", "medium", "large", "extra_large"]
+StorageBackend = Literal["sqlite", "file"]
+SamplerType = Literal["tpe", "auto"]
+
matplotlib.use("Agg")
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
_LOG_2 = math.log(2.0)
DEFAULT_IDLE_DURATION_MULTIPLIER: int = 4
+
+ _MODEL_TYPES: tuple[ModelType, ...] = (
+ "PPO",
+ "RecurrentPPO",
+ "MaskablePPO",
+ "DQN",
+ "QRDQN",
+ )
+ _SCHEDULE_TYPES: tuple[ScheduleType, ...] = ("linear", "constant", "unknown")
+ _EXIT_POTENTIAL_MODES: tuple[ExitPotentialMode, ...] = (
+ "canonical",
+ "non_canonical",
+ "progressive_release",
+ "spike_cancel",
+ "retain_previous",
+ )
+ _TRANSFORM_FUNCTIONS: tuple[TransformFunction, ...] = (
+ "tanh",
+ "softsign",
+ "arctan",
+ "sigmoid",
+ "asinh",
+ "clip",
+ )
+ _EXIT_ATTENUATION_MODES: tuple[ExitAttenuationMode, ...] = (
+ "legacy",
+ "sqrt",
+ "linear",
+ "power",
+ "half_life",
+ )
+ _ACTIVATION_FUNCTIONS: tuple[ActivationFunction, ...] = (
+ "tanh",
+ "relu",
+ "elu",
+ "leaky_relu",
+ )
+ _OPTIMIZER_CLASSES: tuple[OptimizerClass, ...] = ("adam", "adamw", "rmsprop")
+ _NET_ARCH_SIZES: tuple[NetArchSize, ...] = (
+ "small",
+ "medium",
+ "large",
+ "extra_large",
+ )
+ _STORAGE_BACKENDS: tuple[StorageBackend, ...] = ("sqlite", "file")
+ _SAMPLER_TYPES: tuple[SamplerType, ...] = ("tpe", "auto")
+
_action_masks_cache: Dict[Tuple[bool, float], NDArray[np.bool_]] = {}
+ @staticmethod
+ def _model_types_set() -> set[ModelType]:
+ return set(ReforceXY._MODEL_TYPES)
+
+ @staticmethod
+ def _exit_potential_modes_set() -> set[ExitPotentialMode]:
+ return set(ReforceXY._EXIT_POTENTIAL_MODES)
+
+ @staticmethod
+ def _transform_functions_set() -> set[TransformFunction]:
+ return set(ReforceXY._TRANSFORM_FUNCTIONS)
+
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.pairs: List[str] = self.config.get("exchange", {}).get("pair_whitelist")
raise ValueError(
"FreqAI model requires StaticPairList method defined in pairlists configuration and pair_whitelist defined in exchange section configuration"
)
- self.action_masking: bool = self.model_type == "MaskablePPO"
+ self.action_masking: bool = (
+ self.model_type == self._MODEL_TYPES[2]
+ ) # "MaskablePPO"
self.rl_config.setdefault("action_masking", self.action_masking)
self.inference_masking: bool = self.rl_config.get("inference_masking", True)
- self.recurrent: bool = self.model_type == "RecurrentPPO"
+ self.recurrent: bool = self.model_type == self._MODEL_TYPES[1] # "RecurrentPPO"
self.lr_schedule: bool = self.rl_config.get("lr_schedule", False)
self.cr_schedule: bool = self.rl_config.get("cr_schedule", False)
self.n_envs: int = self.rl_config.get("n_envs", 1)
lr = model_params.get("learning_rate", 0.0003)
if isinstance(lr, (int, float)):
lr = float(lr)
- model_params["learning_rate"] = get_schedule("linear", lr)
+ model_params["learning_rate"] = get_schedule(
+ self._SCHEDULE_TYPES[0], lr
+ )
logger.info(
"Learning rate linear schedule enabled, initial value: %s", lr
)
- if not self.hyperopt and "PPO" in self.model_type and self.cr_schedule:
+ # "PPO"
+ if (
+ not self.hyperopt
+ and self._MODEL_TYPES[0] in self.model_type
+ and self.cr_schedule
+ ):
cr = model_params.get("clip_range", 0.2)
if isinstance(cr, (int, float)):
cr = float(cr)
- model_params["clip_range"] = get_schedule("linear", cr)
+ model_params["clip_range"] = get_schedule(self._SCHEDULE_TYPES[0], cr)
logger.info("Clip range linear schedule enabled, initial value: %s", cr)
- if "DQN" in self.model_type:
+ # "DQN"
+ if self._MODEL_TYPES[3] in self.model_type:
if model_params.get("gradient_steps") is None:
model_params["gradient_steps"] = compute_gradient_steps(
model_params.get("train_freq"), model_params.get("subsample_steps")
Literal["small", "medium", "large", "extra_large"],
] = model_params.get("policy_kwargs", {}).get("net_arch", default_net_arch)
- if "PPO" in self.model_type:
+ # "PPO"
+ if self._MODEL_TYPES[0] in self.model_type:
if isinstance(net_arch, str):
- model_params["policy_kwargs"]["net_arch"] = get_net_arch(
- self.model_type, net_arch
- )
+ if net_arch in self._NET_ARCH_SIZES:
+ model_params["policy_kwargs"]["net_arch"] = get_net_arch(
+ self.model_type,
+ cast(NetArchSize, net_arch),
+ )
+ else:
+ logger.warning("Invalid net_arch=%s, using default", net_arch)
+ model_params["policy_kwargs"]["net_arch"] = {
+ "pi": default_net_arch,
+ "vf": default_net_arch,
+ }
elif isinstance(net_arch, list):
model_params["policy_kwargs"]["net_arch"] = {
"pi": net_arch,
}
else:
if isinstance(net_arch, str):
- model_params["policy_kwargs"]["net_arch"] = get_net_arch(
- self.model_type, net_arch
- )
+ if net_arch in self._NET_ARCH_SIZES:
+ model_params["policy_kwargs"]["net_arch"] = get_net_arch(
+ self.model_type,
+ cast(NetArchSize, net_arch),
+ )
+ else:
+ logger.warning("Invalid net_arch=%s, using default", net_arch)
+ model_params["policy_kwargs"]["net_arch"] = default_net_arch
elif isinstance(net_arch, list):
model_params["policy_kwargs"]["net_arch"] = net_arch
else:
model_params["policy_kwargs"]["net_arch"] = default_net_arch
model_params["policy_kwargs"]["activation_fn"] = get_activation_fn(
- model_params.get("policy_kwargs", {}).get("activation_fn", "relu")
+ model_params.get("policy_kwargs", {}).get(
+ "activation_fn", self._ACTIVATION_FUNCTIONS[1]
+ ) # "relu"
)
model_params["policy_kwargs"]["optimizer_class"] = get_optimizer_class(
- model_params.get("policy_kwargs", {}).get("optimizer_class", "adamw")
+ model_params.get("policy_kwargs", {}).get(
+ "optimizer_class", self._OPTIMIZER_CLASSES[1]
+ ) # "adamw"
)
self._model_params_cache = model_params
"""
if total_timesteps <= 0:
return 1
- if "PPO" in self.model_type:
+ # "PPO"
+ if self._MODEL_TYPES[0] in self.model_type:
eval_freq: Optional[int] = None
if model_params:
n_steps = model_params.get("n_steps")
model_params = self.get_model_params()
logger.info("%s params: %s", self.model_type, model_params)
- if "PPO" in self.model_type:
+ # "PPO"
+ if self._MODEL_TYPES[0] in self.model_type:
n_steps = model_params.get("n_steps", 0)
min_timesteps = 2 * n_steps * self.n_envs
if total_timesteps <= min_timesteps:
"""
storage_dir = self.full_path
storage_filename = f"optuna-{pair.split('/')[0]}"
- storage_backend = self.rl_config_optuna.get("storage", "sqlite")
- if storage_backend == "sqlite":
+ storage_backend: StorageBackend = self.rl_config_optuna.get(
+ "storage", self._STORAGE_BACKENDS[0]
+ ) # "sqlite"
+ # "sqlite"
+ if storage_backend == self._STORAGE_BACKENDS[0]:
storage = RDBStorage(
url=f"sqlite:///{storage_dir}/{storage_filename}.sqlite",
heartbeat_interval=60,
failed_trial_callback=RetryFailedTrialCallback(max_retry=3),
)
- elif storage_backend == "file":
+ # "file"
+ elif storage_backend == self._STORAGE_BACKENDS[1]:
storage = JournalStorage(
JournalFileBackend(f"{storage_dir}/{storage_filename}.log")
)
else:
raise ValueError(
- f"Unsupported storage backend: {storage_backend}. Supported backends are: 'sqlite' and 'file'"
+ f"Unsupported storage backend: {storage_backend}. Supported backends are: {', '.join(self._STORAGE_BACKENDS)}"
)
return storage
return False
def create_sampler(self) -> BaseSampler:
- sampler = self.rl_config_optuna.get("sampler", "tpe")
- if sampler == "auto":
+ sampler: SamplerType = self.rl_config_optuna.get(
+ "sampler", self._SAMPLER_TYPES[0]
+ ) # "tpe"
+ # "auto"
+ if sampler == self._SAMPLER_TYPES[1]:
return optunahub.load_module("samplers/auto_sampler").AutoSampler(
seed=self.rl_config_optuna.get("seed", 42)
)
- elif sampler == "tpe":
+ # "tpe"
+ elif sampler == self._SAMPLER_TYPES[0]:
return TPESampler(
n_startup_trials=self.optuna_n_startup_trials,
multivariate=True,
)
else:
raise ValueError(
- f"Unsupported sampler: '{sampler}'. Supported samplers: 'tpe', 'auto'"
+ f"Unsupported sampler: {sampler!r}. Supported samplers: {', '.join(self._SAMPLER_TYPES)}"
)
@staticmethod
continuous = self.rl_config_optuna.get("continuous", False)
if continuous:
ReforceXY.delete_study(study_name, storage)
- if "PPO" in self.model_type:
+ # "PPO"
+ if self._MODEL_TYPES[0] in self.model_type:
resource_eval_freq = min(PPO_N_STEPS)
else:
resource_eval_freq = self.get_eval_freq(total_timesteps, hyperopt=True)
return train_env, eval_env
def get_optuna_params(self, trial: Trial) -> Dict[str, Any]:
- if "RecurrentPPO" in self.model_type:
+ # "RecurrentPPO"
+ if self._MODEL_TYPES[1] in self.model_type:
return sample_params_recurrentppo(trial)
- elif "PPO" in self.model_type:
+ # "PPO"
+ elif self._MODEL_TYPES[0] in self.model_type:
return sample_params_ppo(trial)
- elif "QRDQN" in self.model_type:
+ # "QRDQN"
+ elif self._MODEL_TYPES[4] in self.model_type:
return sample_params_qrdqn(trial)
- elif "DQN" in self.model_type:
+ # "DQN"
+ elif self._MODEL_TYPES[3] in self.model_type:
return sample_params_dqn(trial)
else:
raise NotImplementedError(f"{self.model_type} not supported for hyperopt")
params = self.get_optuna_params(trial)
- if "PPO" in self.model_type:
+ # "PPO"
+ if self._MODEL_TYPES[0] in self.model_type:
n_steps = params.get("n_steps")
if n_steps * self.n_envs > total_timesteps:
raise TrialPruned(
f"{n_steps=} * {self.n_envs=} = {n_steps * self.n_envs} is not divisible by {batch_size=}"
)
- if "DQN" in self.model_type:
+ # "DQN"
+ if self._MODEL_TYPES[3] in self.model_type:
gradient_steps = params.get("gradient_steps")
if isinstance(gradient_steps, int) and gradient_steps <= 0:
raise TrialPruned(f"{gradient_steps=} is negative or zero")
params["seed"] = params.get("seed", 42) + trial.number
logger.info("Trial %s params: %s", trial.number, params)
- if "PPO" in self.model_type:
+ # "PPO"
+
+ if self._MODEL_TYPES[0] in self.model_type:
n_steps = params.get("n_steps", 0)
if n_steps > 0:
rollout = n_steps * self.n_envs
# 'spike_cancel' -> Φ(s')=Φ(s)/γ (Δ ≈ 0, cancels shaping)
# 'retain_previous' -> Φ(s')=Φ(s)
self._exit_potential_mode = str(
- model_reward_parameters.get("exit_potential_mode", "canonical")
- )
- _allowed_exit_modes = {
- "canonical",
- "non_canonical",
- "progressive_release",
- "spike_cancel",
- "retain_previous",
- }
+ model_reward_parameters.get(
+ "exit_potential_mode", ReforceXY._EXIT_POTENTIAL_MODES[0]
+ ) # "canonical"
+ )
+ _allowed_exit_modes = set(ReforceXY._EXIT_POTENTIAL_MODES)
if self._exit_potential_mode not in _allowed_exit_modes:
logger.warning(
- "Unknown exit_potential_mode '%s'; defaulting to 'canonical'",
+ "Unknown exit_potential_mode %r; defaulting to %r. Valid modes: %s",
self._exit_potential_mode,
+ ReforceXY._EXIT_POTENTIAL_MODES[0],
+ ", ".join(ReforceXY._EXIT_POTENTIAL_MODES),
)
- self._exit_potential_mode = "canonical"
+ self._exit_potential_mode = ReforceXY._EXIT_POTENTIAL_MODES[
+ 0
+ ] # "canonical"
self._exit_potential_decay: float = float(
model_reward_parameters.get("exit_potential_decay", 0.5)
)
self._entry_additive_gain: float = float(
model_reward_parameters.get("entry_additive_gain", 1.0)
)
- self._entry_additive_transform_pnl: str = str(
- model_reward_parameters.get("entry_additive_transform_pnl", "tanh")
+ self._entry_additive_transform_pnl: TransformFunction = cast(
+ TransformFunction,
+ model_reward_parameters.get(
+ "entry_additive_transform_pnl", ReforceXY._TRANSFORM_FUNCTIONS[0]
+ ), # "tanh"
)
- self._entry_additive_transform_duration: str = str(
- model_reward_parameters.get("entry_additive_transform_duration", "tanh")
+ self._entry_additive_transform_duration: TransformFunction = cast(
+ TransformFunction,
+ model_reward_parameters.get(
+ "entry_additive_transform_duration", ReforceXY._TRANSFORM_FUNCTIONS[0]
+ ), # "tanh"
)
# === HOLD POTENTIAL (PBRS function Φ) ===
self._hold_potential_enabled: bool = bool(
self._hold_potential_gain: float = float(
model_reward_parameters.get("hold_potential_gain", 1.0)
)
- self._hold_potential_transform_pnl: str = str(
- model_reward_parameters.get("hold_potential_transform_pnl", "tanh")
+ self._hold_potential_transform_pnl: TransformFunction = cast(
+ TransformFunction,
+ model_reward_parameters.get(
+ "hold_potential_transform_pnl", ReforceXY._TRANSFORM_FUNCTIONS[0]
+ ), # "tanh"
)
- self._hold_potential_transform_duration: str = str(
- model_reward_parameters.get("hold_potential_transform_duration", "tanh")
+ self._hold_potential_transform_duration: TransformFunction = cast(
+ TransformFunction,
+ model_reward_parameters.get(
+ "hold_potential_transform_duration", ReforceXY._TRANSFORM_FUNCTIONS[0]
+ ), # "tanh"
)
# === EXIT ADDITIVE (non-PBRS additive term) ===
self._exit_additive_enabled: bool = bool(
self._exit_additive_gain: float = float(
model_reward_parameters.get("exit_additive_gain", 1.0)
)
- self._exit_additive_transform_pnl: str = str(
- model_reward_parameters.get("exit_additive_transform_pnl", "tanh")
+ self._exit_additive_transform_pnl: TransformFunction = cast(
+ TransformFunction,
+ model_reward_parameters.get(
+ "exit_additive_transform_pnl", ReforceXY._TRANSFORM_FUNCTIONS[0]
+ ), # "tanh"
)
- self._exit_additive_transform_duration: str = str(
- model_reward_parameters.get("exit_additive_transform_duration", "tanh")
+ self._exit_additive_transform_duration: TransformFunction = cast(
+ TransformFunction,
+ model_reward_parameters.get(
+ "exit_additive_transform_duration", ReforceXY._TRANSFORM_FUNCTIONS[0]
+ ), # "tanh"
)
# === PBRS INVARIANCE CHECKS ===
- if self._exit_potential_mode == "canonical":
+ # "canonical"
+ if self._exit_potential_mode == ReforceXY._EXIT_POTENTIAL_MODES[0]:
if self._entry_additive_enabled or self._exit_additive_enabled:
logger.info(
"PBRS canonical mode: additive rewards disabled with Φ(terminal)=0. PBRS invariance is preserved. "
- "To use additive rewards, set exit_potential_mode='non_canonical'."
+ f"To use additive rewards, set exit_potential_mode={ReforceXY._EXIT_POTENTIAL_MODES[1]!r}."
)
self._entry_additive_enabled = False
self._exit_additive_enabled = False
- elif self._exit_potential_mode == "non_canonical":
+ # "non_canonical"
+ elif self._exit_potential_mode == ReforceXY._EXIT_POTENTIAL_MODES[1]:
if self._entry_additive_enabled or self._exit_additive_enabled:
logger.info(
"PBRS non-canonical mode: additive rewards enabled with Φ(terminal)=0. PBRS invariance is intentionally broken."
duration_ratio: float,
scale: float,
gain: float,
- transform_pnl: str,
- transform_duration: str,
+ transform_pnl: TransformFunction,
+ transform_duration: TransformFunction,
) -> float:
"""Generic bounded bi-component signal combining PnL and duration.
Output scaling factor
gain : float
Gain multiplier before transform
- transform_pnl : str
+ transform_pnl : TransformFunction
Transform name for PnL component
- transform_duration : str
+ transform_duration : TransformFunction
Transform name for duration component
Returns
transform_duration=self._entry_additive_transform_duration,
)
- def _potential_transform(self, name: str, x: float) -> float:
+ def _potential_transform(self, name: TransformFunction, x: float) -> float:
"""Apply bounded transform function for potential and additive computations.
Provides numerical stability by mapping unbounded inputs to bounded outputs
Parameters
----------
name : str
- Transform function name: 'tanh', 'softsign', 'arctan', 'sigmoid',
- 'asinh', or 'clip'
+ Transform function name: one of ReforceXY._TRANSFORM_FUNCTIONS
x : float
Input value to transform
float
Bounded output in [-1, 1]
"""
- if name == "tanh":
+ if name == ReforceXY._TRANSFORM_FUNCTIONS[0]: # "tanh"
return math.tanh(x)
- if name == "softsign":
+ if name == ReforceXY._TRANSFORM_FUNCTIONS[1]: # "softsign"
ax = abs(x)
return x / (1.0 + ax)
- if name == "arctan":
+ if name == ReforceXY._TRANSFORM_FUNCTIONS[2]: # "arctan"
return (2.0 / math.pi) * math.atan(x)
- if name == "sigmoid":
+ if name == ReforceXY._TRANSFORM_FUNCTIONS[3]: # "sigmoid"
try:
if x >= 0:
exp_neg_x = math.exp(-x)
except OverflowError:
return 1.0 if x > 0 else -1.0
- if name == "asinh":
+ if name == ReforceXY._TRANSFORM_FUNCTIONS[4]: # "asinh"
return x / math.hypot(1.0, x)
- if name == "clip":
+ if name == ReforceXY._TRANSFORM_FUNCTIONS[5]: # "clip"
return max(-1.0, min(1.0, x))
- logger.warning("Unknown potential transform '%s'; falling back to tanh", name)
+ logger.warning(
+ "Unknown potential transform '%s'; falling back to tanh. Valid transforms: %s",
+ name,
+ ", ".join(ReforceXY._TRANSFORM_FUNCTIONS),
+ )
return math.tanh(x)
def _compute_exit_potential(self, prev_potential: float, gamma: float) -> float:
See ``_apply_potential_shaping`` for complete PBRS documentation.
"""
mode = self._exit_potential_mode
- if mode == "canonical" or mode == "non_canonical":
+ # "canonical" or "non_canonical"
+ if (
+ mode == ReforceXY._EXIT_POTENTIAL_MODES[0]
+ or mode == ReforceXY._EXIT_POTENTIAL_MODES[1]
+ ):
return 0.0
- if mode == "progressive_release":
+ # "progressive_release"
+ if mode == ReforceXY._EXIT_POTENTIAL_MODES[2]:
decay = self._exit_potential_decay
if not np.isfinite(decay) or decay < 0.0:
decay = 0.0
if decay > 1.0:
decay = 1.0
next_potential = prev_potential * (1.0 - decay)
- elif mode == "spike_cancel":
+ # "spike_cancel"
+ elif mode == ReforceXY._EXIT_POTENTIAL_MODES[3]:
if gamma <= 0.0 or not np.isfinite(gamma):
next_potential = prev_potential
else:
next_potential = prev_potential / gamma
- elif mode == "retain_previous":
+ # "retain_previous"
+ elif mode == ReforceXY._EXIT_POTENTIAL_MODES[4]:
next_potential = prev_potential
else:
next_potential = 0.0
bool
True if configuration preserves theoretical PBRS invariance
"""
- return self._exit_potential_mode == "canonical" and not (
+ # "canonical"
+ return self._exit_potential_mode == ReforceXY._EXIT_POTENTIAL_MODES[0] and not (
self._entry_additive_enabled or self._exit_additive_enabled
)
return base_reward + reward_shaping
elif is_exit:
if (
- self._exit_potential_mode == "canonical"
- or self._exit_potential_mode == "non_canonical"
+ self._exit_potential_mode
+ == ReforceXY._EXIT_POTENTIAL_MODES[0] # "canonical"
+ ) or (
+ self._exit_potential_mode
+ == ReforceXY._EXIT_POTENTIAL_MODES[1] # "non_canonical"
):
next_potential = 0.0
exit_reward_shaping = -prev_potential
model_reward_parameters = self.rl_config.get("model_reward_parameters", {})
exit_attenuation_mode = str(
- model_reward_parameters.get("exit_attenuation_mode", "linear")
+ model_reward_parameters.get(
+ "exit_attenuation_mode", ReforceXY._EXIT_ATTENUATION_MODES[2]
+ ) # "linear"
)
exit_plateau = bool(model_reward_parameters.get("exit_plateau", True))
exit_plateau_grace = float(
return f * math.pow(2.0, -dr / hl)
strategies: Dict[str, Callable[[float, float, Mapping], float]] = {
- "legacy": _legacy,
- "sqrt": _sqrt,
- "linear": _linear,
- "power": _power,
- "half_life": _half_life,
+ ReforceXY._EXIT_ATTENUATION_MODES[0]: _legacy,
+ ReforceXY._EXIT_ATTENUATION_MODES[1]: _sqrt,
+ ReforceXY._EXIT_ATTENUATION_MODES[2]: _linear,
+ ReforceXY._EXIT_ATTENUATION_MODES[3]: _power,
+ ReforceXY._EXIT_ATTENUATION_MODES[4]: _half_life,
}
if exit_plateau:
strategy_fn = strategies.get(exit_attenuation_mode, None)
if strategy_fn is None:
logger.debug(
- "Unknown exit_attenuation_mode '%s'; defaulting to linear",
+ "Unknown exit_attenuation_mode '%s'; defaulting to linear. Valid modes: %s",
exit_attenuation_mode,
+ ", ".join(ReforceXY._EXIT_ATTENUATION_MODES),
)
strategy_fn = _linear
def _eval_schedule(schedule: Any) -> float | None:
schedule_type, _, _ = get_schedule_type(schedule)
try:
- if schedule_type == "linear":
+ if schedule_type == ReforceXY._SCHEDULE_TYPES[0]: # "linear"
return float(schedule(progress_remaining))
- if schedule_type == "constant":
+ if schedule_type == ReforceXY._SCHEDULE_TYPES[1]: # "constant"
if callable(schedule):
return float(schedule(0.0))
if isinstance(schedule, (int, float)):
schedule_type: Literal["linear", "constant"],
initial_value: float,
) -> Callable[[float], float]:
- if schedule_type == "linear":
+ if schedule_type == ReforceXY._SCHEDULE_TYPES[0]: # "linear"
return SimpleLinearSchedule(initial_value)
- elif schedule_type == "constant":
+ elif schedule_type == ReforceXY._SCHEDULE_TYPES[1]: # "constant"
return ConstantSchedule(initial_value)
else:
return ConstantSchedule(initial_value)
lr = optuna_params.get("learning_rate")
if lr is None:
- raise ValueError(f"missing 'learning_rate' in optuna params for {model_type}")
- lr = get_schedule(optuna_params.get("lr_schedule", "constant"), float(lr))
+ raise ValueError(
+ f"missing {'learning_rate'!r} in optuna params for {model_type}"
+ )
+ lr = get_schedule(
+ optuna_params.get("lr_schedule", ReforceXY._SCHEDULE_TYPES[1]), float(lr)
+ ) # default: "constant"
if "PPO" in model_type:
required_ppo_params = [
if optuna_params.get(param) is None:
raise ValueError(f"missing '{param}' in optuna params for {model_type}")
cr = optuna_params.get("clip_range")
- cr = get_schedule(optuna_params.get("cr_schedule", "constant"), float(cr))
+ cr = get_schedule(
+ optuna_params.get("cr_schedule", ReforceXY._SCHEDULE_TYPES[1]),
+ float(cr),
+ ) # default: "constant"
model_params.update(
{
raise ValueError(f"Model {model_type} not supported")
if optuna_params.get("net_arch"):
- policy_kwargs["net_arch"] = get_net_arch(
- model_type, str(optuna_params["net_arch"])
- )
+ net_arch_value = str(optuna_params["net_arch"])
+ if net_arch_value in ReforceXY._NET_ARCH_SIZES:
+ policy_kwargs["net_arch"] = get_net_arch(
+ model_type,
+ cast(NetArchSize, net_arch_value),
+ )
if optuna_params.get("activation_fn"):
- policy_kwargs["activation_fn"] = get_activation_fn(
- str(optuna_params["activation_fn"])
- )
+ activation_fn_value = str(optuna_params["activation_fn"])
+ if activation_fn_value in ReforceXY._ACTIVATION_FUNCTIONS:
+ policy_kwargs["activation_fn"] = get_activation_fn(
+ cast(ActivationFunction, activation_fn_value)
+ )
if optuna_params.get("optimizer_class"):
- policy_kwargs["optimizer_class"] = get_optimizer_class(
- str(optuna_params["optimizer_class"])
- )
+ optimizer_value = str(optuna_params["optimizer_class"])
+ if optimizer_value in ReforceXY._OPTIMIZER_CLASSES:
+ policy_kwargs["optimizer_class"] = get_optimizer_class(
+ cast(OptimizerClass, optimizer_value)
+ )
if optuna_params.get("ortho_init") is not None:
policy_kwargs["ortho_init"] = bool(optuna_params["ortho_init"])
"gae_lambda": trial.suggest_float("gae_lambda", 0.9, 0.99, step=0.01),
"max_grad_norm": trial.suggest_float("max_grad_norm", 0.3, 1.0, step=0.05),
"vf_coef": trial.suggest_float("vf_coef", 0.0, 1.0, step=0.05),
- "lr_schedule": trial.suggest_categorical("lr_schedule", ["linear", "constant"]),
- "cr_schedule": trial.suggest_categorical("cr_schedule", ["linear", "constant"]),
+ "lr_schedule": trial.suggest_categorical(
+ "lr_schedule", list(ReforceXY._SCHEDULE_TYPES)
+ ),
+ "cr_schedule": trial.suggest_categorical(
+ "cr_schedule", list(ReforceXY._SCHEDULE_TYPES)
+ ),
"target_kl": trial.suggest_categorical(
"target_kl", [None, 0.01, 0.015, 0.02, 0.03, 0.04]
),
"ortho_init": trial.suggest_categorical("ortho_init", [True, False]),
"net_arch": trial.suggest_categorical(
- "net_arch", ["small", "medium", "large", "extra_large"]
+ "net_arch", list(ReforceXY._NET_ARCH_SIZES)
),
"activation_fn": trial.suggest_categorical(
- "activation_fn", ["tanh", "relu", "elu", "leaky_relu"]
+ "activation_fn", list(ReforceXY._ACTIVATION_FUNCTIONS)
),
"optimizer_class": trial.suggest_categorical(
- "optimizer_class", ["adamw", "rmsprop"]
+ "optimizer_class",
+ [ReforceXY._OPTIMIZER_CLASSES[1], ReforceXY._OPTIMIZER_CLASSES[2]],
),
}
"learning_starts", [500, 1000, 2000, 3000, 4000, 5000, 8000, 10000]
),
"net_arch": trial.suggest_categorical(
- "net_arch", ["small", "medium", "large", "extra_large"]
+ "net_arch", list(ReforceXY._NET_ARCH_SIZES)
),
"activation_fn": trial.suggest_categorical(
- "activation_fn", ["tanh", "relu", "elu", "leaky_relu"]
+ "activation_fn", list(ReforceXY._ACTIVATION_FUNCTIONS)
),
"optimizer_class": trial.suggest_categorical(
- "optimizer_class", ["adamw", "rmsprop"]
+ "optimizer_class",
+ [ReforceXY._OPTIMIZER_CLASSES[1], ReforceXY._OPTIMIZER_CLASSES[2]],
),
}