self.assertNotEqual(breakdown.exit_component, 0.0)
self.assertFinite(breakdown.total, name="breakdown.total")
- def test_basic_reward_calculation(self):
- context = self.make_ctx(
- pnl=self.TEST_PROFIT_TARGET,
- trade_duration=10,
- max_trade_duration=100,
- max_unrealized_profit=0.025,
- min_unrealized_profit=0.015,
- position=Positions.Long,
- action=Actions.Long_exit,
- )
- br = calculate_reward(
- context,
- self.DEFAULT_PARAMS,
- base_factor=self.TEST_BASE_FACTOR,
- profit_target=0.06,
- risk_reward_ratio=self.TEST_RR_HIGH,
- short_allowed=True,
- action_masking=True,
- )
- self.assertFinite(br.total, name="total")
- self.assertGreater(br.exit_component, 0)
-
def test_efficiency_zero_policy(self):
ctx = self.make_ctx(
pnl=0.0,
short_positions, 0, "Futures mode should allow short positions"
)
+ def test_get_float_param(self):
+ """Test float parameter extraction."""
+ params = {"test_float": 1.5, "test_int": 2, "test_str": "hello"}
+ self.assertEqual(_get_float_param(params, "test_float", 0.0), 1.5)
+ self.assertEqual(_get_float_param(params, "test_int", 0.0), 2.0)
+ # Non parseable string -> NaN fallback in tolerant parser
+ val_str = _get_float_param(params, "test_str", 0.0)
+ if isinstance(val_str, float) and math.isnan(val_str):
+ pass
+ else:
+ self.fail("Expected NaN for non-numeric string in _get_float_param")
+ self.assertEqual(_get_float_param(params, "missing", 3.14), 3.14)
+
+ def test_get_str_param(self):
+ """Test string parameter extraction."""
+ params = {"test_str": "hello", "test_int": 2}
+ self.assertEqual(_get_str_param(params, "test_str", "default"), "hello")
+ self.assertEqual(_get_str_param(params, "test_int", "default"), "default")
+ self.assertEqual(_get_str_param(params, "missing", "default"), "default")
+
+ def test_get_bool_param(self):
+ """Test boolean parameter extraction."""
+ params = {
+ "test_true": True,
+ "test_false": False,
+ "test_int": 1,
+ "test_str": "yes",
+ }
+ self.assertTrue(_get_bool_param(params, "test_true", False))
+ self.assertFalse(_get_bool_param(params, "test_false", True))
+ # Environment coerces typical truthy numeric/string values
+ self.assertTrue(_get_bool_param(params, "test_int", False))
+ self.assertTrue(_get_bool_param(params, "test_str", False))
+ self.assertFalse(_get_bool_param(params, "missing", False))
+
+ def test_get_int_param_coercions(self):
+ """Robust coercion paths of _get_int_param (bool/int/float/str/None/unsupported)."""
+ # None with numeric default
+ self.assertEqual(_get_int_param({"k": None}, "k", 5), 5)
+ # None with non-numeric default -> 0 fallback
+ self.assertEqual(_get_int_param({"k": None}, "k", "x"), 0)
+ # Booleans
+ self.assertEqual(_get_int_param({"k": True}, "k", 0), 1)
+ self.assertEqual(_get_int_param({"k": False}, "k", 7), 0)
+ # Int passthrough
+ self.assertEqual(_get_int_param({"k": -12}, "k", 0), -12)
+ # Float truncation & negative
+ self.assertEqual(_get_int_param({"k": 9.99}, "k", 0), 9)
+ self.assertEqual(_get_int_param({"k": -3.7}, "k", 0), -3)
+ # Non-finite floats fallback
+ self.assertEqual(_get_int_param({"k": float("nan")}, "k", 4), 4)
+ self.assertEqual(_get_int_param({"k": float("inf")}, "k", 4), 4)
+ # String numerics (int, float, exponent)
+ self.assertEqual(_get_int_param({"k": "42"}, "k", 0), 42)
+ self.assertEqual(_get_int_param({"k": " 17 "}, "k", 0), 17)
+ self.assertEqual(_get_int_param({"k": "3.9"}, "k", 0), 3)
+ self.assertEqual(_get_int_param({"k": "1e2"}, "k", 0), 100)
+ # String fallbacks (empty, invalid, NaN token)
+ self.assertEqual(_get_int_param({"k": ""}, "k", 5), 5)
+ self.assertEqual(_get_int_param({"k": "abc"}, "k", 5), 5)
+ self.assertEqual(_get_int_param({"k": "NaN"}, "k", 5), 5)
+ # Unsupported type
+ self.assertEqual(_get_int_param({"k": [1, 2, 3]}, "k", 3), 3)
+ # Missing key with non-numeric default
+ self.assertEqual(_get_int_param({}, "missing", "zzz"), 0)
+
def test_argument_parser_construction(self):
"""Test build_argument_parser function."""
tolerance=self.TOL_IDENTITY_RELAXED,
)
- def test_get_float_param(self):
- """Test float parameter extraction."""
- params = {"test_float": 1.5, "test_int": 2, "test_str": "hello"}
- self.assertEqual(_get_float_param(params, "test_float", 0.0), 1.5)
- self.assertEqual(_get_float_param(params, "test_int", 0.0), 2.0)
- # Non parseable string -> NaN fallback in tolerant parser
- val_str = _get_float_param(params, "test_str", 0.0)
- if isinstance(val_str, float) and math.isnan(val_str):
- pass
- else:
- self.fail("Expected NaN for non-numeric string in _get_float_param")
- self.assertEqual(_get_float_param(params, "missing", 3.14), 3.14)
-
- def test_get_str_param(self):
- """Test string parameter extraction."""
- params = {"test_str": "hello", "test_int": 2}
- self.assertEqual(_get_str_param(params, "test_str", "default"), "hello")
- self.assertEqual(_get_str_param(params, "test_int", "default"), "default")
- self.assertEqual(_get_str_param(params, "missing", "default"), "default")
-
- def test_get_bool_param(self):
- """Test boolean parameter extraction."""
- params = {
- "test_true": True,
- "test_false": False,
- "test_int": 1,
- "test_str": "yes",
- }
- self.assertTrue(_get_bool_param(params, "test_true", False))
- self.assertFalse(_get_bool_param(params, "test_false", True))
- # Environment coerces typical truthy numeric/string values
- self.assertTrue(_get_bool_param(params, "test_int", False))
- self.assertTrue(_get_bool_param(params, "test_str", False))
- self.assertFalse(_get_bool_param(params, "missing", False))
-
def test_hold_potential_basic(self):
"""Test basic hold potential calculation."""
params = {
"Canonical shaping magnitude should exceed spike_cancel",
)
- def test_get_int_param_coercions(self):
- """Robust coercion paths of _get_int_param (bool/int/float/str/None/unsupported)."""
- # None with numeric default
- self.assertEqual(_get_int_param({"k": None}, "k", 5), 5)
- # None with non-numeric default -> 0 fallback
- self.assertEqual(_get_int_param({"k": None}, "k", "x"), 0)
- # Booleans
- self.assertEqual(_get_int_param({"k": True}, "k", 0), 1)
- self.assertEqual(_get_int_param({"k": False}, "k", 7), 0)
- # Int passthrough
- self.assertEqual(_get_int_param({"k": -12}, "k", 0), -12)
- # Float truncation & negative
- self.assertEqual(_get_int_param({"k": 9.99}, "k", 0), 9)
- self.assertEqual(_get_int_param({"k": -3.7}, "k", 0), -3)
- # Non-finite floats fallback
- self.assertEqual(_get_int_param({"k": float("nan")}, "k", 4), 4)
- self.assertEqual(_get_int_param({"k": float("inf")}, "k", 4), 4)
- # String numerics (int, float, exponent)
- self.assertEqual(_get_int_param({"k": "42"}, "k", 0), 42)
- self.assertEqual(_get_int_param({"k": " 17 "}, "k", 0), 17)
- self.assertEqual(_get_int_param({"k": "3.9"}, "k", 0), 3)
- self.assertEqual(_get_int_param({"k": "1e2"}, "k", 0), 100)
- # String fallbacks (empty, invalid, NaN token)
- self.assertEqual(_get_int_param({"k": ""}, "k", 5), 5)
- self.assertEqual(_get_int_param({"k": "abc"}, "k", 5), 5)
- self.assertEqual(_get_int_param({"k": "NaN"}, "k", 5), 5)
- # Unsupported type
- self.assertEqual(_get_int_param({"k": [1, 2, 3]}, "k", 3), 3)
- # Missing key with non-numeric default
- self.assertEqual(_get_int_param({}, "missing", "zzz"), 0)
-
def test_transform_bulk_monotonicity_and_bounds(self):
"""Non-decreasing monotonicity & (-1,1) bounds for smooth transforms (excluding clip)."""
transforms = ["tanh", "softsign", "arctan", "sigmoid", "asinh"]
)
self.continual_learning = False
+ def pack_env_dict(
+ self, pair: str, model_params: Optional[Dict[str, Any]] = None
+ ) -> Dict[str, Any]:
+ env_info = super().pack_env_dict(pair)
+
+ config = env_info.setdefault("config", {})
+ freqai_cfg = config.setdefault("freqai", {})
+ rl_cfg = freqai_cfg.setdefault("rl_config", {})
+ model_reward_parameters = rl_cfg.setdefault("model_reward_parameters", {})
+
+ gamma: Optional[float] = None
+ best_trial_params: Optional[Dict[str, Any]] = None
+ if self.hyperopt:
+ best_trial_params = self.load_best_trial_params(pair)
+
+ if model_params and isinstance(model_params.get("gamma"), (int, float)):
+ gamma = float(model_params.get("gamma"))
+ elif best_trial_params and isinstance(
+ best_trial_params.get("gamma"), (int, float)
+ ):
+ gamma = float(best_trial_params.get("gamma"))
+ elif hasattr(self.model, "gamma") and isinstance(
+ self.model.gamma, (int, float)
+ ):
+ gamma = float(self.model.gamma)
+ else:
+ model_params_gamma = self.get_model_params().get("gamma")
+ if isinstance(model_params_gamma, (int, float)):
+ gamma = float(model_params_gamma)
+
+ if gamma is not None:
+ model_reward_parameters["potential_gamma"] = gamma
+ else:
+ logger.warning(
+ f"{pair}: No valid PBRS discount gamma resolved for environment"
+ )
+
+ return env_info
+
def set_train_and_eval_environments(
self,
data_dictionary: Dict[str, DataFrame],
train_df = data_dictionary.get("train_features")
test_df = data_dictionary.get("test_features")
env_dict = self.pack_env_dict(dk.pair)
- env_dict["config"]["freqai"]["rl_config"]["model_reward_parameters"][
- "potential_gamma"
- ] = self.get_model_params().get("gamma")
seed = self.get_model_params().get("seed", 42)
if self.check_envs:
self, data_dictionary: Dict[str, Any], dk: FreqaiDataKitchen, **kwargs
) -> Any:
"""
- User customizable fit method
+ Model fitting method
:param data_dictionary: dict = common data dictionary containing all train/test features/labels/weights.
:param dk: FreqaiDatakitchen = data kitchen for current pair.
:return:
seed += trial.number
set_random_seed(seed)
env_info: Dict[str, Any] = (
- self.pack_env_dict(dk.pair) if env_info is None else env_info
+ self.pack_env_dict(dk.pair, model_params) if env_info is None else env_info
)
- gamma: Optional[float] = None
- best_trial_params: Optional[Dict[str, Any]] = None
- if self.hyperopt:
- best_trial_params = self.load_best_trial_params(dk.pair)
- if model_params and isinstance(model_params.get("gamma"), (int, float)):
- gamma = model_params.get("gamma")
- elif best_trial_params:
- gamma = best_trial_params.get("gamma")
- elif hasattr(self.model, "gamma") and isinstance(
- self.model.gamma, (int, float)
- ):
- gamma = self.model.gamma
- elif isinstance(self.get_model_params().get("gamma"), (int, float)):
- gamma = self.get_model_params().get("gamma")
- if gamma is not None:
- # Align RL agent gamma with PBRS gamma for consistent discount factor
- env_info["config"]["freqai"]["rl_config"]["model_reward_parameters"][
- "potential_gamma"
- ] = float(gamma)
env_prefix = f"trial_{trial.number}_" if trial is not None else ""
train_fns = [