]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor(reforcexy): more typing
authorJérôme Benoit <jerome.benoit@sap.com>
Fri, 21 Feb 2025 23:08:34 +0000 (00:08 +0100)
committerJérôme Benoit <jerome.benoit@sap.com>
Fri, 21 Feb 2025 23:08:34 +0000 (00:08 +0100)
Signed-off-by: Jérôme Benoit <jerome.benoit@sap.com>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index e57270f212b6f6069ccb49e6e918abf017e13c5d..3f0bd927b7b71bff86879aebc8e812aa5e434588 100644 (file)
@@ -467,7 +467,9 @@ class ReforceXY(BaseReinforcementLearningModel):
             )
         return storage
 
-    def study(self, train_df, total_timesteps: int, dk: FreqaiDataKitchen) -> Dict:
+    def study(
+        self, train_df: DataFrame, total_timesteps: int, dk: FreqaiDataKitchen
+    ) -> Dict:
         """
         Runs hyperparameter optimization using Optuna and
         returns the best hyperparameters found
@@ -549,7 +551,11 @@ class ReforceXY(BaseReinforcementLearningModel):
         return None
 
     def objective(
-        self, trial: Trial, train_df, total_timesteps: int, dk: FreqaiDataKitchen
+        self,
+        trial: Trial,
+        train_df: DataFrame,
+        total_timesteps: int,
+        dk: FreqaiDataKitchen,
     ) -> float:
         """
         Defines a single trial for hyperparameter optimization using Optuna
@@ -676,7 +682,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             self._non_profit_steps: int = 0
             return self._get_observation(), history
 
-        def get_reward_factor_at_trade_exit(
+        def _get_reward_factor_at_trade_exit(
             self,
             factor: float,
             pnl: float,
@@ -730,7 +736,7 @@ class ReforceXY(BaseReinforcementLearningModel):
                 ForceActions.Stop_loss,
                 ForceActions.Timeout,
             ):
-                return pnl * self.get_reward_factor_at_trade_exit(
+                return pnl * self._get_reward_factor_at_trade_exit(
                     factor, pnl, trade_duration, max_trade_duration
                 )
 
@@ -784,13 +790,13 @@ class ReforceXY(BaseReinforcementLearningModel):
 
             # close long
             if action == Actions.Long_exit.value and self._position == Positions.Long:
-                return pnl * self.get_reward_factor_at_trade_exit(
+                return pnl * self._get_reward_factor_at_trade_exit(
                     factor, pnl, trade_duration, max_trade_duration
                 )
 
             # close short
             if action == Actions.Short_exit.value and self._position == Positions.Short:
-                return pnl * self.get_reward_factor_at_trade_exit(
+                return pnl * self._get_reward_factor_at_trade_exit(
                     factor, pnl, trade_duration, max_trade_duration
                 )