From 0865636f9801380a038caaef70e88e34a28c95a2 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Thu, 16 Oct 2025 23:07:04 +0200 Subject: [PATCH] refactor(reforxexy): factor out PBRS discount gamma transmission to env MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- .../test_reward_space_analysis.py | 154 ++++++++---------- ReforceXY/user_data/freqaimodels/ReforceXY.py | 65 +++++--- 2 files changed, 107 insertions(+), 112 deletions(-) diff --git a/ReforceXY/reward_space_analysis/test_reward_space_analysis.py b/ReforceXY/reward_space_analysis/test_reward_space_analysis.py index 173f93a..a7f82d5 100644 --- a/ReforceXY/reward_space_analysis/test_reward_space_analysis.py +++ b/ReforceXY/reward_space_analysis/test_reward_space_analysis.py @@ -1114,28 +1114,6 @@ class TestRewardComponents(RewardSpaceTestBase): self.assertNotEqual(breakdown.exit_component, 0.0) self.assertFinite(breakdown.total, name="breakdown.total") - def test_basic_reward_calculation(self): - context = self.make_ctx( - pnl=self.TEST_PROFIT_TARGET, - trade_duration=10, - max_trade_duration=100, - max_unrealized_profit=0.025, - min_unrealized_profit=0.015, - position=Positions.Long, - action=Actions.Long_exit, - ) - br = calculate_reward( - context, - self.DEFAULT_PARAMS, - base_factor=self.TEST_BASE_FACTOR, - profit_target=0.06, - risk_reward_ratio=self.TEST_RR_HIGH, - short_allowed=True, - action_masking=True, - ) - self.assertFinite(br.total, name="total") - self.assertGreater(br.exit_component, 0) - def test_efficiency_zero_policy(self): ctx = self.make_ctx( pnl=0.0, @@ -1687,6 +1665,72 @@ class TestAPIAndHelpers(RewardSpaceTestBase): short_positions, 0, "Futures mode should allow short positions" ) + def test_get_float_param(self): + """Test float parameter extraction.""" + params = {"test_float": 1.5, "test_int": 2, "test_str": "hello"} + self.assertEqual(_get_float_param(params, "test_float", 0.0), 1.5) + self.assertEqual(_get_float_param(params, "test_int", 0.0), 2.0) + # Non parseable string -> NaN fallback in tolerant parser + val_str = _get_float_param(params, "test_str", 0.0) + if isinstance(val_str, float) and math.isnan(val_str): + pass + else: + self.fail("Expected NaN for non-numeric string in _get_float_param") + self.assertEqual(_get_float_param(params, "missing", 3.14), 3.14) + + def test_get_str_param(self): + """Test string parameter extraction.""" + params = {"test_str": "hello", "test_int": 2} + self.assertEqual(_get_str_param(params, "test_str", "default"), "hello") + self.assertEqual(_get_str_param(params, "test_int", "default"), "default") + self.assertEqual(_get_str_param(params, "missing", "default"), "default") + + def test_get_bool_param(self): + """Test boolean parameter extraction.""" + params = { + "test_true": True, + "test_false": False, + "test_int": 1, + "test_str": "yes", + } + self.assertTrue(_get_bool_param(params, "test_true", False)) + self.assertFalse(_get_bool_param(params, "test_false", True)) + # Environment coerces typical truthy numeric/string values + self.assertTrue(_get_bool_param(params, "test_int", False)) + self.assertTrue(_get_bool_param(params, "test_str", False)) + self.assertFalse(_get_bool_param(params, "missing", False)) + + def test_get_int_param_coercions(self): + """Robust coercion paths of _get_int_param (bool/int/float/str/None/unsupported).""" + # None with numeric default + self.assertEqual(_get_int_param({"k": None}, "k", 5), 5) + # None with non-numeric default -> 0 fallback + self.assertEqual(_get_int_param({"k": None}, "k", "x"), 0) + # Booleans + self.assertEqual(_get_int_param({"k": True}, "k", 0), 1) + self.assertEqual(_get_int_param({"k": False}, "k", 7), 0) + # Int passthrough + self.assertEqual(_get_int_param({"k": -12}, "k", 0), -12) + # Float truncation & negative + self.assertEqual(_get_int_param({"k": 9.99}, "k", 0), 9) + self.assertEqual(_get_int_param({"k": -3.7}, "k", 0), -3) + # Non-finite floats fallback + self.assertEqual(_get_int_param({"k": float("nan")}, "k", 4), 4) + self.assertEqual(_get_int_param({"k": float("inf")}, "k", 4), 4) + # String numerics (int, float, exponent) + self.assertEqual(_get_int_param({"k": "42"}, "k", 0), 42) + self.assertEqual(_get_int_param({"k": " 17 "}, "k", 0), 17) + self.assertEqual(_get_int_param({"k": "3.9"}, "k", 0), 3) + self.assertEqual(_get_int_param({"k": "1e2"}, "k", 0), 100) + # String fallbacks (empty, invalid, NaN token) + self.assertEqual(_get_int_param({"k": ""}, "k", 5), 5) + self.assertEqual(_get_int_param({"k": "abc"}, "k", 5), 5) + self.assertEqual(_get_int_param({"k": "NaN"}, "k", 5), 5) + # Unsupported type + self.assertEqual(_get_int_param({"k": [1, 2, 3]}, "k", 3), 3) + # Missing key with non-numeric default + self.assertEqual(_get_int_param({}, "missing", "zzz"), 0) + def test_argument_parser_construction(self): """Test build_argument_parser function.""" @@ -3001,41 +3045,6 @@ class TestPBRS(RewardSpaceTestBase): tolerance=self.TOL_IDENTITY_RELAXED, ) - def test_get_float_param(self): - """Test float parameter extraction.""" - params = {"test_float": 1.5, "test_int": 2, "test_str": "hello"} - self.assertEqual(_get_float_param(params, "test_float", 0.0), 1.5) - self.assertEqual(_get_float_param(params, "test_int", 0.0), 2.0) - # Non parseable string -> NaN fallback in tolerant parser - val_str = _get_float_param(params, "test_str", 0.0) - if isinstance(val_str, float) and math.isnan(val_str): - pass - else: - self.fail("Expected NaN for non-numeric string in _get_float_param") - self.assertEqual(_get_float_param(params, "missing", 3.14), 3.14) - - def test_get_str_param(self): - """Test string parameter extraction.""" - params = {"test_str": "hello", "test_int": 2} - self.assertEqual(_get_str_param(params, "test_str", "default"), "hello") - self.assertEqual(_get_str_param(params, "test_int", "default"), "default") - self.assertEqual(_get_str_param(params, "missing", "default"), "default") - - def test_get_bool_param(self): - """Test boolean parameter extraction.""" - params = { - "test_true": True, - "test_false": False, - "test_int": 1, - "test_str": "yes", - } - self.assertTrue(_get_bool_param(params, "test_true", False)) - self.assertFalse(_get_bool_param(params, "test_false", True)) - # Environment coerces typical truthy numeric/string values - self.assertTrue(_get_bool_param(params, "test_int", False)) - self.assertTrue(_get_bool_param(params, "test_str", False)) - self.assertFalse(_get_bool_param(params, "missing", False)) - def test_hold_potential_basic(self): """Test basic hold potential calculation.""" params = { @@ -3313,37 +3322,6 @@ class TestPBRS(RewardSpaceTestBase): "Canonical shaping magnitude should exceed spike_cancel", ) - def test_get_int_param_coercions(self): - """Robust coercion paths of _get_int_param (bool/int/float/str/None/unsupported).""" - # None with numeric default - self.assertEqual(_get_int_param({"k": None}, "k", 5), 5) - # None with non-numeric default -> 0 fallback - self.assertEqual(_get_int_param({"k": None}, "k", "x"), 0) - # Booleans - self.assertEqual(_get_int_param({"k": True}, "k", 0), 1) - self.assertEqual(_get_int_param({"k": False}, "k", 7), 0) - # Int passthrough - self.assertEqual(_get_int_param({"k": -12}, "k", 0), -12) - # Float truncation & negative - self.assertEqual(_get_int_param({"k": 9.99}, "k", 0), 9) - self.assertEqual(_get_int_param({"k": -3.7}, "k", 0), -3) - # Non-finite floats fallback - self.assertEqual(_get_int_param({"k": float("nan")}, "k", 4), 4) - self.assertEqual(_get_int_param({"k": float("inf")}, "k", 4), 4) - # String numerics (int, float, exponent) - self.assertEqual(_get_int_param({"k": "42"}, "k", 0), 42) - self.assertEqual(_get_int_param({"k": " 17 "}, "k", 0), 17) - self.assertEqual(_get_int_param({"k": "3.9"}, "k", 0), 3) - self.assertEqual(_get_int_param({"k": "1e2"}, "k", 0), 100) - # String fallbacks (empty, invalid, NaN token) - self.assertEqual(_get_int_param({"k": ""}, "k", 5), 5) - self.assertEqual(_get_int_param({"k": "abc"}, "k", 5), 5) - self.assertEqual(_get_int_param({"k": "NaN"}, "k", 5), 5) - # Unsupported type - self.assertEqual(_get_int_param({"k": [1, 2, 3]}, "k", 3), 3) - # Missing key with non-numeric default - self.assertEqual(_get_int_param({}, "missing", "zzz"), 0) - def test_transform_bulk_monotonicity_and_bounds(self): """Non-decreasing monotonicity & (-1,1) bounds for smooth transforms (excluding clip).""" transforms = ["tanh", "softsign", "arctan", "sigmoid", "asinh"] diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index cff90f5..12664d9 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -285,6 +285,45 @@ class ReforceXY(BaseReinforcementLearningModel): ) self.continual_learning = False + def pack_env_dict( + self, pair: str, model_params: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + env_info = super().pack_env_dict(pair) + + config = env_info.setdefault("config", {}) + freqai_cfg = config.setdefault("freqai", {}) + rl_cfg = freqai_cfg.setdefault("rl_config", {}) + model_reward_parameters = rl_cfg.setdefault("model_reward_parameters", {}) + + gamma: Optional[float] = None + best_trial_params: Optional[Dict[str, Any]] = None + if self.hyperopt: + best_trial_params = self.load_best_trial_params(pair) + + if model_params and isinstance(model_params.get("gamma"), (int, float)): + gamma = float(model_params.get("gamma")) + elif best_trial_params and isinstance( + best_trial_params.get("gamma"), (int, float) + ): + gamma = float(best_trial_params.get("gamma")) + elif hasattr(self.model, "gamma") and isinstance( + self.model.gamma, (int, float) + ): + gamma = float(self.model.gamma) + else: + model_params_gamma = self.get_model_params().get("gamma") + if isinstance(model_params_gamma, (int, float)): + gamma = float(model_params_gamma) + + if gamma is not None: + model_reward_parameters["potential_gamma"] = gamma + else: + logger.warning( + f"{pair}: No valid PBRS discount gamma resolved for environment" + ) + + return env_info + def set_train_and_eval_environments( self, data_dictionary: Dict[str, DataFrame], @@ -302,9 +341,6 @@ class ReforceXY(BaseReinforcementLearningModel): train_df = data_dictionary.get("train_features") test_df = data_dictionary.get("test_features") env_dict = self.pack_env_dict(dk.pair) - env_dict["config"]["freqai"]["rl_config"]["model_reward_parameters"][ - "potential_gamma" - ] = self.get_model_params().get("gamma") seed = self.get_model_params().get("seed", 42) if self.check_envs: @@ -571,7 +607,7 @@ class ReforceXY(BaseReinforcementLearningModel): self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs ) -> Any: """ - User customizable fit method + Model fitting method :param data_dictionary: dict = common data dictionary containing all train/test features/labels/weights. :param dk: FreqaiDatakitchen = data kitchen for current pair. :return: @@ -1080,27 +1116,8 @@ class ReforceXY(BaseReinforcementLearningModel): seed += trial.number set_random_seed(seed) env_info: Dict[str, Any] = ( - self.pack_env_dict(dk.pair) if env_info is None else env_info + self.pack_env_dict(dk.pair, model_params) if env_info is None else env_info ) - gamma: Optional[float] = None - best_trial_params: Optional[Dict[str, Any]] = None - if self.hyperopt: - best_trial_params = self.load_best_trial_params(dk.pair) - if model_params and isinstance(model_params.get("gamma"), (int, float)): - gamma = model_params.get("gamma") - elif best_trial_params: - gamma = best_trial_params.get("gamma") - elif hasattr(self.model, "gamma") and isinstance( - self.model.gamma, (int, float) - ): - gamma = self.model.gamma - elif isinstance(self.get_model_params().get("gamma"), (int, float)): - gamma = self.get_model_params().get("gamma") - if gamma is not None: - # Align RL agent gamma with PBRS gamma for consistent discount factor - env_info["config"]["freqai"]["rl_config"]["model_reward_parameters"][ - "potential_gamma" - ] = float(gamma) env_prefix = f"trial_{trial.number}_" if trial is not None else "" train_fns = [ -- 2.43.0