From 9181a6eae58a743b7da0922a50558331d456058d Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Wed, 19 Nov 2025 02:58:49 +0100 Subject: [PATCH] refactor(reforcexy): more constants usage MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/user_data/freqaimodels/ReforceXY.py | 117 +++++++++++------- 1 file changed, 71 insertions(+), 46 deletions(-) diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index 37222c9..37ab6e2 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -1084,7 +1084,7 @@ class ReforceXY(BaseReinforcementLearningModel): ) else: raise ValueError( - f"Unsupported sampler: {sampler!r}. Supported samplers: {', '.join(self._SAMPLER_TYPES)}" + f"Unsupported sampler: {sampler}. Supported samplers: {', '.join(self._SAMPLER_TYPES)}" ) @staticmethod @@ -1656,7 +1656,7 @@ class MyRLEnv(Base5ActionRLEnv): if self._entry_additive_enabled or self._exit_additive_enabled: logger.info( "PBRS canonical mode: additive rewards disabled with Φ(terminal)=0. PBRS invariance is preserved. " - f"To use additive rewards, set exit_potential_mode={ReforceXY._EXIT_POTENTIAL_MODES[1]!r}." + f"To use additive rewards, set exit_potential_mode={ReforceXY._EXIT_POTENTIAL_MODES[1]}." ) self._entry_additive_enabled = False self._exit_additive_enabled = False @@ -3175,7 +3175,7 @@ class InfoMetricsCallback(TensorboardCallback): hparam_dict.update({"n_updates": int(n_updates)}) except Exception: pass - if "PPO" in self.model.__class__.__name__: + if ReforceXY._MODEL_TYPES[0] in self.model.__class__.__name__: # "PPO" cr = getattr(self.model, "clip_range", None) cr_schedule, cr_iv, cr_fv = get_schedule_type(cr) hparam_dict.update( @@ -3193,7 +3193,9 @@ class InfoMetricsCallback(TensorboardCallback): ) if getattr(self.model, "target_kl", None) is not None: hparam_dict["target_kl"] = float(self.model.target_kl) - if "RecurrentPPO" in self.model.__class__.__name__: + if ( + ReforceXY._MODEL_TYPES[1] in self.model.__class__.__name__ + ): # "RecurrentPPO" policy = getattr(self.model, "policy", None) if policy is not None: lstm_actor = getattr(policy, "lstm_actor", None) @@ -3204,7 +3206,7 @@ class InfoMetricsCallback(TensorboardCallback): "n_lstm_layers": int(lstm_actor.num_layers), } ) - if "DQN" in self.model.__class__.__name__: + if ReforceXY._MODEL_TYPES[3] in self.model.__class__.__name__: # "DQN" hparam_dict.update( { "buffer_size": int(self.model.buffer_size), @@ -3224,7 +3226,7 @@ class InfoMetricsCallback(TensorboardCallback): ) if train_freq is not None: hparam_dict.update({"train_freq": train_freq}) - if "QRDQN" in self.model.__class__.__name__: + if ReforceXY._MODEL_TYPES[4] in self.model.__class__.__name__: # "QRDQN" hparam_dict.update({"n_quantiles": int(self.model.n_quantiles)}) metric_dict: dict[str, float | int] = { "eval/mean_reward": 0.0, @@ -3240,7 +3242,7 @@ class InfoMetricsCallback(TensorboardCallback): "info/trade_count": 0, "info/trade_duration": 0, } - if "PPO" in self.model.__class__.__name__: + if ReforceXY._MODEL_TYPES[0] in self.model.__class__.__name__: # "PPO" metric_dict.update( { "train/approx_kl": 0.0, @@ -3252,7 +3254,7 @@ class InfoMetricsCallback(TensorboardCallback): "train/explained_variance": 0.0, } ) - if "DQN" in self.model.__class__.__name__: + if ReforceXY._MODEL_TYPES[3] in self.model.__class__.__name__: # "DQN" metric_dict.update( { "train/loss": 0.0, @@ -3491,7 +3493,7 @@ class InfoMetricsCallback(TensorboardCallback): except Exception: pass - if "PPO" in self.model.__class__.__name__: + if ReforceXY._MODEL_TYPES[0] in self.model.__class__.__name__: # "PPO" try: cr = getattr(self.model, "clip_range", None) cr = _eval_schedule(cr) @@ -3502,7 +3504,7 @@ class InfoMetricsCallback(TensorboardCallback): except Exception: pass - if "DQN" in self.model.__class__.__name__: + if ReforceXY._MODEL_TYPES[3] in self.model.__class__.__name__: # "DQN" try: er = getattr(self.model, "exploration_rate", None) if _is_finite_number(er): @@ -3738,80 +3740,100 @@ def get_schedule_type( if isinstance(schedule, (int, float)): try: schedule = float(schedule) - return "constant", schedule, schedule + return ReforceXY._SCHEDULE_TYPES[1], schedule, schedule # "constant" except Exception: - return "constant", np.nan, np.nan + return ReforceXY._SCHEDULE_TYPES[1], np.nan, np.nan # "constant" elif isinstance(schedule, ConstantSchedule): try: - return "constant", schedule(1.0), schedule(0.0) + return ( + ReforceXY._SCHEDULE_TYPES[1], + schedule(1.0), + schedule(0.0), + ) # "constant" except Exception: - return "constant", np.nan, np.nan + return ReforceXY._SCHEDULE_TYPES[1], np.nan, np.nan # "constant" elif isinstance(schedule, SimpleLinearSchedule): try: - return "linear", schedule(1.0), schedule(0.0) + return ( + ReforceXY._SCHEDULE_TYPES[0], + schedule(1.0), + schedule(0.0), + ) # "linear" except Exception: - return "linear", np.nan, np.nan + return ReforceXY._SCHEDULE_TYPES[0], np.nan, np.nan # "linear" - return "unknown", np.nan, np.nan + return ReforceXY._SCHEDULE_TYPES[2], np.nan, np.nan # "unknown" def get_schedule( schedule_type: Literal["linear", "constant"], initial_value: float, ) -> Callable[[float], float]: - if schedule_type == ReforceXY._SCHEDULE_TYPES[0]: # "linear" + if schedule_type == ReforceXY._SCHEDULE_TYPES[0]: return SimpleLinearSchedule(initial_value) - elif schedule_type == ReforceXY._SCHEDULE_TYPES[1]: # "constant" + elif schedule_type == ReforceXY._SCHEDULE_TYPES[1]: return ConstantSchedule(initial_value) else: return ConstantSchedule(initial_value) def get_net_arch( - model_type: str, net_arch_type: Literal["small", "medium", "large", "extra_large"] + model_type: str, net_arch_type: NetArchSize ) -> Union[List[int], Dict[str, List[int]]]: """ Get network architecture """ - if "PPO" in model_type: + if ReforceXY._MODEL_TYPES[0] in model_type: # "PPO" return { - "small": {"pi": [128, 128], "vf": [128, 128]}, - "medium": {"pi": [256, 256], "vf": [256, 256]}, - "large": {"pi": [512, 512], "vf": [512, 512]}, - "extra_large": {"pi": [1024, 1024], "vf": [1024, 1024]}, + ReforceXY._NET_ARCH_SIZES[0]: { + "pi": [128, 128], + "vf": [128, 128], + }, # ReforceXY._NET_ARCH_SIZES[0] + ReforceXY._NET_ARCH_SIZES[1]: { + "pi": [256, 256], + "vf": [256, 256], + }, # ReforceXY._NET_ARCH_SIZES[1] + ReforceXY._NET_ARCH_SIZES[2]: { + "pi": [512, 512], + "vf": [512, 512], + }, # "large" + ReforceXY._NET_ARCH_SIZES[3]: { + "pi": [1024, 1024], + "vf": [1024, 1024], + }, # "extra_large" }.get(net_arch_type, {"pi": [128, 128], "vf": [128, 128]}) return { - "small": [128, 128], - "medium": [256, 256], - "large": [512, 512], - "extra_large": [1024, 1024], + ReforceXY._NET_ARCH_SIZES[0]: [128, 128], # "small" + ReforceXY._NET_ARCH_SIZES[1]: [256, 256], # "medium" + ReforceXY._NET_ARCH_SIZES[2]: [512, 512], # "large" + ReforceXY._NET_ARCH_SIZES[3]: [1024, 1024], # "extra_large" }.get(net_arch_type, [128, 128]) def get_activation_fn( - activation_fn_name: Literal["tanh", "relu", "elu", "leaky_relu"], + activation_fn_name: ActivationFunction, ) -> Type[th.nn.Module]: """ Get activation function """ return { - "tanh": th.nn.Tanh, - "relu": th.nn.ReLU, - "elu": th.nn.ELU, - "leaky_relu": th.nn.LeakyReLU, + ReforceXY._ACTIVATION_FUNCTIONS[0]: th.nn.Tanh, # "tanh" + ReforceXY._ACTIVATION_FUNCTIONS[1]: th.nn.ReLU, # "relu" + ReforceXY._ACTIVATION_FUNCTIONS[2]: th.nn.ELU, # "elu" + ReforceXY._ACTIVATION_FUNCTIONS[3]: th.nn.LeakyReLU, # "leaky_relu" }.get(activation_fn_name, th.nn.ReLU) def get_optimizer_class( - optimizer_class_name: Literal["adam", "adamw", "rmsprop"], + optimizer_class_name: OptimizerClass, ) -> Type[th.optim.Optimizer]: """ Get optimizer class """ return { - "adam": th.optim.Adam, - "adamw": th.optim.AdamW, - "rmsprop": th.optim.RMSprop, + ReforceXY._OPTIMIZER_CLASSES[0]: th.optim.Adam, # "adam" + ReforceXY._OPTIMIZER_CLASSES[1]: th.optim.AdamW, # "adamw" + ReforceXY._OPTIMIZER_CLASSES[2]: th.optim.RMSprop, # "rmsprop" }.get(optimizer_class_name, th.optim.Adam) @@ -3823,14 +3845,12 @@ def convert_optuna_params_to_model_params( lr = optuna_params.get("learning_rate") if lr is None: - raise ValueError( - f"missing {'learning_rate'!r} in optuna params for {model_type}" - ) + raise ValueError(f"missing {'learning_rate'} in optuna params for {model_type}") lr = get_schedule( optuna_params.get("lr_schedule", ReforceXY._SCHEDULE_TYPES[1]), float(lr) ) # default: "constant" - if "PPO" in model_type: + if ReforceXY._MODEL_TYPES[0] in model_type: # "PPO" required_ppo_params = [ "clip_range", "n_steps", @@ -3867,12 +3887,12 @@ def convert_optuna_params_to_model_params( ) if optuna_params.get("target_kl") is not None: model_params["target_kl"] = float(optuna_params.get("target_kl")) - if "RecurrentPPO" in model_type: + if ReforceXY._MODEL_TYPES[1] in model_type: # "RecurrentPPO" policy_kwargs["lstm_hidden_size"] = int( optuna_params.get("lstm_hidden_size") ) policy_kwargs["n_lstm_layers"] = int(optuna_params.get("n_lstm_layers")) - elif "DQN" in model_type: + elif ReforceXY._MODEL_TYPES[3] in model_type: # "DQN" required_dqn_params = [ "gamma", "batch_size", @@ -3915,7 +3935,10 @@ def convert_optuna_params_to_model_params( "learning_starts": int(optuna_params.get("learning_starts")), } ) - if "QRDQN" in model_type and optuna_params.get("n_quantiles") is not None: + if ( + ReforceXY._MODEL_TYPES[4] in model_type + and optuna_params.get("n_quantiles") is not None + ): # "QRDQN" policy_kwargs["n_quantiles"] = int(optuna_params["n_quantiles"]) else: raise ValueError(f"Model {model_type} not supported") @@ -4038,7 +4061,9 @@ def get_common_dqn_optuna_params(trial: Trial) -> Dict[str, Any]: "batch_size", [64, 128, 256, 512, 1024] ), "learning_rate": trial.suggest_float("learning_rate", 1e-5, 3e-3, log=True), - "lr_schedule": trial.suggest_categorical("lr_schedule", ["linear", "constant"]), + "lr_schedule": trial.suggest_categorical( + "lr_schedule", list(ReforceXY._SCHEDULE_TYPES[:2]) + ), # ["linear", "constant"] "buffer_size": trial.suggest_categorical( "buffer_size", [int(1e4), int(5e4), int(1e5), int(2e5)] ), -- 2.43.0