From: Jérôme Benoit Date: Mon, 3 Mar 2025 21:01:13 +0000 (+0100) Subject: refactor(reforcexy): cleanup hyperparams handling X-Git-Url: https://git.piment-noir.org/?a=commitdiff_plain;h=2648820a76f0507161e60bc6453d84adcc26083b;p=freqai-strategies.git refactor(reforcexy): cleanup hyperparams handling Signed-off-by: Jérôme Benoit --- diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index 7274b3a..c485ed0 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -234,30 +234,21 @@ class ReforceXY(BaseReinforcementLearningModel): model_params["policy_kwargs"] = {} net_arch = model_params["policy_kwargs"].get("net_arch", [128, 128]) - model_params["policy_kwargs"]["net_arch"] = net_arch if "PPO" in self.model_type: model_params["policy_kwargs"]["net_arch"] = { "pi": net_arch, "vf": net_arch, } + else: + model_params["policy_kwargs"]["net_arch"] = net_arch + + model_params["policy_kwargs"]["activation_fn"] = model_params[ + "policy_kwargs" + ].get("activation_fn", "relu") - activation_fn = model_params["policy_kwargs"].get("activation_fn", "relu") - if activation_fn == "tanh": - model_params["policy_kwargs"]["activation_fn"] = th.nn.Tanh - elif activation_fn == "relu": - model_params["policy_kwargs"]["activation_fn"] = th.nn.ReLU - elif activation_fn == "elu": - model_params["policy_kwargs"]["activation_fn"] = th.nn.ELU - elif activation_fn == "leaky_relu": - model_params["policy_kwargs"]["activation_fn"] = th.nn.LeakyReLU - - optimizer_class = model_params["policy_kwargs"].get("optimizer_class", "adam") - if optimizer_class == "adam": - model_params["policy_kwargs"]["optimizer_class"] = th.optim.Adam - elif optimizer_class == "rmsprop": - model_params["policy_kwargs"]["optimizer_class"] = th.optim.RMSprop - elif optimizer_class == "sgd": - model_params["policy_kwargs"]["optimizer_class"] = th.optim.SGD + model_params["policy_kwargs"]["optimizer_class"] = model_params[ + "policy_kwargs" + ].get("optimizer_class", "adam") return model_params @@ -1422,6 +1413,50 @@ def steps_to_days(steps: int, timeframe: str) -> float: return round(days, 1) +def get_net_arch( + model_type: str, net_arch_type: str +) -> Dict[str, list[int] | Dict[str, list[int]]]: + """ + Get network architecture + """ + if "PPO" in model_type: + return { + "small": {"pi": [128, 128], "vf": [128, 128]}, + "medium": {"pi": [256, 256], "vf": [256, 256]}, + "large": {"pi": [512, 512], "vf": [512, 512]}, + "extra_large": {"pi": [1024, 1024], "vf": [1024, 1024]}, + }[net_arch_type] + return { + "small": [128, 128], + "medium": [256, 256], + "large": [512, 512], + "extra_large": [1024, 1024], + }[net_arch_type] + + +def get_activation_fn(activation_fn_name: str) -> type[th.nn.Module]: + """ + Get activation function + """ + return { + "tanh": th.nn.Tanh, + "relu": th.nn.ReLU, + "elu": th.nn.ELU, + "leaky_relu": th.nn.LeakyReLU, + }[activation_fn_name] + + +def get_optimizer_class(optimizer_class_name: str) -> type[th.optim.Optimizer]: + """ + Get optimizer class + """ + return { + "adam": th.optim.Adam, + "rmsprop": th.optim.RMSprop, + "sgd": th.optim.SGD, + }[optimizer_class_name] + + def sample_params_ppo(trial: Trial) -> Dict[str, Any]: """ Sampler for PPO hyperparams @@ -1448,29 +1483,15 @@ def sample_params_ppo(trial: Trial) -> Dict[str, Any]: net_arch_type: str = trial.suggest_categorical( "net_arch", ["small", "medium", "large", "extra_large"] ) - net_arch = { - "small": {"pi": [128, 128], "vf": [128, 128]}, - "medium": {"pi": [256, 256], "vf": [256, 256]}, - "large": {"pi": [512, 512], "vf": [512, 512]}, - "extra_large": {"pi": [1024, 1024], "vf": [1024, 1024]}, - }[net_arch_type] + net_arch = get_net_arch("PPO", net_arch_type) activation_fn_name: str = trial.suggest_categorical( "activation_fn", ["tanh", "relu", "elu", "leaky_relu"] ) - activation_fn: Dict[str, type[th.nn.Module]] = { - "tanh": th.nn.Tanh, - "relu": th.nn.ReLU, - "elu": th.nn.ELU, - "leaky_relu": th.nn.LeakyReLU, - }[activation_fn_name] + activation_fn = get_activation_fn(activation_fn_name) optimizer_class_name = trial.suggest_categorical( "optimizer_class", ["adam", "rmsprop", "sgd"] ) - optimizer_class: Dict[str, type[th.optim.Optimizer]] = { - "adam": th.optim.Adam, - "rmsprop": th.optim.RMSprop, - "sgd": th.optim.SGD, - }[optimizer_class_name] + optimizer_class = get_optimizer_class(optimizer_class_name) return { "n_steps": n_steps, "batch_size": batch_size, @@ -1498,11 +1519,11 @@ def sample_params_dqn(trial: Trial) -> Dict[str, Any]: gamma = trial.suggest_categorical( "gamma", [0.9, 0.95, 0.98, 0.99, 0.995, 0.999, 0.9999] ) + batch_size = trial.suggest_categorical("batch_size", [64, 128, 256, 512]) learning_rate = trial.suggest_float("learning_rate", 1e-6, 0.01, log=True) lr_schedule = trial.suggest_categorical("lr_schedule", ["linear", "constant"]) if lr_schedule == "linear": learning_rate = linear_schedule(learning_rate) - batch_size = trial.suggest_categorical("batch_size", [64, 128, 256, 512]) buffer_size = trial.suggest_categorical( "buffer_size", [int(1e4), int(5e4), int(1e5)] ) @@ -1522,33 +1543,19 @@ def sample_params_dqn(trial: Trial) -> Dict[str, Any]: net_arch_type: str = trial.suggest_categorical( "net_arch", ["small", "medium", "large", "extra_large"] ) - net_arch = { - "small": [128, 128], - "medium": [256, 256], - "large": [512, 512], - "extra_large": [1024, 1024], - }[net_arch_type] + net_arch = get_net_arch("DQN", net_arch_type) activation_fn_name: str = trial.suggest_categorical( "activation_fn", ["tanh", "relu", "elu", "leaky_relu"] ) - activation_fn: Dict[str, type[th.nn.Module]] = { - "tanh": th.nn.Tanh, - "relu": th.nn.ReLU, - "elu": th.nn.ELU, - "leaky_relu": th.nn.LeakyReLU, - }[activation_fn_name] + activation_fn = get_activation_fn(activation_fn_name) optimizer_class_name = trial.suggest_categorical( "optimizer_class", ["adam", "rmsprop", "sgd"] ) - optimizer_class: Dict[str, type[th.optim.Optimizer]] = { - "adam": th.optim.Adam, - "rmsprop": th.optim.RMSprop, - "sgd": th.optim.SGD, - }[optimizer_class_name] + optimizer_class = get_optimizer_class(optimizer_class_name) return { "gamma": gamma, - "learning_rate": learning_rate, "batch_size": batch_size, + "learning_rate": learning_rate, "buffer_size": buffer_size, "train_freq": train_freq, "gradient_steps": gradient_steps,