]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
fix(reforcexy): properly handle activation_fn and optimizer_class
authorJérôme Benoit <jerome.benoit@sap.com>
Mon, 3 Mar 2025 15:51:21 +0000 (16:51 +0100)
committerJérôme Benoit <jerome.benoit@sap.com>
Mon, 3 Mar 2025 15:51:21 +0000 (16:51 +0100)
Signed-off-by: Jérôme Benoit <jerome.benoit@sap.com>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index 456fbcf5e60bcb4977525bd62c0c90fa5e5f803f..00d9f34d9afd6837bfb3fe8f7348a028ce9b45f7 100644 (file)
@@ -230,23 +230,35 @@ class ReforceXY(BaseReinforcementLearningModel):
             model_params["clip_range"] = linear_schedule(_cr)
             logger.info("Clip range linear schedule enabled, initial value: %s", _cr)
 
-        net_arch = model_params.get("policy_kwargs", {}).get("net_arch", [128, 128])
         if not model_params.get("policy_kwargs"):
             model_params["policy_kwargs"] = {}
-        model_params["policy_kwargs"].update(
-            {
-                "net_arch": net_arch,
-                # TODO: how to propagate these values from the configuration?
-                "activation_fn": th.nn.ReLU,
-                "optimizer_class": th.optim.Adam,
-            }
-        )
+
+        net_arch = model_params["policy_kwargs"].get("net_arch", [128, 128])
+        model_params["policy_kwargs"]["net_arch"] = net_arch
         if "PPO" in self.model_type:
             model_params["policy_kwargs"]["net_arch"] = {
                 "pi": net_arch,
                 "vf": net_arch,
             }
 
+        activation_fn = model_params["policy_kwargs"].get("activation_fn", "relu")
+        if activation_fn == "tanh":
+            model_params["policy_kwargs"]["activation_fn"] = th.nn.Tanh
+        elif activation_fn == "relu":
+            model_params["policy_kwargs"]["activation_fn"] = th.nn.ReLU
+        elif activation_fn == "elu":
+            model_params["policy_kwargs"]["activation_fn"] = th.nn.ELU
+        elif activation_fn == "leaky_relu":
+            model_params["policy_kwargs"]["activation_fn"] = th.nn.LeakyReLU
+
+        optimizer_class = model_params["policy_kwargs"].get("optimizer_class", "adam")
+        if optimizer_class == "adam":
+            model_params["policy_kwargs"]["optimizer_class"] = th.optim.Adam
+        elif optimizer_class == "rmsprop":
+            model_params["policy_kwargs"]["optimizer_class"] = th.optim.RMSprop
+        elif optimizer_class == "sgd":
+            model_params["policy_kwargs"]["optimizer_class"] = th.optim.SGD
+
         return model_params
 
     def get_callbacks(
@@ -543,7 +555,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             dk.pair if self.rl_config_optuna.get("per_pair", False) else None,
         )
 
-        return self.model_training_parameters.update(study.best_trial.params)
+        return {**self.model_training_parameters, **study.best_trial.params}
 
     def save_best_params(self, best_params: Dict, pair: str | None = None) -> None:
         """
@@ -606,7 +618,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             raise NotImplementedError
 
         # Ensure that the sampled parameters take precedence
-        params = self.model_training_parameters.update(params)
+        params = {**self.model_training_parameters, **params}
 
         nan_encountered = False
 
@@ -670,7 +682,6 @@ class ReforceXY(BaseReinforcementLearningModel):
             self.timeout: int = self.rl_config.get("max_trade_duration_candles", 128)
             self._last_closed_position: Positions = None
             self._last_closed_trade_tick: int = 0
-            # self.reward_range = (-1, 1)
             if self.force_actions:
                 logger.info(
                     "%s - take_profit: %s, stop_loss: %s, timeout: %s candles (%s days), observation_space: %s",
@@ -1234,13 +1245,13 @@ class InfoMetricsCallback(TensorboardCallback):
     def _on_training_start(self) -> None:
         _lr = self.model.learning_rate
         _lr = _lr if isinstance(_lr, float) else "lr_schedule"
-        hparam_dict = {
+        hparam_dict: Dict = {
             "algorithm": self.model.__class__.__name__,
             "learning_rate": _lr,
             "gamma": self.model.gamma,
             "batch_size": self.model.batch_size,
         }
-        if "PPO" in self.model_type:
+        if "PPO" in self.model.__class__.__name__:
             _cr = self.model.clip_range
             _cr = _cr if isinstance(_cr, float) else "cr_schedule"
             hparam_dict.update(
@@ -1253,7 +1264,7 @@ class InfoMetricsCallback(TensorboardCallback):
                     "vf_coef": self.model.vf_coef,
                 }
             )
-        if "DQN" in self.model_type:
+        if "DQN" in self.model.__class__.__name__:
             hparam_dict.update(
                 {
                     "buffer_size": self.model.buffer_size,
@@ -1265,6 +1276,8 @@ class InfoMetricsCallback(TensorboardCallback):
                     "exploration_final_eps": self.model.exploration_final_eps,
                 }
             )
+        if "QRDQN" in self.model.__class__.__name__:
+            hparam_dict.update({"n_quantiles": self.model.n_quantiles})
         metric_dict = {
             "info/total_reward": 0,
             "info/total_profit": 0,
@@ -1449,6 +1462,14 @@ def sample_params_ppo(trial: Trial) -> Dict[str, Any]:
         "elu": th.nn.ELU,
         "leaky_relu": th.nn.LeakyReLU,
     }[activation_fn_name]
+    optimizer_class_name = trial.suggest_categorical(
+        "optimizer_class", ["adam", "rmsprop", "sgd"]
+    )
+    optimizer_class: Dict[str, type[th.optim.Optimizer]] = {
+        "adam": th.optim.Adam,
+        "rmsprop": th.optim.RMSprop,
+        "sgd": th.optim.SGD,
+    }[optimizer_class_name]
     return {
         "n_steps": n_steps,
         "batch_size": batch_size,
@@ -1463,6 +1484,7 @@ def sample_params_ppo(trial: Trial) -> Dict[str, Any]:
         "policy_kwargs": dict(
             net_arch=net_arch,
             activation_fn=activation_fn,
+            optimizer_class=optimizer_class,
             ortho_init=ortho_init,
         ),
     }
@@ -1514,6 +1536,14 @@ def sample_params_dqn(trial: Trial) -> Dict[str, Any]:
         "elu": th.nn.ELU,
         "leaky_relu": th.nn.LeakyReLU,
     }[activation_fn_name]
+    optimizer_class_name = trial.suggest_categorical(
+        "optimizer_class", ["adam", "rmsprop", "sgd"]
+    )
+    optimizer_class: Dict[str, type[th.optim.Optimizer]] = {
+        "adam": th.optim.Adam,
+        "rmsprop": th.optim.RMSprop,
+        "sgd": th.optim.SGD,
+    }[optimizer_class_name]
     return {
         "gamma": gamma,
         "learning_rate": learning_rate,
@@ -1525,7 +1555,11 @@ def sample_params_dqn(trial: Trial) -> Dict[str, Any]:
         "exploration_final_eps": exploration_final_eps,
         "target_update_interval": target_update_interval,
         "learning_starts": learning_starts,
-        "policy_kwargs": dict(net_arch=net_arch, activation_fn=activation_fn),
+        "policy_kwargs": dict(
+            net_arch=net_arch,
+            activation_fn=activation_fn,
+            optimizer_class=optimizer_class,
+        ),
     }