From: Jérôme Benoit Date: Sat, 11 Oct 2025 16:26:04 +0000 (+0200) Subject: refactor(reforcexy): code cleanups X-Git-Url: https://git.piment-noir.org/?a=commitdiff_plain;h=56f1a4322f66680260793780779255807a95c5d8;p=freqai-strategies.git refactor(reforcexy): code cleanups Signed-off-by: Jérôme Benoit --- diff --git a/ReforceXY/reward_space_analysis/README.md b/ReforceXY/reward_space_analysis/README.md index 63cf10b..db7a2d6 100644 --- a/ReforceXY/reward_space_analysis/README.md +++ b/ReforceXY/reward_space_analysis/README.md @@ -190,7 +190,7 @@ None - all parameters have sensible defaults. - Maximum trade duration in candles (from environment config) - Should match your actual trading environment setting -- Drives idle grace: when `max_idle_duration_candles` ≤ 0 the fallback = `2 * max_trade_duration` +- Drives idle grace: when `max_idle_duration_candles` fallback = `2 * max_trade_duration` ### Reward Configuration diff --git a/ReforceXY/reward_space_analysis/reward_space_analysis.py b/ReforceXY/reward_space_analysis/reward_space_analysis.py index b87c5c3..80fb259 100644 --- a/ReforceXY/reward_space_analysis/reward_space_analysis.py +++ b/ReforceXY/reward_space_analysis/reward_space_analysis.py @@ -31,7 +31,7 @@ import random import warnings from enum import Enum, IntEnum from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Mapping +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import numpy as np import pandas as pd @@ -74,7 +74,7 @@ def _to_bool(value: Any) -> bool: def _get_param_float( - params: Mapping[str, RewardParamValue], key: str, default: RewardParamValue + params: RewardParams, key: str, default: RewardParamValue ) -> float: """Extract float parameter with type safety and default fallback.""" value = params.get(key, default) @@ -135,7 +135,7 @@ DEFAULT_MODEL_REWARD_PARAMETERS: RewardParams = { # Idle penalty (env defaults) "idle_penalty_scale": 0.5, "idle_penalty_power": 1.025, - # Fallback semantics: 2 * max_trade_duration_candles + # Fallback: 2 * max_trade_duration_candles "max_idle_duration_candles": None, # Holding keys (env defaults) "holding_penalty_scale": 0.25, @@ -299,7 +299,7 @@ def add_tunable_cli_args(parser: argparse.ArgumentParser) -> None: if key == "exit_attenuation_mode": parser.add_argument( f"--{key}", - type=str, # case preserved; validation + silent fallback occurs before factor computation + type=str, choices=sorted(ALLOWED_EXIT_MODES), default=None, help=help_text, @@ -755,7 +755,7 @@ def simulate_samples( rng = random.Random(seed) short_allowed = _is_short_allowed(trading_mode) action_masking = _to_bool(params.get("action_masking", True)) - samples: list[dict[str, float]] = [] + samples: list[Dict[str, float]] = [] for _ in range(num_samples): if short_allowed: position_choices = [Positions.Neutral, Positions.Long, Positions.Short] diff --git a/ReforceXY/reward_space_analysis/test_reward_space_analysis.py b/ReforceXY/reward_space_analysis/test_reward_space_analysis.py index 0e608db..7f8d890 100644 --- a/ReforceXY/reward_space_analysis/test_reward_space_analysis.py +++ b/ReforceXY/reward_space_analysis/test_reward_space_analysis.py @@ -887,7 +887,7 @@ class TestRewardAlignment(RewardSpaceTestBase): position=Positions.Neutral, action=Actions.Neutral, ), - # Holding penalty (maintained position) + # Holding penalty RewardContext( pnl=0.0, trade_duration=80, @@ -1097,7 +1097,6 @@ class TestPublicAPI(RewardSpaceTestBase): "trade_duration": np.random.uniform(5, 150, 300), "idle_duration": idle_duration, "position": np.random.choice([0.0, 0.5, 1.0], 300), - "is_force_exit": np.random.choice([0.0, 1.0], 300, p=[0.85, 0.15]), } ) @@ -1397,7 +1396,6 @@ class TestStatisticalValidation(RewardSpaceTestBase): "trade_duration": np.random.uniform(5, 150, 300), "idle_duration": np.random.uniform(0, 100, 300), "position": np.random.choice([0.0, 0.5, 1.0], 300), - "is_force_exit": np.random.choice([0.0, 1.0], 300, p=[0.8, 0.2]), } ) @@ -1802,7 +1800,6 @@ class TestHelperFunctions(RewardSpaceTestBase): [np.random.uniform(5, 50, 50), np.zeros(150)] ), "position": np.random.choice([0.0, 0.5, 1.0], 200), - "is_force_exit": np.random.choice([0.0, 1.0], 200, p=[0.8, 0.2]), } ) diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index a41a237..fb855da 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -117,7 +117,7 @@ class ReforceXY(BaseReinforcementLearningModel): """ _LOG_2 = math.log(2.0) - _action_masks_cache: Dict[Tuple[bool, int], NDArray[np.bool_]] = {} + _action_masks_cache: Dict[Tuple[bool, float], NDArray[np.bool_]] = {} def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -191,7 +191,7 @@ class ReforceXY(BaseReinforcementLearningModel): cache_key = ( can_short, - position.value, + float(position.value), ) if cache_key in ReforceXY._action_masks_cache: return ReforceXY._action_masks_cache[cache_key] @@ -1726,7 +1726,7 @@ class MyRLEnv(Base5ActionRLEnv): delta_pnl = pnl - pre_pnl info = { "tick": self._current_tick, - "position": self._position.value, + "position": float(self._position.value), "action": action, "pre_pnl": round(pre_pnl, 5), "pnl": round(pnl, 5),