]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
fix(reforcexy): do not overwrite configuration policy_kwargs
authorJérôme Benoit <jerome.benoit@sap.com>
Mon, 3 Mar 2025 13:38:40 +0000 (14:38 +0100)
committerJérôme Benoit <jerome.benoit@sap.com>
Mon, 3 Mar 2025 13:38:40 +0000 (14:38 +0100)
Signed-off-by: Jérôme Benoit <jerome.benoit@sap.com>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index f4a33b0275aa2ad0a18d07d73d5779ca282b07b2..c180825de93924b655897137424c2ed7e3abe79c 100644 (file)
@@ -230,16 +230,19 @@ class ReforceXY(BaseReinforcementLearningModel):
             model_params["clip_range"] = linear_schedule(_cr)
             logger.info("Clip range linear schedule enabled, initial value: %s", _cr)
 
-        model_params["policy_kwargs"] = {
-            "net_arch": self.net_arch,
-            "activation_fn": th.nn.ReLU,
-            "optimizer_class": th.optim.Adam,
-            # "ortho_init": True
-        }
+        net_arch = model_params.get("policy_kwargs", {}).get("net_arch", [128, 128])
+        model_params["policy_kwargs"].update(
+            {
+                "net_arch": net_arch,
+                # TODO: how to propagate these values from the configuration?
+                "activation_fn": th.nn.ReLU,
+                "optimizer_class": th.optim.Adam,
+            }
+        )
         if "PPO" in self.model_type:
             model_params["policy_kwargs"]["net_arch"] = {
-                "pi": self.net_arch,
-                "vf": self.net_arch,
+                "pi": net_arch,
+                "vf": net_arch,
             }
 
         return model_params
@@ -538,7 +541,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             dk.pair if self.rl_config_optuna.get("per_pair", False) else None,
         )
 
-        return {**self.model_training_parameters, **study.best_trial.params}
+        return self.model_training_parameters.update(study.best_trial.params)
 
     def save_best_params(self, best_params: Dict, pair: str | None = None) -> None:
         """
@@ -601,7 +604,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             raise NotImplementedError
 
         # Ensure that the sampled parameters take precedence
-        params = {**self.model_training_parameters, **params}
+        params = self.model_training_parameters.update(params)
 
         nan_encountered = False
 
@@ -1229,17 +1232,37 @@ class InfoMetricsCallback(TensorboardCallback):
     def _on_training_start(self) -> None:
         _lr = self.model.learning_rate
         _lr = _lr if isinstance(_lr, float) else "lr_schedule"
-        _cr = self.model.clip_range
-        _cr = _cr if isinstance(_cr, float) else "cr_schedule"
         hparam_dict = {
             "algorithm": self.model.__class__.__name__,
             "learning_rate": _lr,
-            "clip_range": _cr,
-            # "gamma": self.model.gamma,
-            # "gae_lambda": self.model.gae_lambda,
-            # "n_steps": self.model.n_steps,
-            # "batch_size": self.model.batch_size,
+            "gamma": self.model.gamma,
+            "batch_size": self.model.batch_size,
         }
+        if "PPO" in self.model_type:
+            _cr = self.model.clip_range
+            _cr = _cr if isinstance(_cr, float) else "cr_schedule"
+            hparam_dict.update(
+                {
+                    "clip_range": _cr,
+                    "gae_lambda": self.model.gae_lambda,
+                    "n_steps": self.model.n_steps,
+                    "n_epochs": self.model.n_epochs,
+                    "ent_coef": self.model.ent_coef,
+                    "vf_coef": self.model.vf_coef,
+                }
+            )
+        if "DQN" in self.model_type:
+            hparam_dict.update(
+                {
+                    "buffer_size": self.model.buffer_size,
+                    "gradient_steps": self.model.gradient_steps,
+                    "train_freq": self.model.train_freq,
+                    "learning_starts": self.model.learning_starts,
+                    "target_update_interval": self.model.target_update_interval,
+                    "exploration_fraction": self.model.exploration_fraction,
+                    "exploration_final_eps": self.model.exploration_final_eps,
+                }
+            )
         metric_dict = {
             "info/total_reward": 0,
             "info/total_profit": 0,
@@ -1397,28 +1420,28 @@ def sample_params_ppo(trial: Trial) -> Dict[str, Any]:
     gae_lambda = trial.suggest_float("gae_lambda", 0.1, 0.99, step=0.01)
     max_grad_norm = trial.suggest_float("max_grad_norm", 0.1, 5, step=0.01)
     vf_coef = trial.suggest_float("vf_coef", 0, 1, step=0.01)
-    net_arch_type = trial.suggest_categorical("net_arch", ["small", "medium", "large"])
-    # ortho_init = True
     ortho_init = trial.suggest_categorical("ortho_init", [False, True])
-    activation_fn_name = trial.suggest_categorical(
-        "activation_fn", ["tanh", "relu", "elu", "leaky_relu"]
-    )
     lr_schedule = trial.suggest_categorical("lr_schedule", ["linear", "constant"])
     if lr_schedule == "linear":
         learning_rate = linear_schedule(learning_rate)
-
     cr_schedule = trial.suggest_categorical("cr_schedule", ["linear", "constant"])
     if cr_schedule == "linear":
         clip_range = linear_schedule(clip_range)
     if batch_size > n_steps:
         batch_size = n_steps
+    net_arch_type: str = trial.suggest_categorical(
+        "net_arch", ["small", "medium", "large", "extra_large"]
+    )
     net_arch = {
         "small": {"pi": [128, 128], "vf": [128, 128]},
         "medium": {"pi": [256, 256], "vf": [256, 256]},
         "large": {"pi": [512, 512], "vf": [512, 512]},
+        "extra_large": {"pi": [1024, 1024], "vf": [1024, 1024]},
     }[net_arch_type]
-
-    activation_fn = {
+    activation_fn_name: str = trial.suggest_categorical(
+        "activation_fn", ["tanh", "relu", "elu", "leaky_relu"]
+    )
+    activation_fn: Dict[str, type[th.nn.Module]] = {
         "tanh": th.nn.Tanh,
         "relu": th.nn.ReLU,
         "elu": th.nn.ELU,
@@ -1451,7 +1474,10 @@ def sample_params_dqn(trial: Trial) -> Dict[str, Any]:
         "gamma", [0.9, 0.95, 0.98, 0.99, 0.995, 0.999, 0.9999]
     )
     learning_rate = trial.suggest_float("learning_rate", 1e-6, 0.01, log=True)
-    batch_size = trial.suggest_categorical("batch_size", [64, 256, 512])
+    lr_schedule = trial.suggest_categorical("lr_schedule", ["linear", "constant"])
+    if lr_schedule == "linear":
+        learning_rate = linear_schedule(learning_rate)
+    batch_size = trial.suggest_categorical("batch_size", [64, 128, 256, 512])
     buffer_size = trial.suggest_categorical(
         "buffer_size", [int(1e4), int(5e4), int(1e5)]
     )
@@ -1463,15 +1489,29 @@ def sample_params_dqn(trial: Trial) -> Dict[str, Any]:
         "target_update_interval", [1000, 5000, 10000]
     )
     learning_starts = trial.suggest_categorical("learning_starts", [1000, 5000, 10000])
-    train_freq = trial.suggest_categorical("train_freq", [1, 4, 8, 16, 128, 256, 1000])
-    subsample_steps = trial.suggest_categorical("subsample_steps", [1, 2, 4, 8])
+    train_freq = trial.suggest_categorical(
+        "train_freq", [2, 4, 8, 16, 128, 256, 512, 1024]
+    )
+    subsample_steps = trial.suggest_categorical("subsample_steps", [2, 4, 8])
     gradient_steps = max(train_freq // subsample_steps, 1)
-    net_arch_type = trial.suggest_categorical("net_arch", ["small", "medium", "large"])
+    net_arch_type: str = trial.suggest_categorical(
+        "net_arch", ["small", "medium", "large", "extra_large"]
+    )
     net_arch = {
         "small": [128, 128],
         "medium": [256, 256],
         "large": [512, 512],
+        "extra_large": [1024, 1024],
     }[net_arch_type]
+    activation_fn_name: str = trial.suggest_categorical(
+        "activation_fn", ["tanh", "relu", "elu", "leaky_relu"]
+    )
+    activation_fn: Dict[str, type[th.nn.Module]] = {
+        "tanh": th.nn.Tanh,
+        "relu": th.nn.ReLU,
+        "elu": th.nn.ELU,
+        "leaky_relu": th.nn.LeakyReLU,
+    }[activation_fn_name]
     return {
         "gamma": gamma,
         "learning_rate": learning_rate,
@@ -1483,7 +1523,7 @@ def sample_params_dqn(trial: Trial) -> Dict[str, Any]:
         "exploration_final_eps": exploration_final_eps,
         "target_update_interval": target_update_interval,
         "learning_starts": learning_starts,
-        "policy_kwargs": dict(net_arch=net_arch),
+        "policy_kwargs": dict(net_arch=net_arch, activation_fn=activation_fn),
     }
 
 
@@ -1491,7 +1531,7 @@ def sample_params_qrdqn(trial: Trial) -> Dict[str, Any]:
     """
     Sampler for QRDQN hyperparams
     """
-    hyperparams = sample_params_dqn(trial)
+    params = sample_params_dqn(trial)
     n_quantiles = trial.suggest_int("n_quantiles", 5, 200)
-    hyperparams["policy_kwargs"].update({"n_quantiles": n_quantiles})
-    return hyperparams
+    params["policy_kwargs"].update({"n_quantiles": n_quantiles})
+    return params