]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor: raise error is no pairs are defined main
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Sat, 22 Feb 2025 12:13:40 +0000 (13:13 +0100)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Sat, 22 Feb 2025 12:13:40 +0000 (13:13 +0100)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py
quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py
quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py

index 3f0bd927b7b71bff86879aebc8e812aa5e434588..d250b95940e5e74e0df37cd3ce1d52c3d0275d6c 100644 (file)
@@ -100,6 +100,10 @@ class ReforceXY(BaseReinforcementLearningModel):
     def __init__(self, **kwargs):
         super().__init__(**kwargs)
         self.pairs = self.config.get("exchange", {}).get("pair_whitelist")
     def __init__(self, **kwargs):
         super().__init__(**kwargs)
         self.pairs = self.config.get("exchange", {}).get("pair_whitelist")
+        if not self.pairs:
+            raise ValueError(
+                "FreqAI model requires StaticPairList method defined in pairlists configuration and pair_whitelist defined in exchange section configuration"
+            )
         self.is_maskable: bool = (
             self.model_type == "MaskablePPO"
         )  # Enable action masking
         self.is_maskable: bool = (
             self.model_type == "MaskablePPO"
         )  # Enable action masking
@@ -697,8 +701,10 @@ class ReforceXY(BaseReinforcementLearningModel):
             elif trade_duration > max_trade_duration:
                 factor *= 0.5
             if pnl > self.profit_aim * self.rr:
             elif trade_duration > max_trade_duration:
                 factor *= 0.5
             if pnl > self.profit_aim * self.rr:
-                factor *= self.rl_config.get("model_reward_parameters", {}).get(
-                    "win_reward_factor", 2
+                factor *= float(
+                    self.rl_config.get("model_reward_parameters", {}).get(
+                        "win_reward_factor", 2.0
+                    )
                 )
             return factor
 
                 )
             return factor
 
@@ -719,7 +725,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             """
             # first, penalize if the action is not valid
             if not self._force_action and not self._is_valid(action):
             """
             # first, penalize if the action is not valid
             if not self._force_action and not self._is_valid(action):
-                return -2
+                return -2.0
 
             pnl = self.get_unrealized_profit()
             # mrr = self.get_most_recent_return()
 
             pnl = self.get_unrealized_profit()
             # mrr = self.get_most_recent_return()
@@ -763,16 +769,20 @@ class ReforceXY(BaseReinforcementLearningModel):
                 action == Actions.Long_enter.value
                 and self._position == Positions.Neutral
             ):
                 action == Actions.Long_enter.value
                 and self._position == Positions.Neutral
             ):
-                return 25
+                return 25.0
             if (
                 action == Actions.Short_enter.value
                 and self._position == Positions.Neutral
             ):
             if (
                 action == Actions.Short_enter.value
                 and self._position == Positions.Neutral
             ):
-                return 25
+                return 25.0
 
             # discourage agent from not entering trades
             if action == Actions.Neutral.value and self._position == Positions.Neutral:
 
             # discourage agent from not entering trades
             if action == Actions.Neutral.value and self._position == Positions.Neutral:
-                return -1
+                return float(
+                    self.rl_config.get("model_reward_parameters", {}).get(
+                        "inaction", -1.0
+                    )
+                )
 
             # discourage sitting in position
             if (
 
             # discourage sitting in position
             if (
index f213deca48e35bfdb1e7fa695691bc25b7d3499a..7ca72fba300bb74ba85d4649b2df8280f720aef9 100644 (file)
@@ -46,6 +46,10 @@ class LightGBMRegressorQuickAdapterV35(BaseRegressionModel):
     def __init__(self, **kwargs):
         super().__init__(**kwargs)
         self.pairs = self.config.get("exchange", {}).get("pair_whitelist")
     def __init__(self, **kwargs):
         super().__init__(**kwargs)
         self.pairs = self.config.get("exchange", {}).get("pair_whitelist")
+        if not self.pairs:
+            raise ValueError(
+                "FreqAI model requires StaticPairList method defined in pairlists configuration and pair_whitelist defined in exchange section configuration"
+            )
         self.__optuna_config = self.freqai_info.get("optuna_hyperopt", {})
         self.__optuna_hyperopt: bool = (
             self.freqai_info.get("enabled", False)
         self.__optuna_config = self.freqai_info.get("optuna_hyperopt", {})
         self.__optuna_hyperopt: bool = (
             self.freqai_info.get("enabled", False)
index 8e4a3a209a520be49891a17c9e8c7bacf4320ef2..ebcb25b9f0a716b2daf64b731ac74f2e14839da3 100644 (file)
@@ -46,6 +46,10 @@ class XGBoostRegressorQuickAdapterV35(BaseRegressionModel):
     def __init__(self, **kwargs):
         super().__init__(**kwargs)
         self.pairs = self.config.get("exchange", {}).get("pair_whitelist")
     def __init__(self, **kwargs):
         super().__init__(**kwargs)
         self.pairs = self.config.get("exchange", {}).get("pair_whitelist")
+        if not self.pairs:
+            raise ValueError(
+                "FreqAI model requires StaticPairList method defined in pairlists configuration and pair_whitelist defined in exchange section configuration"
+            )
         self.__optuna_config = self.freqai_info.get("optuna_hyperopt", {})
         self.__optuna_hyperopt: bool = (
             self.freqai_info.get("enabled", False)
         self.__optuna_config = self.freqai_info.get("optuna_hyperopt", {})
         self.__optuna_hyperopt: bool = (
             self.freqai_info.get("enabled", False)