From: Jérôme Benoit Date: Mon, 3 Mar 2025 13:38:40 +0000 (+0100) Subject: fix(reforcexy): do not overwrite configuration policy_kwargs X-Git-Url: https://git.piment-noir.org/?a=commitdiff_plain;h=0e66bc71d912712cce5957a0e2973f5b1270dfd8;p=freqai-strategies.git fix(reforcexy): do not overwrite configuration policy_kwargs Signed-off-by: Jérôme Benoit --- diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index f4a33b0..c180825 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -230,16 +230,19 @@ class ReforceXY(BaseReinforcementLearningModel): model_params["clip_range"] = linear_schedule(_cr) logger.info("Clip range linear schedule enabled, initial value: %s", _cr) - model_params["policy_kwargs"] = { - "net_arch": self.net_arch, - "activation_fn": th.nn.ReLU, - "optimizer_class": th.optim.Adam, - # "ortho_init": True - } + net_arch = model_params.get("policy_kwargs", {}).get("net_arch", [128, 128]) + model_params["policy_kwargs"].update( + { + "net_arch": net_arch, + # TODO: how to propagate these values from the configuration? + "activation_fn": th.nn.ReLU, + "optimizer_class": th.optim.Adam, + } + ) if "PPO" in self.model_type: model_params["policy_kwargs"]["net_arch"] = { - "pi": self.net_arch, - "vf": self.net_arch, + "pi": net_arch, + "vf": net_arch, } return model_params @@ -538,7 +541,7 @@ class ReforceXY(BaseReinforcementLearningModel): dk.pair if self.rl_config_optuna.get("per_pair", False) else None, ) - return {**self.model_training_parameters, **study.best_trial.params} + return self.model_training_parameters.update(study.best_trial.params) def save_best_params(self, best_params: Dict, pair: str | None = None) -> None: """ @@ -601,7 +604,7 @@ class ReforceXY(BaseReinforcementLearningModel): raise NotImplementedError # Ensure that the sampled parameters take precedence - params = {**self.model_training_parameters, **params} + params = self.model_training_parameters.update(params) nan_encountered = False @@ -1229,17 +1232,37 @@ class InfoMetricsCallback(TensorboardCallback): def _on_training_start(self) -> None: _lr = self.model.learning_rate _lr = _lr if isinstance(_lr, float) else "lr_schedule" - _cr = self.model.clip_range - _cr = _cr if isinstance(_cr, float) else "cr_schedule" hparam_dict = { "algorithm": self.model.__class__.__name__, "learning_rate": _lr, - "clip_range": _cr, - # "gamma": self.model.gamma, - # "gae_lambda": self.model.gae_lambda, - # "n_steps": self.model.n_steps, - # "batch_size": self.model.batch_size, + "gamma": self.model.gamma, + "batch_size": self.model.batch_size, } + if "PPO" in self.model_type: + _cr = self.model.clip_range + _cr = _cr if isinstance(_cr, float) else "cr_schedule" + hparam_dict.update( + { + "clip_range": _cr, + "gae_lambda": self.model.gae_lambda, + "n_steps": self.model.n_steps, + "n_epochs": self.model.n_epochs, + "ent_coef": self.model.ent_coef, + "vf_coef": self.model.vf_coef, + } + ) + if "DQN" in self.model_type: + hparam_dict.update( + { + "buffer_size": self.model.buffer_size, + "gradient_steps": self.model.gradient_steps, + "train_freq": self.model.train_freq, + "learning_starts": self.model.learning_starts, + "target_update_interval": self.model.target_update_interval, + "exploration_fraction": self.model.exploration_fraction, + "exploration_final_eps": self.model.exploration_final_eps, + } + ) metric_dict = { "info/total_reward": 0, "info/total_profit": 0, @@ -1397,28 +1420,28 @@ def sample_params_ppo(trial: Trial) -> Dict[str, Any]: gae_lambda = trial.suggest_float("gae_lambda", 0.1, 0.99, step=0.01) max_grad_norm = trial.suggest_float("max_grad_norm", 0.1, 5, step=0.01) vf_coef = trial.suggest_float("vf_coef", 0, 1, step=0.01) - net_arch_type = trial.suggest_categorical("net_arch", ["small", "medium", "large"]) - # ortho_init = True ortho_init = trial.suggest_categorical("ortho_init", [False, True]) - activation_fn_name = trial.suggest_categorical( - "activation_fn", ["tanh", "relu", "elu", "leaky_relu"] - ) lr_schedule = trial.suggest_categorical("lr_schedule", ["linear", "constant"]) if lr_schedule == "linear": learning_rate = linear_schedule(learning_rate) - cr_schedule = trial.suggest_categorical("cr_schedule", ["linear", "constant"]) if cr_schedule == "linear": clip_range = linear_schedule(clip_range) if batch_size > n_steps: batch_size = n_steps + net_arch_type: str = trial.suggest_categorical( + "net_arch", ["small", "medium", "large", "extra_large"] + ) net_arch = { "small": {"pi": [128, 128], "vf": [128, 128]}, "medium": {"pi": [256, 256], "vf": [256, 256]}, "large": {"pi": [512, 512], "vf": [512, 512]}, + "extra_large": {"pi": [1024, 1024], "vf": [1024, 1024]}, }[net_arch_type] - - activation_fn = { + activation_fn_name: str = trial.suggest_categorical( + "activation_fn", ["tanh", "relu", "elu", "leaky_relu"] + ) + activation_fn: Dict[str, type[th.nn.Module]] = { "tanh": th.nn.Tanh, "relu": th.nn.ReLU, "elu": th.nn.ELU, @@ -1451,7 +1474,10 @@ def sample_params_dqn(trial: Trial) -> Dict[str, Any]: "gamma", [0.9, 0.95, 0.98, 0.99, 0.995, 0.999, 0.9999] ) learning_rate = trial.suggest_float("learning_rate", 1e-6, 0.01, log=True) - batch_size = trial.suggest_categorical("batch_size", [64, 256, 512]) + lr_schedule = trial.suggest_categorical("lr_schedule", ["linear", "constant"]) + if lr_schedule == "linear": + learning_rate = linear_schedule(learning_rate) + batch_size = trial.suggest_categorical("batch_size", [64, 128, 256, 512]) buffer_size = trial.suggest_categorical( "buffer_size", [int(1e4), int(5e4), int(1e5)] ) @@ -1463,15 +1489,29 @@ def sample_params_dqn(trial: Trial) -> Dict[str, Any]: "target_update_interval", [1000, 5000, 10000] ) learning_starts = trial.suggest_categorical("learning_starts", [1000, 5000, 10000]) - train_freq = trial.suggest_categorical("train_freq", [1, 4, 8, 16, 128, 256, 1000]) - subsample_steps = trial.suggest_categorical("subsample_steps", [1, 2, 4, 8]) + train_freq = trial.suggest_categorical( + "train_freq", [2, 4, 8, 16, 128, 256, 512, 1024] + ) + subsample_steps = trial.suggest_categorical("subsample_steps", [2, 4, 8]) gradient_steps = max(train_freq // subsample_steps, 1) - net_arch_type = trial.suggest_categorical("net_arch", ["small", "medium", "large"]) + net_arch_type: str = trial.suggest_categorical( + "net_arch", ["small", "medium", "large", "extra_large"] + ) net_arch = { "small": [128, 128], "medium": [256, 256], "large": [512, 512], + "extra_large": [1024, 1024], }[net_arch_type] + activation_fn_name: str = trial.suggest_categorical( + "activation_fn", ["tanh", "relu", "elu", "leaky_relu"] + ) + activation_fn: Dict[str, type[th.nn.Module]] = { + "tanh": th.nn.Tanh, + "relu": th.nn.ReLU, + "elu": th.nn.ELU, + "leaky_relu": th.nn.LeakyReLU, + }[activation_fn_name] return { "gamma": gamma, "learning_rate": learning_rate, @@ -1483,7 +1523,7 @@ def sample_params_dqn(trial: Trial) -> Dict[str, Any]: "exploration_final_eps": exploration_final_eps, "target_update_interval": target_update_interval, "learning_starts": learning_starts, - "policy_kwargs": dict(net_arch=net_arch), + "policy_kwargs": dict(net_arch=net_arch, activation_fn=activation_fn), } @@ -1491,7 +1531,7 @@ def sample_params_qrdqn(trial: Trial) -> Dict[str, Any]: """ Sampler for QRDQN hyperparams """ - hyperparams = sample_params_dqn(trial) + params = sample_params_dqn(trial) n_quantiles = trial.suggest_int("n_quantiles", 5, 200) - hyperparams["policy_kwargs"].update({"n_quantiles": n_quantiles}) - return hyperparams + params["policy_kwargs"].update({"n_quantiles": n_quantiles}) + return params