]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
fix(reforcexy): idle duration computation
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Tue, 16 Sep 2025 19:02:07 +0000 (21:02 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Tue, 16 Sep 2025 19:02:07 +0000 (21:02 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index d8b099bea70a8097602348e5d825cec92b4e227f..d8417d50a56f04c50904e586020f13751110de5d 100644 (file)
@@ -549,7 +549,6 @@ class ReforceXY(BaseReinforcementLearningModel):
                 if self.live:
                     position, pnl, trade_duration = self.get_state_info(dk.pair)
                 else:
-                    # TODO: state info handling for backtests
                     position = Positions.Neutral
                     pnl = 0.0
                     trade_duration = 0
@@ -1123,21 +1122,21 @@ class ReforceXY(BaseReinforcementLearningModel):
                 return ForceActions.Stop_loss
             return None
 
-        def _get_new_position(self, action: int) -> Positions:
+        def _get_position(self, action: int) -> Positions:
             return {
                 Actions.Long_enter.value: Positions.Long,
                 Actions.Short_enter.value: Positions.Short,
             }[action]
 
         def _enter_trade(self, action: int) -> None:
-            self._position = self._get_new_position(action)
+            self._position = self._get_position(action)
             self._last_trade_tick = self._current_tick
 
         def _exit_trade(self) -> None:
             self._update_total_profit()
             self._last_closed_position = self._position
             self._position = Positions.Neutral
-            self._last_closed_trade_tick = self._last_trade_tick
+            self._last_closed_trade_tick = self._current_tick
             self._last_trade_tick = None
 
         def execute_trade(self, action: int) -> None:
@@ -1172,12 +1171,14 @@ class ReforceXY(BaseReinforcementLearningModel):
             """
             Take a step in the environment based on the provided action
             """
-            self.tensorboard_log(Actions._member_names_[action], category="actions")
             self._current_tick += 1
             self._update_unrealized_total_profit()
             self._force_action = self._get_force_action()
             reward = self.calculate_reward(action)
             self.total_reward += reward
+            self.tensorboard_log(Actions._member_names_[action], category="actions")
+            self.execute_trade(action)
+            self._position_history.append(self._position)
             info = {
                 "tick": self._current_tick,
                 "position": self._position.value,
@@ -1193,12 +1194,6 @@ class ReforceXY(BaseReinforcementLearningModel):
                 "trade_duration": self.get_trade_duration(),
                 "trade_count": len(self.trade_history),
             }
-            self.execute_trade(action)
-            info["position"] = self._position.value
-            info["pnl"] = self.get_unrealized_profit()
-            info["trade_duration"] = self.get_trade_duration()
-            info["trade_count"] = len(self.trade_history)
-            self._position_history.append(self._position)
             self._update_history(info)
             return (
                 self._get_observation(),
@@ -1482,15 +1477,15 @@ class InfoMetricsCallback(TensorboardCallback):
             float(_lr) if isinstance(_lr, (int, float, np.floating)) else "lr_schedule"
         )
         n_stack = 1
-        training_env = getattr(self, "training_env", None)
-        while training_env is not None:
-            if hasattr(training_env, "n_stack"):
+        env = getattr(self, "training_env", None)
+        while env is not None:
+            if hasattr(env, "n_stack"):
                 try:
-                    n_stack = int(getattr(training_env, "n_stack"))
+                    n_stack = int(getattr(env, "n_stack"))
                 except Exception:
                     pass
                 break
-            training_env = getattr(training_env, "venv", None)
+            env = getattr(env, "venv", None)
         hparam_dict: Dict[str, Any] = {
             "algorithm": self.model.__class__.__name__,
             "n_envs": int(self.model.n_envs),