From c76fba6789675a72d820f6b137116abc40a93f0e Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Sat, 22 Nov 2025 23:48:27 +0100 Subject: [PATCH] refactor(reforcexy): use proper types MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- ReforceXY/user_data/freqaimodels/ReforceXY.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index 7cb833a..8360b11 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -150,26 +150,26 @@ class ReforceXY(BaseReinforcementLearningModel): _LOG_2: Final[float] = math.log(2.0) DEFAULT_IDLE_DURATION_MULTIPLIER: Final[int] = 4 - _MODEL_TYPES: Final[tuple[ModelType, ...]] = ( + _MODEL_TYPES: Final[Tuple[ModelType, ...]] = ( "PPO", "RecurrentPPO", "MaskablePPO", "DQN", "QRDQN", ) - _SCHEDULE_TYPES_KNOWN: Final[tuple[ScheduleTypeKnown, ...]] = ("linear", "constant") - _SCHEDULE_TYPES: Final[tuple[ScheduleType, ...]] = ( + _SCHEDULE_TYPES_KNOWN: Final[Tuple[ScheduleTypeKnown, ...]] = ("linear", "constant") + _SCHEDULE_TYPES: Final[Tuple[ScheduleType, ...]] = ( *_SCHEDULE_TYPES_KNOWN, "unknown", ) - _EXIT_POTENTIAL_MODES: Final[tuple[ExitPotentialMode, ...]] = ( + _EXIT_POTENTIAL_MODES: Final[Tuple[ExitPotentialMode, ...]] = ( "canonical", "non_canonical", "progressive_release", "spike_cancel", "retain_previous", ) - _TRANSFORM_FUNCTIONS: Final[tuple[TransformFunction, ...]] = ( + _TRANSFORM_FUNCTIONS: Final[Tuple[TransformFunction, ...]] = ( "tanh", "softsign", "arctan", @@ -177,36 +177,36 @@ class ReforceXY(BaseReinforcementLearningModel): "asinh", "clip", ) - _EXIT_ATTENUATION_MODES: Final[tuple[ExitAttenuationMode, ...]] = ( + _EXIT_ATTENUATION_MODES: Final[Tuple[ExitAttenuationMode, ...]] = ( "legacy", "sqrt", "linear", "power", "half_life", ) - _ACTIVATION_FUNCTIONS: Final[tuple[ActivationFunction, ...]] = ( + _ACTIVATION_FUNCTIONS: Final[Tuple[ActivationFunction, ...]] = ( "relu", "tanh", "elu", "leaky_relu", ) - _OPTIMIZER_CLASSES_OPTUNA: Final[tuple[OptimizerClassOptuna, ...]] = ( + _OPTIMIZER_CLASSES_OPTUNA: Final[Tuple[OptimizerClassOptuna, ...]] = ( "adamw", "rmsprop", ) - _OPTIMIZER_CLASSES: Final[tuple[OptimizerClass, ...]] = ( + _OPTIMIZER_CLASSES: Final[Tuple[OptimizerClass, ...]] = ( *_OPTIMIZER_CLASSES_OPTUNA, "adam", ) - _NET_ARCH_SIZES: Final[tuple[NetArchSize, ...]] = ( + _NET_ARCH_SIZES: Final[Tuple[NetArchSize, ...]] = ( "small", "medium", "large", "extra_large", ) - _STORAGE_BACKENDS: Final[tuple[StorageBackend, ...]] = ("sqlite", "file") - _SAMPLER_TYPES: Final[tuple[SamplerType, ...]] = ("tpe", "auto") - _PPO_N_STEPS: Final[tuple[int, ...]] = (512, 1024, 2048, 4096) + _STORAGE_BACKENDS: Final[Tuple[StorageBackend, ...]] = ("sqlite", "file") + _SAMPLER_TYPES: Final[Tuple[SamplerType, ...]] = ("tpe", "auto") + _PPO_N_STEPS: Final[Tuple[int, ...]] = (512, 1024, 2048, 4096) _action_masks_cache: ClassVar[Dict[Tuple[bool, float], NDArray[np.bool_]]] = {} @@ -3249,7 +3249,7 @@ class InfoMetricsCallback(TensorboardCallback): hparam_dict.update({"train_freq": train_freq}) if ReforceXY._MODEL_TYPES[4] in self.model.__class__.__name__: # "QRDQN" hparam_dict.update({"n_quantiles": int(self.model.n_quantiles)}) - metric_dict: dict[str, float | int] = { + metric_dict: Dict[str, float | int] = { "eval/mean_reward": 0.0, "eval/mean_reward_std": 0.0, "rollout/ep_rew_mean": 0.0, -- 2.43.0