]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor(reforxexy): factor out PBRS discount gamma transmission to env
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Thu, 16 Oct 2025 21:07:04 +0000 (23:07 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Thu, 16 Oct 2025 21:07:04 +0000 (23:07 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/reward_space_analysis/test_reward_space_analysis.py
ReforceXY/user_data/freqaimodels/ReforceXY.py

index 173f93ab0852de6dfc81e3bd8ca8209513a80e8e..a7f82d557802256c2361418e06cd3a5d96e58596 100644 (file)
@@ -1114,28 +1114,6 @@ class TestRewardComponents(RewardSpaceTestBase):
                     self.assertNotEqual(breakdown.exit_component, 0.0)
                 self.assertFinite(breakdown.total, name="breakdown.total")
 
-    def test_basic_reward_calculation(self):
-        context = self.make_ctx(
-            pnl=self.TEST_PROFIT_TARGET,
-            trade_duration=10,
-            max_trade_duration=100,
-            max_unrealized_profit=0.025,
-            min_unrealized_profit=0.015,
-            position=Positions.Long,
-            action=Actions.Long_exit,
-        )
-        br = calculate_reward(
-            context,
-            self.DEFAULT_PARAMS,
-            base_factor=self.TEST_BASE_FACTOR,
-            profit_target=0.06,
-            risk_reward_ratio=self.TEST_RR_HIGH,
-            short_allowed=True,
-            action_masking=True,
-        )
-        self.assertFinite(br.total, name="total")
-        self.assertGreater(br.exit_component, 0)
-
     def test_efficiency_zero_policy(self):
         ctx = self.make_ctx(
             pnl=0.0,
@@ -1687,6 +1665,72 @@ class TestAPIAndHelpers(RewardSpaceTestBase):
             short_positions, 0, "Futures mode should allow short positions"
         )
 
+    def test_get_float_param(self):
+        """Test float parameter extraction."""
+        params = {"test_float": 1.5, "test_int": 2, "test_str": "hello"}
+        self.assertEqual(_get_float_param(params, "test_float", 0.0), 1.5)
+        self.assertEqual(_get_float_param(params, "test_int", 0.0), 2.0)
+        # Non parseable string -> NaN fallback in tolerant parser
+        val_str = _get_float_param(params, "test_str", 0.0)
+        if isinstance(val_str, float) and math.isnan(val_str):
+            pass
+        else:
+            self.fail("Expected NaN for non-numeric string in _get_float_param")
+        self.assertEqual(_get_float_param(params, "missing", 3.14), 3.14)
+
+    def test_get_str_param(self):
+        """Test string parameter extraction."""
+        params = {"test_str": "hello", "test_int": 2}
+        self.assertEqual(_get_str_param(params, "test_str", "default"), "hello")
+        self.assertEqual(_get_str_param(params, "test_int", "default"), "default")
+        self.assertEqual(_get_str_param(params, "missing", "default"), "default")
+
+    def test_get_bool_param(self):
+        """Test boolean parameter extraction."""
+        params = {
+            "test_true": True,
+            "test_false": False,
+            "test_int": 1,
+            "test_str": "yes",
+        }
+        self.assertTrue(_get_bool_param(params, "test_true", False))
+        self.assertFalse(_get_bool_param(params, "test_false", True))
+        # Environment coerces typical truthy numeric/string values
+        self.assertTrue(_get_bool_param(params, "test_int", False))
+        self.assertTrue(_get_bool_param(params, "test_str", False))
+        self.assertFalse(_get_bool_param(params, "missing", False))
+
+    def test_get_int_param_coercions(self):
+        """Robust coercion paths of _get_int_param (bool/int/float/str/None/unsupported)."""
+        # None with numeric default
+        self.assertEqual(_get_int_param({"k": None}, "k", 5), 5)
+        # None with non-numeric default -> 0 fallback
+        self.assertEqual(_get_int_param({"k": None}, "k", "x"), 0)
+        # Booleans
+        self.assertEqual(_get_int_param({"k": True}, "k", 0), 1)
+        self.assertEqual(_get_int_param({"k": False}, "k", 7), 0)
+        # Int passthrough
+        self.assertEqual(_get_int_param({"k": -12}, "k", 0), -12)
+        # Float truncation & negative
+        self.assertEqual(_get_int_param({"k": 9.99}, "k", 0), 9)
+        self.assertEqual(_get_int_param({"k": -3.7}, "k", 0), -3)
+        # Non-finite floats fallback
+        self.assertEqual(_get_int_param({"k": float("nan")}, "k", 4), 4)
+        self.assertEqual(_get_int_param({"k": float("inf")}, "k", 4), 4)
+        # String numerics (int, float, exponent)
+        self.assertEqual(_get_int_param({"k": "42"}, "k", 0), 42)
+        self.assertEqual(_get_int_param({"k": " 17 "}, "k", 0), 17)
+        self.assertEqual(_get_int_param({"k": "3.9"}, "k", 0), 3)
+        self.assertEqual(_get_int_param({"k": "1e2"}, "k", 0), 100)
+        # String fallbacks (empty, invalid, NaN token)
+        self.assertEqual(_get_int_param({"k": ""}, "k", 5), 5)
+        self.assertEqual(_get_int_param({"k": "abc"}, "k", 5), 5)
+        self.assertEqual(_get_int_param({"k": "NaN"}, "k", 5), 5)
+        # Unsupported type
+        self.assertEqual(_get_int_param({"k": [1, 2, 3]}, "k", 3), 3)
+        # Missing key with non-numeric default
+        self.assertEqual(_get_int_param({}, "missing", "zzz"), 0)
+
     def test_argument_parser_construction(self):
         """Test build_argument_parser function."""
 
@@ -3001,41 +3045,6 @@ class TestPBRS(RewardSpaceTestBase):
             tolerance=self.TOL_IDENTITY_RELAXED,
         )
 
-    def test_get_float_param(self):
-        """Test float parameter extraction."""
-        params = {"test_float": 1.5, "test_int": 2, "test_str": "hello"}
-        self.assertEqual(_get_float_param(params, "test_float", 0.0), 1.5)
-        self.assertEqual(_get_float_param(params, "test_int", 0.0), 2.0)
-        # Non parseable string -> NaN fallback in tolerant parser
-        val_str = _get_float_param(params, "test_str", 0.0)
-        if isinstance(val_str, float) and math.isnan(val_str):
-            pass
-        else:
-            self.fail("Expected NaN for non-numeric string in _get_float_param")
-        self.assertEqual(_get_float_param(params, "missing", 3.14), 3.14)
-
-    def test_get_str_param(self):
-        """Test string parameter extraction."""
-        params = {"test_str": "hello", "test_int": 2}
-        self.assertEqual(_get_str_param(params, "test_str", "default"), "hello")
-        self.assertEqual(_get_str_param(params, "test_int", "default"), "default")
-        self.assertEqual(_get_str_param(params, "missing", "default"), "default")
-
-    def test_get_bool_param(self):
-        """Test boolean parameter extraction."""
-        params = {
-            "test_true": True,
-            "test_false": False,
-            "test_int": 1,
-            "test_str": "yes",
-        }
-        self.assertTrue(_get_bool_param(params, "test_true", False))
-        self.assertFalse(_get_bool_param(params, "test_false", True))
-        # Environment coerces typical truthy numeric/string values
-        self.assertTrue(_get_bool_param(params, "test_int", False))
-        self.assertTrue(_get_bool_param(params, "test_str", False))
-        self.assertFalse(_get_bool_param(params, "missing", False))
-
     def test_hold_potential_basic(self):
         """Test basic hold potential calculation."""
         params = {
@@ -3313,37 +3322,6 @@ class TestPBRS(RewardSpaceTestBase):
             "Canonical shaping magnitude should exceed spike_cancel",
         )
 
-    def test_get_int_param_coercions(self):
-        """Robust coercion paths of _get_int_param (bool/int/float/str/None/unsupported)."""
-        # None with numeric default
-        self.assertEqual(_get_int_param({"k": None}, "k", 5), 5)
-        # None with non-numeric default -> 0 fallback
-        self.assertEqual(_get_int_param({"k": None}, "k", "x"), 0)
-        # Booleans
-        self.assertEqual(_get_int_param({"k": True}, "k", 0), 1)
-        self.assertEqual(_get_int_param({"k": False}, "k", 7), 0)
-        # Int passthrough
-        self.assertEqual(_get_int_param({"k": -12}, "k", 0), -12)
-        # Float truncation & negative
-        self.assertEqual(_get_int_param({"k": 9.99}, "k", 0), 9)
-        self.assertEqual(_get_int_param({"k": -3.7}, "k", 0), -3)
-        # Non-finite floats fallback
-        self.assertEqual(_get_int_param({"k": float("nan")}, "k", 4), 4)
-        self.assertEqual(_get_int_param({"k": float("inf")}, "k", 4), 4)
-        # String numerics (int, float, exponent)
-        self.assertEqual(_get_int_param({"k": "42"}, "k", 0), 42)
-        self.assertEqual(_get_int_param({"k": " 17 "}, "k", 0), 17)
-        self.assertEqual(_get_int_param({"k": "3.9"}, "k", 0), 3)
-        self.assertEqual(_get_int_param({"k": "1e2"}, "k", 0), 100)
-        # String fallbacks (empty, invalid, NaN token)
-        self.assertEqual(_get_int_param({"k": ""}, "k", 5), 5)
-        self.assertEqual(_get_int_param({"k": "abc"}, "k", 5), 5)
-        self.assertEqual(_get_int_param({"k": "NaN"}, "k", 5), 5)
-        # Unsupported type
-        self.assertEqual(_get_int_param({"k": [1, 2, 3]}, "k", 3), 3)
-        # Missing key with non-numeric default
-        self.assertEqual(_get_int_param({}, "missing", "zzz"), 0)
-
     def test_transform_bulk_monotonicity_and_bounds(self):
         """Non-decreasing monotonicity & (-1,1) bounds for smooth transforms (excluding clip)."""
         transforms = ["tanh", "softsign", "arctan", "sigmoid", "asinh"]
index cff90f5855cd746929c71f0e66452d4e2c4e5209..12664d93c16b464d4af0a1ea44b9a9a1d5b95f7a 100644 (file)
@@ -285,6 +285,45 @@ class ReforceXY(BaseReinforcementLearningModel):
             )
             self.continual_learning = False
 
+    def pack_env_dict(
+        self, pair: str, model_params: Optional[Dict[str, Any]] = None
+    ) -> Dict[str, Any]:
+        env_info = super().pack_env_dict(pair)
+
+        config = env_info.setdefault("config", {})
+        freqai_cfg = config.setdefault("freqai", {})
+        rl_cfg = freqai_cfg.setdefault("rl_config", {})
+        model_reward_parameters = rl_cfg.setdefault("model_reward_parameters", {})
+
+        gamma: Optional[float] = None
+        best_trial_params: Optional[Dict[str, Any]] = None
+        if self.hyperopt:
+            best_trial_params = self.load_best_trial_params(pair)
+
+        if model_params and isinstance(model_params.get("gamma"), (int, float)):
+            gamma = float(model_params.get("gamma"))
+        elif best_trial_params and isinstance(
+            best_trial_params.get("gamma"), (int, float)
+        ):
+            gamma = float(best_trial_params.get("gamma"))
+        elif hasattr(self.model, "gamma") and isinstance(
+            self.model.gamma, (int, float)
+        ):
+            gamma = float(self.model.gamma)
+        else:
+            model_params_gamma = self.get_model_params().get("gamma")
+            if isinstance(model_params_gamma, (int, float)):
+                gamma = float(model_params_gamma)
+
+        if gamma is not None:
+            model_reward_parameters["potential_gamma"] = gamma
+        else:
+            logger.warning(
+                f"{pair}: No valid PBRS discount gamma resolved for environment"
+            )
+
+        return env_info
+
     def set_train_and_eval_environments(
         self,
         data_dictionary: Dict[str, DataFrame],
@@ -302,9 +341,6 @@ class ReforceXY(BaseReinforcementLearningModel):
         train_df = data_dictionary.get("train_features")
         test_df = data_dictionary.get("test_features")
         env_dict = self.pack_env_dict(dk.pair)
-        env_dict["config"]["freqai"]["rl_config"]["model_reward_parameters"][
-            "potential_gamma"
-        ] = self.get_model_params().get("gamma")
         seed = self.get_model_params().get("seed", 42)
 
         if self.check_envs:
@@ -571,7 +607,7 @@ class ReforceXY(BaseReinforcementLearningModel):
         self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs
     ) -> Any:
         """
-        User customizable fit method
+        Model fitting method
         :param data_dictionary: dict = common data dictionary containing all train/test features/labels/weights.
         :param dk: FreqaiDatakitchen = data kitchen for current pair.
         :return:
@@ -1080,27 +1116,8 @@ class ReforceXY(BaseReinforcementLearningModel):
             seed += trial.number
         set_random_seed(seed)
         env_info: Dict[str, Any] = (
-            self.pack_env_dict(dk.pair) if env_info is None else env_info
+            self.pack_env_dict(dk.pair, model_params) if env_info is None else env_info
         )
-        gamma: Optional[float] = None
-        best_trial_params: Optional[Dict[str, Any]] = None
-        if self.hyperopt:
-            best_trial_params = self.load_best_trial_params(dk.pair)
-        if model_params and isinstance(model_params.get("gamma"), (int, float)):
-            gamma = model_params.get("gamma")
-        elif best_trial_params:
-            gamma = best_trial_params.get("gamma")
-        elif hasattr(self.model, "gamma") and isinstance(
-            self.model.gamma, (int, float)
-        ):
-            gamma = self.model.gamma
-        elif isinstance(self.get_model_params().get("gamma"), (int, float)):
-            gamma = self.get_model_params().get("gamma")
-        if gamma is not None:
-            # Align RL agent gamma with PBRS gamma for consistent discount factor
-            env_info["config"]["freqai"]["rl_config"]["model_reward_parameters"][
-                "potential_gamma"
-            ] = float(gamma)
         env_prefix = f"trial_{trial.number}_" if trial is not None else ""
 
         train_fns = [