]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
fix(ReforceXY): apply PBRS correction at terminal step
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Mon, 22 Dec 2025 17:29:33 +0000 (18:29 +0100)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Mon, 22 Dec 2025 17:29:33 +0000 (18:29 +0100)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/reward_space_analysis/README.md
ReforceXY/reward_space_analysis/reward_space_analysis.py
ReforceXY/reward_space_analysis/tests/components/test_transforms.py
ReforceXY/reward_space_analysis/tests/helpers/test_internal_branches.py
ReforceXY/reward_space_analysis/tests/pbrs/test_pbrs.py
ReforceXY/user_data/freqaimodels/ReforceXY.py

index 06230e26e4da782f8c94e9cb6c434a29e12943b1..e9606d3d6a2864ec8045f76785f762e44b7f22fe 100644 (file)
@@ -314,14 +314,14 @@ where `kernel_function` depends on `exit_attenuation_mode`. See [Exit Attenuatio
 
 #### PBRS (Potential-Based Reward Shaping)
 
-| Parameter                | Default   | Description                       |
-| ------------------------ | --------- | --------------------------------- |
-| `potential_gamma`        | 0.95      | Discount factor γ for potential Φ |
-| `exit_potential_mode`    | canonical | Potential release mode            |
-| `exit_potential_decay`   | 0.5       | Decay for progressive_release     |
-| `hold_potential_enabled` | true      | Enable hold potential Φ           |
-| `entry_fee_rate`         | 0.0       | Entry fee rate                    |
-| `exit_fee_rate`          | 0.0       | Exit fee rate                     |
+| Parameter                | Default   | Description                          |
+| ------------------------ | --------- | ------------------------------------ |
+| `potential_gamma`        | 0.95      | Discount factor γ for potential Φ    |
+| `exit_potential_mode`    | canonical | Potential release mode               |
+| `exit_potential_decay`   | 0.5       | Decay for progressive_release        |
+| `hold_potential_enabled` | true      | Enable hold potential Φ              |
+| `entry_fee_rate`         | 0.0       | Entry fee rate (`price * (1 + fee)`) |
+| `exit_fee_rate`          | 0.0       | Exit fee rate (`price / (1 + fee)`)  |
 
 PBRS invariance holds when: `exit_potential_mode=canonical`.
 
index ee9720fdfb6adc80e8540443eb1716362bcd1cef..cdeb3c9c44108ae3943b0964f6908addaadebfc9 100644 (file)
@@ -368,6 +368,53 @@ def _get_float_param(params: RewardParams, key: str, default: RewardParamValue)
     return np.nan
 
 
+def _clamp_float_to_bounds(
+    key: str,
+    value: float,
+    *,
+    bounds: Optional[Dict[str, float]] = None,
+    strict: bool,
+) -> tuple[float, list[str]]:
+    """Clamp numeric `value` to bounds for `key`.
+
+    Returns:
+        (adjusted_value, reason_parts)
+
+    Notes:
+        - Uses `_PARAMETER_BOUNDS` when `bounds` is None.
+        - In strict mode, raises on out-of-bounds or non-finite.
+        - In relaxed mode, clamps and emits reason tags.
+    """
+
+    effective_bounds = bounds if bounds is not None else _PARAMETER_BOUNDS.get(key, {})
+    adjusted = float(value)
+    reason_parts: list[str] = []
+
+    if "min" in effective_bounds and adjusted < float(effective_bounds["min"]):
+        if strict:
+            raise ValueError(
+                f"Parameter '{key}'={adjusted} below min {float(effective_bounds['min'])}"
+            )
+        adjusted = float(effective_bounds["min"])
+        reason_parts.append(f"min={float(effective_bounds['min'])}")
+
+    if "max" in effective_bounds and adjusted > float(effective_bounds["max"]):
+        if strict:
+            raise ValueError(
+                f"Parameter '{key}'={adjusted} above max {float(effective_bounds['max'])}"
+            )
+        adjusted = float(effective_bounds["max"])
+        reason_parts.append(f"max={float(effective_bounds['max'])}")
+
+    if not np.isfinite(adjusted):
+        if strict:
+            raise ValueError(f"Parameter '{key}' is non-finite: {adjusted}")
+        adjusted = float(effective_bounds.get("min", 0.0))
+        reason_parts.append("non_finite_reset")
+
+    return adjusted, reason_parts
+
+
 def _get_int_param(params: RewardParams, key: str, default: RewardParamValue) -> int:
     """Extract integer parameter with robust coercion.
 
@@ -538,8 +585,6 @@ def validate_reward_parameters(
             continue
 
         original_numeric = float(coerced)
-        adjusted = original_numeric
-        reason_parts: List[str] = []
 
         # Track type coercion
         if not isinstance(original_val, (int, float)):
@@ -554,23 +599,12 @@ def validate_reward_parameters(
             )
             sanitized[key] = original_numeric
 
-        # Bounds enforcement
-        if "min" in bounds and adjusted < bounds["min"]:
-            if strict:
-                raise ValueError(f"Parameter '{key}'={adjusted} below min {bounds['min']}")
-            adjusted = bounds["min"]
-            reason_parts.append(f"min={bounds['min']}")
-        if "max" in bounds and adjusted > bounds["max"]:
-            if strict:
-                raise ValueError(f"Parameter '{key}'={adjusted} above max {bounds['max']}")
-            adjusted = bounds["max"]
-            reason_parts.append(f"max={bounds['max']}")
-
-        if not np.isfinite(adjusted):
-            if strict:
-                raise ValueError(f"Parameter '{key}' is non-finite: {adjusted}")
-            adjusted = bounds.get("min", 0.0)
-            reason_parts.append("non_finite_reset")
+        adjusted, reason_parts = _clamp_float_to_bounds(
+            key,
+            original_numeric,
+            bounds=bounds,
+            strict=strict,
+        )
 
         if not np.isclose(adjusted, original_numeric):
             sanitized[key] = adjusted
@@ -1019,6 +1053,25 @@ def _is_valid_action(
     return False
 
 
+def _get_next_position(
+    position: Positions,
+    action: Actions,
+    *,
+    short_allowed: bool,
+) -> Positions:
+    """Compute the next position given current position and action."""
+
+    if action == Actions.Long_enter and position == Positions.Neutral:
+        return Positions.Long
+    if action == Actions.Short_enter and position == Positions.Neutral and short_allowed:
+        return Positions.Short
+    if action == Actions.Long_exit and position == Positions.Long:
+        return Positions.Neutral
+    if action == Actions.Short_exit and position == Positions.Short:
+        return Positions.Neutral
+    return position
+
+
 def _idle_penalty(context: RewardContext, idle_factor: float, params: RewardParams) -> float:
     """Compute idle penalty."""
     idle_penalty_scale = _get_float_param(
@@ -1106,14 +1159,15 @@ def calculate_reward(
         context.action,
         short_allowed=short_allowed,
     )
+
+    base_reward: Optional[float] = None
     if not is_valid and not action_masking:
         breakdown.invalid_penalty = _get_float_param(
             params,
             "invalid_action",
             DEFAULT_MODEL_REWARD_PARAMETERS.get("invalid_action", -2.0),
         )
-        breakdown.total = breakdown.invalid_penalty
-        return breakdown
+        base_reward = breakdown.invalid_penalty
 
     factor = _get_float_param(params, "base_factor", base_factor)
 
@@ -1138,53 +1192,71 @@ def calculate_reward(
     )
 
     # Base reward calculation
-    base_reward = 0.0
-
-    if context.action == Actions.Neutral and context.position == Positions.Neutral:
-        base_reward = _idle_penalty(context, idle_factor, params)
-        breakdown.idle_penalty = base_reward
-    elif (
-        context.position in (Positions.Long, Positions.Short) and context.action == Actions.Neutral
-    ):
-        base_reward = _hold_penalty(context, hold_factor, params)
-        breakdown.hold_penalty = base_reward
-    elif context.action == Actions.Long_exit and context.position == Positions.Long:
-        base_reward = _compute_exit_reward(
-            factor, pnl_target, current_duration_ratio, context, params, risk_reward_ratio
-        )
-        breakdown.exit_component = base_reward
-    elif context.action == Actions.Short_exit and context.position == Positions.Short:
-        base_reward = _compute_exit_reward(
-            factor, pnl_target, current_duration_ratio, context, params, risk_reward_ratio
-        )
-        breakdown.exit_component = base_reward
-    else:
-        base_reward = 0.0
+    if base_reward is None:
+        if context.action == Actions.Neutral and context.position == Positions.Neutral:
+            base_reward = _idle_penalty(context, idle_factor, params)
+            breakdown.idle_penalty = base_reward
+        elif (
+            context.position in (Positions.Long, Positions.Short)
+            and context.action == Actions.Neutral
+        ):
+            base_reward = _hold_penalty(context, hold_factor, params)
+            breakdown.hold_penalty = base_reward
+        elif context.action == Actions.Long_exit and context.position == Positions.Long:
+            base_reward = _compute_exit_reward(
+                factor, pnl_target, current_duration_ratio, context, params, risk_reward_ratio
+            )
+            breakdown.exit_component = base_reward
+        elif context.action == Actions.Short_exit and context.position == Positions.Short:
+            base_reward = _compute_exit_reward(
+                factor, pnl_target, current_duration_ratio, context, params, risk_reward_ratio
+            )
+            breakdown.exit_component = base_reward
+        else:
+            base_reward = 0.0
 
-    breakdown.base_reward = base_reward
+    breakdown.base_reward = float(base_reward)
 
     # === PBRS INTEGRATION ===
     current_pnl = context.pnl if context.position != Positions.Neutral else 0.0
 
-    is_entry = context.position == Positions.Neutral and context.action in (
-        Actions.Long_enter,
-        Actions.Short_enter,
+    next_position = _get_next_position(
+        context.position, context.action, short_allowed=short_allowed
     )
-    is_exit = context.position in (
+    is_entry = context.position == Positions.Neutral and next_position in (
         Positions.Long,
         Positions.Short,
-    ) and context.action in (Actions.Long_exit, Actions.Short_exit)
-    is_hold = (
-        context.position in (Positions.Long, Positions.Short) and context.action == Actions.Neutral
     )
-    is_neutral = context.position == Positions.Neutral and context.action == Actions.Neutral
+    is_exit = (
+        context.position
+        in (
+            Positions.Long,
+            Positions.Short,
+        )
+        and next_position == Positions.Neutral
+    )
+    is_hold = context.position in (
+        Positions.Long,
+        Positions.Short,
+    ) and next_position in (Positions.Long, Positions.Short)
+    is_neutral = context.position == Positions.Neutral and next_position == Positions.Neutral
 
     if is_entry:
         next_duration_ratio = 0.0
         if context.action == Actions.Long_enter:
-            next_pnl = _compute_entry_unrealized_pnl_estimate(Positions.Long, params)
+            next_pnl = _compute_unrealized_pnl_estimate(
+                Positions.Long,
+                entry_open=1.0,
+                current_open=1.0,
+                params=params,
+            )
         elif context.action == Actions.Short_enter:
-            next_pnl = _compute_entry_unrealized_pnl_estimate(Positions.Short, params)
+            next_pnl = _compute_unrealized_pnl_estimate(
+                Positions.Short,
+                entry_open=1.0,
+                current_open=1.0,
+                params=params,
+            )
         else:
             next_pnl = current_pnl
     elif is_hold:
@@ -1280,6 +1352,9 @@ def calculate_reward(
         breakdown.total = total_reward
         return breakdown
 
+    prev_potential_safe = float(prev_potential) if np.isfinite(prev_potential) else 0.0
+    breakdown.prev_potential = prev_potential_safe
+    breakdown.next_potential = prev_potential_safe
     breakdown.total = base_reward
 
     return breakdown
@@ -1393,21 +1468,42 @@ def simulate_samples(
     max_unrealized_profit = 0.0
     min_unrealized_profit = 0.0
 
+    # Synthetic market state
+    current_open = 1.0
+    entry_open = current_open
+
     for _ in range(num_samples):
-        # Simulate market movement while in position (PnL as a state variable)
-        if position in (Positions.Long, Positions.Short):
-            duration_ratio = _compute_duration_ratio(trade_duration, max_trade_duration_candles)
-            pnl_std = pnl_base_std * (1.0 + pnl_duration_vol_scale * duration_ratio)
-            step_delta = rng.gauss(0.0, pnl_std)
-
-            # Small directional drift so signals aren't perfectly symmetric.
-            drift = 0.001 * duration_ratio
-            if position == Positions.Long:
-                step_delta += drift
-            else:
-                step_delta -= drift
+        # Simulate synthetic open-price movement.
+        duration_ratio = (
+            _compute_duration_ratio(trade_duration, max_trade_duration_candles)
+            if position in (Positions.Long, Positions.Short)
+            else 0.0
+        )
+        open_return_std = pnl_base_std * (1.0 + pnl_duration_vol_scale * duration_ratio)
+        step_return = rng.gauss(0.0, open_return_std)
+
+        # Small directional drift so long/short trajectories are not perfectly symmetric
+        drift = 0.001 * duration_ratio
+        if position == Positions.Long:
+            step_return += drift
+        elif position == Positions.Short:
+            step_return -= drift
+
+        if not np.isfinite(step_return):
+            step_return = 0.0
+        step_return = float(np.clip(step_return, -0.95, 0.95))
 
-            pnl = min(max(-0.15, pnl + step_delta), 0.15)
+        current_open = float(max(1e-6, current_open * (1.0 + step_return)))
+
+        # Compute fee-aware unrealized PnL from (entry_open, current_open)
+        if position in (Positions.Long, Positions.Short):
+            pnl = _compute_unrealized_pnl_estimate(
+                position,
+                entry_open=entry_open,
+                current_open=current_open,
+                params=params,
+            )
+            pnl = float(np.clip(pnl, -0.15, 0.15))
             max_unrealized_profit = max(max_unrealized_profit, pnl)
             min_unrealized_profit = min(min_unrealized_profit, pnl)
         else:
@@ -1479,14 +1575,26 @@ def simulate_samples(
                 position = Positions.Long
                 trade_duration = 0
                 idle_duration = 0
-                pnl = _compute_entry_unrealized_pnl_estimate(Positions.Long, params)
+                entry_open = current_open
+                pnl = _compute_unrealized_pnl_estimate(
+                    Positions.Long,
+                    entry_open=entry_open,
+                    current_open=current_open,
+                    params=params,
+                )
                 max_unrealized_profit = pnl
                 min_unrealized_profit = pnl
             elif action == Actions.Short_enter and short_allowed:
                 position = Positions.Short
                 trade_duration = 0
                 idle_duration = 0
-                pnl = _compute_entry_unrealized_pnl_estimate(Positions.Short, params)
+                entry_open = current_open
+                pnl = _compute_unrealized_pnl_estimate(
+                    Positions.Short,
+                    entry_open=entry_open,
+                    current_open=current_open,
+                    params=params,
+                )
                 max_unrealized_profit = pnl
                 min_unrealized_profit = pnl
         else:
@@ -1497,6 +1605,7 @@ def simulate_samples(
                 position = Positions.Neutral
                 trade_duration = 0
                 idle_duration = 0
+                entry_open = current_open
 
     df = pd.DataFrame(samples)
     df.attrs["reward_params"] = dict(params)
@@ -2775,79 +2884,119 @@ def _get_potential_gamma(params: RewardParams) -> float:
             stacklevel=2,
         )
         return POTENTIAL_GAMMA_DEFAULT
-    if gamma < 0.0 or gamma > 1.0:
-        original = gamma
-        gamma = float(np.clip(gamma, 0.0, 1.0))
+
+    raw_gamma = float(gamma)
+    gamma, reason_parts = _clamp_float_to_bounds("potential_gamma", raw_gamma, strict=False)
+    if reason_parts:
         warnings.warn(
-            f"potential_gamma={original} outside [0,1]; clamped to {gamma}",
+            f"potential_gamma={raw_gamma} outside [0,1]; clamped to {gamma}",
             RewardDiagnosticsWarning,
             stacklevel=2,
         )
-        return gamma
     return float(gamma)
 
 
 # === PBRS IMPLEMENTATION ===
 
 
-def _compute_entry_unrealized_pnl_estimate(next_position: Positions, params: RewardParams) -> float:
-    """Estimate immediate unrealized PnL after entry fees.
+def _get_fee_rates(params: RewardParams) -> tuple[float, float]:
+    """Return clamped `(entry_fee_rate, exit_fee_rate)`.
+
+    Semantics follow Freqtrade's `BaseEnvironment` fee helpers:
+    - Entry fee is applied as multiplication: `price * (1 + entry_fee_rate)`.
+    - Exit fee is applied as division: `price / (1 + exit_fee_rate)`.
 
-    For Long entry:
-        current_price = open * (1 - exit_fee_rate)
-        last_trade_price = open * (1 + entry_fee_rate)
-        pnl = (current_price - last_trade_price) / last_trade_price
+    Notes:
+    - Supports two tunables (`entry_fee_rate`, `exit_fee_rate`).
+    - Missing/non-finite values fall back to the min bound (usually 0.0).
+    - Values are clamped to `_PARAMETER_BOUNDS`.
 
-    For Short entry:
-        current_price = open * (1 + entry_fee_rate)
-        last_trade_price = open * (1 - exit_fee_rate)
-        pnl = (last_trade_price - current_price) / last_trade_price
+    This function intentionally clamps (never raises) so callers do not need to
+    pre-run `validate_reward_parameters()`.
     """
 
-    entry_fee_rate = _get_float_param(
+    raw_entry_fee_rate = _get_float_param(
         params,
         "entry_fee_rate",
         DEFAULT_MODEL_REWARD_PARAMETERS.get("entry_fee_rate", 0.0),
     )
-    exit_fee_rate = _get_float_param(
+    raw_exit_fee_rate = _get_float_param(
         params,
         "exit_fee_rate",
         DEFAULT_MODEL_REWARD_PARAMETERS.get("exit_fee_rate", 0.0),
     )
 
-    if not np.isfinite(entry_fee_rate):
-        entry_fee_rate = 0.0
-    if not np.isfinite(exit_fee_rate):
-        exit_fee_rate = 0.0
+    entry_fee_rate, _ = _clamp_float_to_bounds(
+        "entry_fee_rate",
+        float(raw_entry_fee_rate),
+        strict=False,
+    )
+    exit_fee_rate, _ = _clamp_float_to_bounds(
+        "exit_fee_rate",
+        float(raw_exit_fee_rate),
+        strict=False,
+    )
 
-    entry_fee_bounds = _PARAMETER_BOUNDS.get("entry_fee_rate", {"min": 0.0, "max": 1.0})
-    exit_fee_bounds = _PARAMETER_BOUNDS.get("exit_fee_rate", {"min": 0.0, "max": 1.0})
+    return entry_fee_rate, exit_fee_rate
 
-    entry_fee_min = float(entry_fee_bounds.get("min", 0.0))
-    entry_fee_max = float(entry_fee_bounds.get("max", 1.0))
-    exit_fee_min = float(exit_fee_bounds.get("min", 0.0))
-    exit_fee_max = float(exit_fee_bounds.get("max", 1.0))
 
-    entry_fee_rate = float(np.clip(entry_fee_rate, entry_fee_min, entry_fee_max))
-    exit_fee_rate = float(np.clip(exit_fee_rate, exit_fee_min, exit_fee_max))
+def _apply_entry_fee(price: float, entry_fee_rate: float) -> float:
+    return float(price * (1.0 + entry_fee_rate))
 
-    current_open = 1.0
-    next_pnl = 0.0
-
-    if next_position == Positions.Long:
-        current_price = current_open * (1.0 - exit_fee_rate)
-        last_trade_price = current_open * (1.0 + entry_fee_rate)
-        if last_trade_price != 0.0 and np.isfinite(last_trade_price):
-            next_pnl = (current_price - last_trade_price) / last_trade_price
-    elif next_position == Positions.Short:
-        current_price = current_open * (1.0 + entry_fee_rate)
-        last_trade_price = current_open * (1.0 - exit_fee_rate)
-        if last_trade_price != 0.0 and np.isfinite(last_trade_price):
-            next_pnl = (last_trade_price - current_price) / last_trade_price
-
-    if not np.isfinite(next_pnl):
+
+def _apply_exit_fee(price: float, exit_fee_rate: float) -> float:
+    denom = 1.0 + exit_fee_rate
+    if denom <= 0.0 or not np.isfinite(denom):
+        return float(price)
+    return float(price / denom)
+
+
+def _compute_unrealized_pnl_estimate(
+    position: Positions,
+    *,
+    entry_open: float,
+    current_open: float,
+    params: RewardParams,
+) -> float:
+    """Estimate unrealized PnL using fee application parity with Freqtrade.
+
+    Long:
+        entry_price = entry_open * (1 + entry_fee_rate)
+        current_price = current_open / (1 + exit_fee_rate)
+        pnl = (current_price - entry_price) / entry_price
+
+    Short:
+        entry_price = entry_open / (1 + exit_fee_rate)
+        current_price = current_open * (1 + entry_fee_rate)
+        pnl = (entry_price - current_price) / entry_price
+    """
+
+    if position not in (Positions.Long, Positions.Short):
+        return 0.0
+
+    if not np.isfinite(entry_open) or entry_open <= 0.0:
+        return 0.0
+    if not np.isfinite(current_open) or current_open <= 0.0:
+        return 0.0
+
+    entry_fee_rate, exit_fee_rate = _get_fee_rates(params)
+
+    if position == Positions.Long:
+        current_price = _apply_exit_fee(current_open, exit_fee_rate)
+        entry_price = _apply_entry_fee(entry_open, entry_fee_rate)
+        if entry_price == 0.0 or not np.isfinite(entry_price):
+            return 0.0
+        pnl = (current_price - entry_price) / entry_price
+    else:
+        current_price = _apply_entry_fee(current_open, entry_fee_rate)
+        entry_price = _apply_exit_fee(entry_open, exit_fee_rate)
+        if entry_price == 0.0 or not np.isfinite(entry_price):
+            return 0.0
+        pnl = (entry_price - current_price) / entry_price
+
+    if not np.isfinite(pnl):
         return 0.0
-    return float(next_pnl)
+    return float(pnl)
 
 
 def _compute_hold_potential(
index 2c90120358582aac2f84546f2cec6f7932cbd54c..5ded6bba5b1535663a923a1f9a835cfcc50169b0 100644 (file)
@@ -41,14 +41,16 @@ class TestTransforms(RewardSpaceTestBase):
         ]
 
         for transform_name, test_values, expected_values in test_cases:
-            for test_val, expected_val in zip(test_values, expected_values):
-                with self.subTest(transform=transform_name, input=test_val, expected=expected_val):
+            for test_val, expected_value in zip(test_values, expected_values):
+                with self.subTest(
+                    transform=transform_name, input=test_val, expected=expected_value
+                ):
                     result = apply_transform(transform_name, test_val)
                     self.assertAlmostEqualFloat(
                         result,
-                        expected_val,
+                        expected_value,
                         tolerance=1e-10,
-                        msg=f"{transform_name}({test_val}) should equal {expected_val}",
+                        msg=f"{transform_name}({test_val}) should equal {expected_value}",
                     )
 
     def test_transform_bounds_smooth(self):
index ce00aa40c774078e576321284772353162b8eb80..952bd4ce3c7fa55699caf0dece7e7c42f4cd9a70 100644 (file)
@@ -7,7 +7,6 @@ from reward_space_analysis import (
     Positions,
     RewardParams,
     _get_bool_param,
-    _get_float_param,
     calculate_reward,
 )
 
@@ -38,25 +37,6 @@ def test_get_bool_param_none_and_invalid_literal():
     assert _get_bool_param(params_invalid, "check_invariants", True) is True
 
 
-def test_get_float_param_invalid_string_returns_nan():
-    """Verify _get_float_param returns NaN for invalid string input.
-
-    Tests error handling in float parameter parsing when given
-    a non-numeric string that cannot be converted to float.
-
-    **Setup:**
-    - Invalid string: "abc"
-    - Parameter: idle_penalty_scale
-    - Default value: 0.5
-
-    **Assertions:**
-    - Result is NaN (covers float conversion ValueError path)
-    """
-    params: RewardParams = {"idle_penalty_scale": "abc"}
-    val = _get_float_param(params, "idle_penalty_scale", 0.5)
-    assert math.isnan(val)
-
-
 def test_calculate_reward_unrealized_pnl_hold_path():
     """Verify unrealized PnL branch activates during hold action.
 
index fb003250781bb248b5a83839c242ced1b4bcda31..ce47a5eea1dd6ca520a031e78c177316edbaf37a 100644 (file)
@@ -16,10 +16,10 @@ from reward_space_analysis import (
     Actions,
     Positions,
     _compute_entry_additive,
-    _compute_entry_unrealized_pnl_estimate,
     _compute_exit_additive,
     _compute_exit_potential,
     _compute_hold_potential,
+    _compute_unrealized_pnl_estimate,
     _get_float_param,
     apply_potential_shaping,
     calculate_reward,
@@ -305,6 +305,51 @@ class TestPBRS(RewardSpaceTestBase):
             msg="Hold shaping must be suppressed when hold potential disabled",
         )
 
+    def test_calculate_reward_preserves_potential_when_pbrs_disabled(self):
+        """calculate_reward() preserves stored potential when PBRS is disabled."""
+        params = self.base_params(
+            hold_potential_enabled=False,
+            entry_additive_enabled=False,
+            exit_additive_enabled=False,
+            exit_potential_mode="non_canonical",
+        )
+        ctx = self.make_ctx(position=Positions.Neutral, action=Actions.Neutral)
+
+        prev_potential = 0.37
+        breakdown = calculate_reward(
+            ctx,
+            params,
+            base_factor=PARAMS.BASE_FACTOR,
+            profit_aim=PARAMS.PROFIT_AIM,
+            risk_reward_ratio=PARAMS.RISK_REWARD_RATIO,
+            short_allowed=True,
+            action_masking=True,
+            prev_potential=prev_potential,
+        )
+
+        self.assertAlmostEqualFloat(
+            breakdown.prev_potential,
+            prev_potential,
+            tolerance=TOLERANCE.IDENTITY_STRICT,
+            msg="prev_potential must be preserved when PBRS disabled",
+        )
+        self.assertAlmostEqualFloat(
+            breakdown.next_potential,
+            prev_potential,
+            tolerance=TOLERANCE.IDENTITY_STRICT,
+            msg="next_potential must equal prev_potential when PBRS disabled",
+        )
+        self.assertPlacesEqual(
+            breakdown.reward_shaping, 0.0, places=TOLERANCE.DECIMAL_PLACES_STRICT
+        )
+        self.assertPlacesEqual(breakdown.pbrs_delta, 0.0, places=TOLERANCE.DECIMAL_PLACES_STRICT)
+        self.assertAlmostEqualFloat(
+            breakdown.total,
+            breakdown.base_reward,
+            tolerance=TOLERANCE.IDENTITY_STRICT,
+            msg="PBRS disabled total must equal base_reward",
+        )
+
     def test_exit_potential_canonical(self):
         """Verifies canonical exit resets potential (no params mutation)."""
         params = self.base_params(
@@ -437,8 +482,8 @@ class TestPBRS(RewardSpaceTestBase):
         self.assertPlacesEqual(
             next_potential, prev_potential, places=TOLERANCE.DECIMAL_PLACES_STRICT
         )
-        gamma_raw = DEFAULT_MODEL_REWARD_PARAMETERS.get("potential_gamma", 0.95)
-        gamma_fallback = 0.95 if gamma_raw is None else gamma_raw
+        raw_gamma = DEFAULT_MODEL_REWARD_PARAMETERS.get("potential_gamma", 0.95)
+        gamma_fallback = 0.95 if raw_gamma is None else raw_gamma
         try:
             gamma = float(gamma_fallback)
         except Exception:
@@ -531,10 +576,17 @@ class TestPBRS(RewardSpaceTestBase):
         ]
 
         for key, params in cases:
-            pnl_clamped = _compute_entry_unrealized_pnl_estimate(Positions.Long, params)
-            pnl_expected = _compute_entry_unrealized_pnl_estimate(
+            pnl_clamped = _compute_unrealized_pnl_estimate(
+                Positions.Long,
+                entry_open=1.0,
+                current_open=1.0,
+                params=params,
+            )
+            pnl_expected = _compute_unrealized_pnl_estimate(
                 Positions.Long,
-                {**params, key: 0.1},
+                entry_open=1.0,
+                current_open=1.0,
+                params={**params, key: 0.1},
             )
             self.assertAlmostEqualFloat(
                 pnl_clamped,
@@ -543,8 +595,40 @@ class TestPBRS(RewardSpaceTestBase):
                 msg=f"Expected {key} values above max to clamp to 0.1",
             )
 
+    def test_unrealized_pnl_estimate_uses_division_for_exit_fee(self):
+        """Exit fee uses division `open/(1+fee)`."""
+        params = self.base_params(entry_fee_rate=0.0, exit_fee_rate=0.1)
+
+        pnl_long = _compute_unrealized_pnl_estimate(
+            Positions.Long,
+            entry_open=1.0,
+            current_open=1.0,
+            params=params,
+        )
+        expected_pnl_long = (1.0 / 1.1 - 1.0) / 1.0
+        self.assertAlmostEqualFloat(
+            float(pnl_long),
+            float(expected_pnl_long),
+            tolerance=TOLERANCE.IDENTITY_STRICT,
+            msg="Long entry PnL mismatch for division-based exit fee",
+        )
+
+        pnl_short = _compute_unrealized_pnl_estimate(
+            Positions.Short,
+            entry_open=1.0,
+            current_open=1.0,
+            params=params,
+        )
+        expected_pnl_short = (1.0 / 1.1 - 1.0) / (1.0 / 1.1)
+        self.assertAlmostEqualFloat(
+            float(pnl_short),
+            float(expected_pnl_short),
+            tolerance=TOLERANCE.IDENTITY_STRICT,
+            msg="Short entry PnL mismatch for division-based exit fee",
+        )
+
     def test_simulate_samples_initializes_pnl_on_entry(self):
-        """simulate_samples() sets in-position pnl to entry fee estimate."""
+        """simulate_samples() sets in-position pnl to fee-aware entry estimate."""
         params = self.base_params(
             exit_potential_mode="non_canonical",
             hold_potential_enabled=True,
@@ -555,7 +639,7 @@ class TestPBRS(RewardSpaceTestBase):
         )
 
         df = simulate_samples(
-            num_samples=50,
+            num_samples=80,
             seed=1,
             params=params,
             base_factor=PARAMS.BASE_FACTOR,
@@ -567,15 +651,19 @@ class TestPBRS(RewardSpaceTestBase):
             pnl_duration_vol_scale=0.0,
         )
 
-        enter_rows = df[df["action"] == float(Actions.Long_enter.value)]
-        self.assertGreater(len(enter_rows), 0, "Expected at least one Long_enter in sample")
-
         enter_pos = df.reset_index(drop=True)
         enter_mask = enter_pos["action"].to_numpy() == float(Actions.Long_enter.value)
         enter_positions = np.flatnonzero(enter_mask)
+        self.assertGreater(len(enter_positions), 0, "Expected at least one Long_enter in sample")
+
         first_enter_pos = int(enter_positions[0])
-        next_pos = first_enter_pos + 1
+        self.assertEqual(
+            float(enter_pos.iloc[first_enter_pos]["position"]),
+            float(Positions.Neutral.value),
+            "Expected Neutral position on Long_enter row",
+        )
 
+        next_pos = first_enter_pos + 1
         self.assertLess(next_pos, len(enter_pos), "Sample must include post-entry step")
         self.assertEqual(
             float(enter_pos.iloc[next_pos]["position"]),
@@ -583,7 +671,12 @@ class TestPBRS(RewardSpaceTestBase):
             "Expected Long position immediately after Long_enter",
         )
 
-        expected_pnl = _compute_entry_unrealized_pnl_estimate(Positions.Long, params)
+        expected_pnl = _compute_unrealized_pnl_estimate(
+            Positions.Long,
+            entry_open=1.0,
+            current_open=1.0,
+            params=params,
+        )
         post_entry_pnl = float(enter_pos.iloc[next_pos]["pnl"])
         self.assertAlmostEqualFloat(
             post_entry_pnl,
@@ -797,6 +890,62 @@ class TestPBRS(RewardSpaceTestBase):
             msg="Canonical exit PBRS delta should be -prev_potential",
         )
 
+    def test_invalid_action_still_applies_pbrs_shaping(self):
+        """Invalid action penalties still flow through PBRS shaping."""
+
+        params = self.base_params(
+            max_trade_duration_candles=100,
+            exit_potential_mode="canonical",
+            hold_potential_enabled=True,
+            entry_additive_enabled=False,
+            exit_additive_enabled=False,
+            potential_gamma=0.9,
+        )
+        pnl_target = PARAMS.PROFIT_AIM * PARAMS.RISK_REWARD_RATIO
+        ctx = self.make_ctx(
+            pnl=0.02,
+            trade_duration=10,
+            idle_duration=0,
+            max_unrealized_profit=0.03,
+            min_unrealized_profit=0.01,
+            position=Positions.Long,
+            action=Actions.Short_exit,  # invalid for long
+        )
+
+        current_duration_ratio = ctx.trade_duration / params["max_trade_duration_candles"]
+        prev_potential = _compute_hold_potential(
+            ctx.pnl, pnl_target, current_duration_ratio, params
+        )
+        self.assertNotEqual(prev_potential, 0.0)
+
+        breakdown = calculate_reward(
+            ctx,
+            params,
+            base_factor=PARAMS.BASE_FACTOR,
+            profit_aim=PARAMS.PROFIT_AIM,
+            risk_reward_ratio=PARAMS.RISK_REWARD_RATIO,
+            short_allowed=True,
+            action_masking=False,
+            prev_potential=prev_potential,
+        )
+
+        expected_shaping = params["potential_gamma"] * prev_potential - prev_potential
+        self.assertAlmostEqualFloat(
+            breakdown.reward_shaping,
+            expected_shaping,
+            tolerance=TOLERANCE.IDENTITY_RELAXED,
+            msg="Invalid actions should still produce PBRS shaping",
+        )
+        self.assertAlmostEqualFloat(
+            breakdown.total,
+            breakdown.invalid_penalty
+            + breakdown.reward_shaping
+            + breakdown.entry_additive
+            + breakdown.exit_additive,
+            tolerance=TOLERANCE.IDENTITY_RELAXED,
+            msg="Total should decompose for invalid actions",
+        )
+
     def test_simulate_samples_retains_signals_in_canonical_mode(self):
         """simulate_samples() is not drift-corrected; it must not force Σ shaping ~ 0."""
 
index c69f63d5963915c7073edf7e0eac085f6fa26ce0..bd6e18dd2a1f967fc90ae8332e4720acf5178b11 100644 (file)
@@ -1769,6 +1769,7 @@ class MyRLEnv(Base5ActionRLEnv):
                 "PBRS: hold_potential_enabled=True and add_state_info=False is unsupported. Automatically enabling add_state_info=True."
             )
             self.add_state_info = True
+            self._set_observation_space()
 
         # === PNL TARGET VALIDATION ===
         pnl_target = self.profit_aim * self.rr
@@ -2870,6 +2871,36 @@ class MyRLEnv(Base5ActionRLEnv):
 
         return None
 
+    def _apply_terminal_pbrs_correction(self, reward: float) -> float:
+        if not (
+            self._hold_potential_enabled
+            or self._entry_additive_enabled
+            or self._exit_additive_enabled
+        ):
+            self._last_potential = 0.0
+            self._last_next_potential = 0.0
+            self._last_reward_shaping = 0.0
+            return reward
+
+        prev_potential = self._last_prev_potential
+        computed_reward_shaping = self._last_reward_shaping
+        terminal_reward_shaping = -prev_potential
+        reward_shaping_delta = terminal_reward_shaping - computed_reward_shaping
+        last_next_potential = self._last_next_potential
+
+        if np.isclose(computed_reward_shaping, terminal_reward_shaping) and np.isclose(
+            last_next_potential, 0.0
+        ):
+            self._last_potential = 0.0
+            self._last_next_potential = 0.0
+            return reward
+
+        self._last_potential = 0.0
+        self._last_next_potential = 0.0
+        self._last_reward_shaping = terminal_reward_shaping
+        self._total_reward_shaping += reward_shaping_delta
+        return reward + reward_shaping_delta
+
     def step(
         self, action: int
     ) -> Tuple[NDArray[np.float32], float, bool, bool, Dict[str, Any]]:
@@ -2881,11 +2912,15 @@ class MyRLEnv(Base5ActionRLEnv):
         pre_pnl = self.get_unrealized_profit()
         self._update_portfolio_log_returns()
         reward = self.calculate_reward(action)
-        self.total_reward += reward
         trade_type = self.execute_trade(action)
         if trade_type is not None:
             self.append_trade_history(trade_type, self.current_price(), pre_pnl)
         self._position_history.append(self._position)
+        terminated = self.is_terminated()
+        if terminated:
+            reward = self._apply_terminal_pbrs_correction(reward)
+            self._last_potential = 0.0
+        self.total_reward += reward
         pnl = self.get_unrealized_profit()
         self._update_max_unrealized_profit(pnl)
         self._update_min_unrealized_profit(pnl)
@@ -2927,18 +2962,6 @@ class MyRLEnv(Base5ActionRLEnv):
             "trade_count": len(self.trade_history) // 2,
         }
         self._update_history(info)
-        terminated = self.is_terminated()
-        if terminated:
-            # Enforce Φ(terminal)=0 for PBRS invariance (Wiewiora et al. 2003)
-            self._last_potential = 0.0
-            # eps = np.finfo(float).eps
-            # if self.is_pbrs_invariant_mode() and abs(self._total_reward_shaping) > eps:
-            #     logger.warning(
-            #         "PBRS mode %s invariance deviation: |sum Δ|=%.6f > eps=%.6f",
-            #         self._exit_potential_mode,
-            #         abs(self._total_reward_shaping),
-            #         eps,
-            #     )
         return (
             self._get_observation(),
             reward,