]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor(reforcexy): code cleanups
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Sat, 11 Oct 2025 16:26:04 +0000 (18:26 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Sat, 11 Oct 2025 16:26:04 +0000 (18:26 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/reward_space_analysis/README.md
ReforceXY/reward_space_analysis/reward_space_analysis.py
ReforceXY/reward_space_analysis/test_reward_space_analysis.py
ReforceXY/user_data/freqaimodels/ReforceXY.py

index 63cf10b9923beee9eb61a437ad4d58e3361c6896..db7a2d69b78064b8322ba9884e0253f382ac0bae 100644 (file)
@@ -190,7 +190,7 @@ None - all parameters have sensible defaults.
 
 - Maximum trade duration in candles (from environment config)
 - Should match your actual trading environment setting
-- Drives idle grace: when `max_idle_duration_candles` ≤ 0 the fallback = `2 * max_trade_duration`
+- Drives idle grace: when `max_idle_duration_candles` fallback = `2 * max_trade_duration`
 
 ### Reward Configuration
 
index b87c5c3cef8ec476c3040772550d073277c0f49b..80fb259ef7732d83ae3a6e793536484fd370587d 100644 (file)
@@ -31,7 +31,7 @@ import random
 import warnings
 from enum import Enum, IntEnum
 from pathlib import Path
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Mapping
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
 
 import numpy as np
 import pandas as pd
@@ -74,7 +74,7 @@ def _to_bool(value: Any) -> bool:
 
 
 def _get_param_float(
-    params: Mapping[str, RewardParamValue], key: str, default: RewardParamValue
+    params: RewardParams, key: str, default: RewardParamValue
 ) -> float:
     """Extract float parameter with type safety and default fallback."""
     value = params.get(key, default)
@@ -135,7 +135,7 @@ DEFAULT_MODEL_REWARD_PARAMETERS: RewardParams = {
     # Idle penalty (env defaults)
     "idle_penalty_scale": 0.5,
     "idle_penalty_power": 1.025,
-    # Fallback semantics: 2 * max_trade_duration_candles
+    # Fallback: 2 * max_trade_duration_candles
     "max_idle_duration_candles": None,
     # Holding keys (env defaults)
     "holding_penalty_scale": 0.25,
@@ -299,7 +299,7 @@ def add_tunable_cli_args(parser: argparse.ArgumentParser) -> None:
         if key == "exit_attenuation_mode":
             parser.add_argument(
                 f"--{key}",
-                type=str,  # case preserved; validation + silent fallback occurs before factor computation
+                type=str,
                 choices=sorted(ALLOWED_EXIT_MODES),
                 default=None,
                 help=help_text,
@@ -755,7 +755,7 @@ def simulate_samples(
     rng = random.Random(seed)
     short_allowed = _is_short_allowed(trading_mode)
     action_masking = _to_bool(params.get("action_masking", True))
-    samples: list[dict[str, float]] = []
+    samples: list[Dict[str, float]] = []
     for _ in range(num_samples):
         if short_allowed:
             position_choices = [Positions.Neutral, Positions.Long, Positions.Short]
index 0e608db923df9108e3282b407e58c43e74a26e42..7f8d890a2af38dfafb55196083d946a54b901c0d 100644 (file)
@@ -887,7 +887,7 @@ class TestRewardAlignment(RewardSpaceTestBase):
                 position=Positions.Neutral,
                 action=Actions.Neutral,
             ),
-            # Holding penalty (maintained position)
+            # Holding penalty
             RewardContext(
                 pnl=0.0,
                 trade_duration=80,
@@ -1097,7 +1097,6 @@ class TestPublicAPI(RewardSpaceTestBase):
                 "trade_duration": np.random.uniform(5, 150, 300),
                 "idle_duration": idle_duration,
                 "position": np.random.choice([0.0, 0.5, 1.0], 300),
-                "is_force_exit": np.random.choice([0.0, 1.0], 300, p=[0.85, 0.15]),
             }
         )
 
@@ -1397,7 +1396,6 @@ class TestStatisticalValidation(RewardSpaceTestBase):
                 "trade_duration": np.random.uniform(5, 150, 300),
                 "idle_duration": np.random.uniform(0, 100, 300),
                 "position": np.random.choice([0.0, 0.5, 1.0], 300),
-                "is_force_exit": np.random.choice([0.0, 1.0], 300, p=[0.8, 0.2]),
             }
         )
 
@@ -1802,7 +1800,6 @@ class TestHelperFunctions(RewardSpaceTestBase):
                     [np.random.uniform(5, 50, 50), np.zeros(150)]
                 ),
                 "position": np.random.choice([0.0, 0.5, 1.0], 200),
-                "is_force_exit": np.random.choice([0.0, 1.0], 200, p=[0.8, 0.2]),
             }
         )
 
index a41a23789ad7e9af5be678e79a07bbfdb872738e..fb855da9d0199b8f223e43da466bc59ef1f844ff 100644 (file)
@@ -117,7 +117,7 @@ class ReforceXY(BaseReinforcementLearningModel):
     """
 
     _LOG_2 = math.log(2.0)
-    _action_masks_cache: Dict[Tuple[bool, int], NDArray[np.bool_]] = {}
+    _action_masks_cache: Dict[Tuple[bool, float], NDArray[np.bool_]] = {}
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
@@ -191,7 +191,7 @@ class ReforceXY(BaseReinforcementLearningModel):
 
         cache_key = (
             can_short,
-            position.value,
+            float(position.value),
         )
         if cache_key in ReforceXY._action_masks_cache:
             return ReforceXY._action_masks_cache[cache_key]
@@ -1726,7 +1726,7 @@ class MyRLEnv(Base5ActionRLEnv):
         delta_pnl = pnl - pre_pnl
         info = {
             "tick": self._current_tick,
-            "position": self._position.value,
+            "position": float(self._position.value),
             "action": action,
             "pre_pnl": round(pre_pnl, 5),
             "pnl": round(pnl, 5),