]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor(reforcexy): cleanup hyperparams handling
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Mon, 3 Mar 2025 21:01:13 +0000 (22:01 +0100)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Mon, 3 Mar 2025 21:01:13 +0000 (22:01 +0100)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index 7274b3a69965c5732a77c766569714f6872246cd..c485ed0208e5d25f0f889993db32676354bf6a6b 100644 (file)
@@ -234,30 +234,21 @@ class ReforceXY(BaseReinforcementLearningModel):
             model_params["policy_kwargs"] = {}
 
         net_arch = model_params["policy_kwargs"].get("net_arch", [128, 128])
-        model_params["policy_kwargs"]["net_arch"] = net_arch
         if "PPO" in self.model_type:
             model_params["policy_kwargs"]["net_arch"] = {
                 "pi": net_arch,
                 "vf": net_arch,
             }
+        else:
+            model_params["policy_kwargs"]["net_arch"] = net_arch
+
+        model_params["policy_kwargs"]["activation_fn"] = model_params[
+            "policy_kwargs"
+        ].get("activation_fn", "relu")
 
-        activation_fn = model_params["policy_kwargs"].get("activation_fn", "relu")
-        if activation_fn == "tanh":
-            model_params["policy_kwargs"]["activation_fn"] = th.nn.Tanh
-        elif activation_fn == "relu":
-            model_params["policy_kwargs"]["activation_fn"] = th.nn.ReLU
-        elif activation_fn == "elu":
-            model_params["policy_kwargs"]["activation_fn"] = th.nn.ELU
-        elif activation_fn == "leaky_relu":
-            model_params["policy_kwargs"]["activation_fn"] = th.nn.LeakyReLU
-
-        optimizer_class = model_params["policy_kwargs"].get("optimizer_class", "adam")
-        if optimizer_class == "adam":
-            model_params["policy_kwargs"]["optimizer_class"] = th.optim.Adam
-        elif optimizer_class == "rmsprop":
-            model_params["policy_kwargs"]["optimizer_class"] = th.optim.RMSprop
-        elif optimizer_class == "sgd":
-            model_params["policy_kwargs"]["optimizer_class"] = th.optim.SGD
+        model_params["policy_kwargs"]["optimizer_class"] = model_params[
+            "policy_kwargs"
+        ].get("optimizer_class", "adam")
 
         return model_params
 
@@ -1422,6 +1413,50 @@ def steps_to_days(steps: int, timeframe: str) -> float:
     return round(days, 1)
 
 
+def get_net_arch(
+    model_type: str, net_arch_type: str
+) -> Dict[str, list[int] | Dict[str, list[int]]]:
+    """
+    Get network architecture
+    """
+    if "PPO" in model_type:
+        return {
+            "small": {"pi": [128, 128], "vf": [128, 128]},
+            "medium": {"pi": [256, 256], "vf": [256, 256]},
+            "large": {"pi": [512, 512], "vf": [512, 512]},
+            "extra_large": {"pi": [1024, 1024], "vf": [1024, 1024]},
+        }[net_arch_type]
+    return {
+        "small": [128, 128],
+        "medium": [256, 256],
+        "large": [512, 512],
+        "extra_large": [1024, 1024],
+    }[net_arch_type]
+
+
+def get_activation_fn(activation_fn_name: str) -> type[th.nn.Module]:
+    """
+    Get activation function
+    """
+    return {
+        "tanh": th.nn.Tanh,
+        "relu": th.nn.ReLU,
+        "elu": th.nn.ELU,
+        "leaky_relu": th.nn.LeakyReLU,
+    }[activation_fn_name]
+
+
+def get_optimizer_class(optimizer_class_name: str) -> type[th.optim.Optimizer]:
+    """
+    Get optimizer class
+    """
+    return {
+        "adam": th.optim.Adam,
+        "rmsprop": th.optim.RMSprop,
+        "sgd": th.optim.SGD,
+    }[optimizer_class_name]
+
+
 def sample_params_ppo(trial: Trial) -> Dict[str, Any]:
     """
     Sampler for PPO hyperparams
@@ -1448,29 +1483,15 @@ def sample_params_ppo(trial: Trial) -> Dict[str, Any]:
     net_arch_type: str = trial.suggest_categorical(
         "net_arch", ["small", "medium", "large", "extra_large"]
     )
-    net_arch = {
-        "small": {"pi": [128, 128], "vf": [128, 128]},
-        "medium": {"pi": [256, 256], "vf": [256, 256]},
-        "large": {"pi": [512, 512], "vf": [512, 512]},
-        "extra_large": {"pi": [1024, 1024], "vf": [1024, 1024]},
-    }[net_arch_type]
+    net_arch = get_net_arch("PPO", net_arch_type)
     activation_fn_name: str = trial.suggest_categorical(
         "activation_fn", ["tanh", "relu", "elu", "leaky_relu"]
     )
-    activation_fn: Dict[str, type[th.nn.Module]] = {
-        "tanh": th.nn.Tanh,
-        "relu": th.nn.ReLU,
-        "elu": th.nn.ELU,
-        "leaky_relu": th.nn.LeakyReLU,
-    }[activation_fn_name]
+    activation_fn = get_activation_fn(activation_fn_name)
     optimizer_class_name = trial.suggest_categorical(
         "optimizer_class", ["adam", "rmsprop", "sgd"]
     )
-    optimizer_class: Dict[str, type[th.optim.Optimizer]] = {
-        "adam": th.optim.Adam,
-        "rmsprop": th.optim.RMSprop,
-        "sgd": th.optim.SGD,
-    }[optimizer_class_name]
+    optimizer_class = get_optimizer_class(optimizer_class_name)
     return {
         "n_steps": n_steps,
         "batch_size": batch_size,
@@ -1498,11 +1519,11 @@ def sample_params_dqn(trial: Trial) -> Dict[str, Any]:
     gamma = trial.suggest_categorical(
         "gamma", [0.9, 0.95, 0.98, 0.99, 0.995, 0.999, 0.9999]
     )
+    batch_size = trial.suggest_categorical("batch_size", [64, 128, 256, 512])
     learning_rate = trial.suggest_float("learning_rate", 1e-6, 0.01, log=True)
     lr_schedule = trial.suggest_categorical("lr_schedule", ["linear", "constant"])
     if lr_schedule == "linear":
         learning_rate = linear_schedule(learning_rate)
-    batch_size = trial.suggest_categorical("batch_size", [64, 128, 256, 512])
     buffer_size = trial.suggest_categorical(
         "buffer_size", [int(1e4), int(5e4), int(1e5)]
     )
@@ -1522,33 +1543,19 @@ def sample_params_dqn(trial: Trial) -> Dict[str, Any]:
     net_arch_type: str = trial.suggest_categorical(
         "net_arch", ["small", "medium", "large", "extra_large"]
     )
-    net_arch = {
-        "small": [128, 128],
-        "medium": [256, 256],
-        "large": [512, 512],
-        "extra_large": [1024, 1024],
-    }[net_arch_type]
+    net_arch = get_net_arch("DQN", net_arch_type)
     activation_fn_name: str = trial.suggest_categorical(
         "activation_fn", ["tanh", "relu", "elu", "leaky_relu"]
     )
-    activation_fn: Dict[str, type[th.nn.Module]] = {
-        "tanh": th.nn.Tanh,
-        "relu": th.nn.ReLU,
-        "elu": th.nn.ELU,
-        "leaky_relu": th.nn.LeakyReLU,
-    }[activation_fn_name]
+    activation_fn = get_activation_fn(activation_fn_name)
     optimizer_class_name = trial.suggest_categorical(
         "optimizer_class", ["adam", "rmsprop", "sgd"]
     )
-    optimizer_class: Dict[str, type[th.optim.Optimizer]] = {
-        "adam": th.optim.Adam,
-        "rmsprop": th.optim.RMSprop,
-        "sgd": th.optim.SGD,
-    }[optimizer_class_name]
+    optimizer_class = get_optimizer_class(optimizer_class_name)
     return {
         "gamma": gamma,
-        "learning_rate": learning_rate,
         "batch_size": batch_size,
+        "learning_rate": learning_rate,
         "buffer_size": buffer_size,
         "train_freq": train_freq,
         "gradient_steps": gradient_steps,