)
else:
raise ValueError(
- f"Unsupported sampler: {sampler!r}. Supported samplers: {', '.join(self._SAMPLER_TYPES)}"
+ f"Unsupported sampler: {sampler}. Supported samplers: {', '.join(self._SAMPLER_TYPES)}"
)
@staticmethod
if self._entry_additive_enabled or self._exit_additive_enabled:
logger.info(
"PBRS canonical mode: additive rewards disabled with Φ(terminal)=0. PBRS invariance is preserved. "
- f"To use additive rewards, set exit_potential_mode={ReforceXY._EXIT_POTENTIAL_MODES[1]!r}."
+ f"To use additive rewards, set exit_potential_mode={ReforceXY._EXIT_POTENTIAL_MODES[1]}."
)
self._entry_additive_enabled = False
self._exit_additive_enabled = False
hparam_dict.update({"n_updates": int(n_updates)})
except Exception:
pass
- if "PPO" in self.model.__class__.__name__:
+ if ReforceXY._MODEL_TYPES[0] in self.model.__class__.__name__: # "PPO"
cr = getattr(self.model, "clip_range", None)
cr_schedule, cr_iv, cr_fv = get_schedule_type(cr)
hparam_dict.update(
)
if getattr(self.model, "target_kl", None) is not None:
hparam_dict["target_kl"] = float(self.model.target_kl)
- if "RecurrentPPO" in self.model.__class__.__name__:
+ if (
+ ReforceXY._MODEL_TYPES[1] in self.model.__class__.__name__
+ ): # "RecurrentPPO"
policy = getattr(self.model, "policy", None)
if policy is not None:
lstm_actor = getattr(policy, "lstm_actor", None)
"n_lstm_layers": int(lstm_actor.num_layers),
}
)
- if "DQN" in self.model.__class__.__name__:
+ if ReforceXY._MODEL_TYPES[3] in self.model.__class__.__name__: # "DQN"
hparam_dict.update(
{
"buffer_size": int(self.model.buffer_size),
)
if train_freq is not None:
hparam_dict.update({"train_freq": train_freq})
- if "QRDQN" in self.model.__class__.__name__:
+ if ReforceXY._MODEL_TYPES[4] in self.model.__class__.__name__: # "QRDQN"
hparam_dict.update({"n_quantiles": int(self.model.n_quantiles)})
metric_dict: dict[str, float | int] = {
"eval/mean_reward": 0.0,
"info/trade_count": 0,
"info/trade_duration": 0,
}
- if "PPO" in self.model.__class__.__name__:
+ if ReforceXY._MODEL_TYPES[0] in self.model.__class__.__name__: # "PPO"
metric_dict.update(
{
"train/approx_kl": 0.0,
"train/explained_variance": 0.0,
}
)
- if "DQN" in self.model.__class__.__name__:
+ if ReforceXY._MODEL_TYPES[3] in self.model.__class__.__name__: # "DQN"
metric_dict.update(
{
"train/loss": 0.0,
except Exception:
pass
- if "PPO" in self.model.__class__.__name__:
+ if ReforceXY._MODEL_TYPES[0] in self.model.__class__.__name__: # "PPO"
try:
cr = getattr(self.model, "clip_range", None)
cr = _eval_schedule(cr)
except Exception:
pass
- if "DQN" in self.model.__class__.__name__:
+ if ReforceXY._MODEL_TYPES[3] in self.model.__class__.__name__: # "DQN"
try:
er = getattr(self.model, "exploration_rate", None)
if _is_finite_number(er):
if isinstance(schedule, (int, float)):
try:
schedule = float(schedule)
- return "constant", schedule, schedule
+ return ReforceXY._SCHEDULE_TYPES[1], schedule, schedule # "constant"
except Exception:
- return "constant", np.nan, np.nan
+ return ReforceXY._SCHEDULE_TYPES[1], np.nan, np.nan # "constant"
elif isinstance(schedule, ConstantSchedule):
try:
- return "constant", schedule(1.0), schedule(0.0)
+ return (
+ ReforceXY._SCHEDULE_TYPES[1],
+ schedule(1.0),
+ schedule(0.0),
+ ) # "constant"
except Exception:
- return "constant", np.nan, np.nan
+ return ReforceXY._SCHEDULE_TYPES[1], np.nan, np.nan # "constant"
elif isinstance(schedule, SimpleLinearSchedule):
try:
- return "linear", schedule(1.0), schedule(0.0)
+ return (
+ ReforceXY._SCHEDULE_TYPES[0],
+ schedule(1.0),
+ schedule(0.0),
+ ) # "linear"
except Exception:
- return "linear", np.nan, np.nan
+ return ReforceXY._SCHEDULE_TYPES[0], np.nan, np.nan # "linear"
- return "unknown", np.nan, np.nan
+ return ReforceXY._SCHEDULE_TYPES[2], np.nan, np.nan # "unknown"
def get_schedule(
schedule_type: Literal["linear", "constant"],
initial_value: float,
) -> Callable[[float], float]:
- if schedule_type == ReforceXY._SCHEDULE_TYPES[0]: # "linear"
+ if schedule_type == ReforceXY._SCHEDULE_TYPES[0]:
return SimpleLinearSchedule(initial_value)
- elif schedule_type == ReforceXY._SCHEDULE_TYPES[1]: # "constant"
+ elif schedule_type == ReforceXY._SCHEDULE_TYPES[1]:
return ConstantSchedule(initial_value)
else:
return ConstantSchedule(initial_value)
def get_net_arch(
- model_type: str, net_arch_type: Literal["small", "medium", "large", "extra_large"]
+ model_type: str, net_arch_type: NetArchSize
) -> Union[List[int], Dict[str, List[int]]]:
"""
Get network architecture
"""
- if "PPO" in model_type:
+ if ReforceXY._MODEL_TYPES[0] in model_type: # "PPO"
return {
- "small": {"pi": [128, 128], "vf": [128, 128]},
- "medium": {"pi": [256, 256], "vf": [256, 256]},
- "large": {"pi": [512, 512], "vf": [512, 512]},
- "extra_large": {"pi": [1024, 1024], "vf": [1024, 1024]},
+ ReforceXY._NET_ARCH_SIZES[0]: {
+ "pi": [128, 128],
+ "vf": [128, 128],
+ }, # ReforceXY._NET_ARCH_SIZES[0]
+ ReforceXY._NET_ARCH_SIZES[1]: {
+ "pi": [256, 256],
+ "vf": [256, 256],
+ }, # ReforceXY._NET_ARCH_SIZES[1]
+ ReforceXY._NET_ARCH_SIZES[2]: {
+ "pi": [512, 512],
+ "vf": [512, 512],
+ }, # "large"
+ ReforceXY._NET_ARCH_SIZES[3]: {
+ "pi": [1024, 1024],
+ "vf": [1024, 1024],
+ }, # "extra_large"
}.get(net_arch_type, {"pi": [128, 128], "vf": [128, 128]})
return {
- "small": [128, 128],
- "medium": [256, 256],
- "large": [512, 512],
- "extra_large": [1024, 1024],
+ ReforceXY._NET_ARCH_SIZES[0]: [128, 128], # "small"
+ ReforceXY._NET_ARCH_SIZES[1]: [256, 256], # "medium"
+ ReforceXY._NET_ARCH_SIZES[2]: [512, 512], # "large"
+ ReforceXY._NET_ARCH_SIZES[3]: [1024, 1024], # "extra_large"
}.get(net_arch_type, [128, 128])
def get_activation_fn(
- activation_fn_name: Literal["tanh", "relu", "elu", "leaky_relu"],
+ activation_fn_name: ActivationFunction,
) -> Type[th.nn.Module]:
"""
Get activation function
"""
return {
- "tanh": th.nn.Tanh,
- "relu": th.nn.ReLU,
- "elu": th.nn.ELU,
- "leaky_relu": th.nn.LeakyReLU,
+ ReforceXY._ACTIVATION_FUNCTIONS[0]: th.nn.Tanh, # "tanh"
+ ReforceXY._ACTIVATION_FUNCTIONS[1]: th.nn.ReLU, # "relu"
+ ReforceXY._ACTIVATION_FUNCTIONS[2]: th.nn.ELU, # "elu"
+ ReforceXY._ACTIVATION_FUNCTIONS[3]: th.nn.LeakyReLU, # "leaky_relu"
}.get(activation_fn_name, th.nn.ReLU)
def get_optimizer_class(
- optimizer_class_name: Literal["adam", "adamw", "rmsprop"],
+ optimizer_class_name: OptimizerClass,
) -> Type[th.optim.Optimizer]:
"""
Get optimizer class
"""
return {
- "adam": th.optim.Adam,
- "adamw": th.optim.AdamW,
- "rmsprop": th.optim.RMSprop,
+ ReforceXY._OPTIMIZER_CLASSES[0]: th.optim.Adam, # "adam"
+ ReforceXY._OPTIMIZER_CLASSES[1]: th.optim.AdamW, # "adamw"
+ ReforceXY._OPTIMIZER_CLASSES[2]: th.optim.RMSprop, # "rmsprop"
}.get(optimizer_class_name, th.optim.Adam)
lr = optuna_params.get("learning_rate")
if lr is None:
- raise ValueError(
- f"missing {'learning_rate'!r} in optuna params for {model_type}"
- )
+ raise ValueError(f"missing {'learning_rate'} in optuna params for {model_type}")
lr = get_schedule(
optuna_params.get("lr_schedule", ReforceXY._SCHEDULE_TYPES[1]), float(lr)
) # default: "constant"
- if "PPO" in model_type:
+ if ReforceXY._MODEL_TYPES[0] in model_type: # "PPO"
required_ppo_params = [
"clip_range",
"n_steps",
)
if optuna_params.get("target_kl") is not None:
model_params["target_kl"] = float(optuna_params.get("target_kl"))
- if "RecurrentPPO" in model_type:
+ if ReforceXY._MODEL_TYPES[1] in model_type: # "RecurrentPPO"
policy_kwargs["lstm_hidden_size"] = int(
optuna_params.get("lstm_hidden_size")
)
policy_kwargs["n_lstm_layers"] = int(optuna_params.get("n_lstm_layers"))
- elif "DQN" in model_type:
+ elif ReforceXY._MODEL_TYPES[3] in model_type: # "DQN"
required_dqn_params = [
"gamma",
"batch_size",
"learning_starts": int(optuna_params.get("learning_starts")),
}
)
- if "QRDQN" in model_type and optuna_params.get("n_quantiles") is not None:
+ if (
+ ReforceXY._MODEL_TYPES[4] in model_type
+ and optuna_params.get("n_quantiles") is not None
+ ): # "QRDQN"
policy_kwargs["n_quantiles"] = int(optuna_params["n_quantiles"])
else:
raise ValueError(f"Model {model_type} not supported")
"batch_size", [64, 128, 256, 512, 1024]
),
"learning_rate": trial.suggest_float("learning_rate", 1e-5, 3e-3, log=True),
- "lr_schedule": trial.suggest_categorical("lr_schedule", ["linear", "constant"]),
+ "lr_schedule": trial.suggest_categorical(
+ "lr_schedule", list(ReforceXY._SCHEDULE_TYPES[:2])
+ ), # ["linear", "constant"]
"buffer_size": trial.suggest_categorical(
"buffer_size", [int(1e4), int(5e4), int(1e5), int(2e5)]
),