]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
fix(reforcexy): spot support at model training
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Fri, 26 Sep 2025 18:29:59 +0000 (20:29 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Fri, 26 Sep 2025 18:29:59 +0000 (20:29 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py
ReforceXY/user_data/strategies/RLAgentStrategy.py
quickadapter/user_data/strategies/QuickAdapterV3.py

index 9fa7d48ab625a892d834f54f219c3e81cc578eba..60cd86a1457569d66fde7846ec93639fdb43a290 100644 (file)
@@ -159,6 +159,15 @@ class ReforceXY(BaseReinforcementLearningModel):
         self._model_params_cache: Optional[Dict[str, Any]] = None
         self.unset_unsupported()
 
+    @staticmethod
+    def is_short_allowed(trading_mode: str) -> bool:
+        if trading_mode in {"margin", "futures"}:
+            return True
+        elif trading_mode == "spot":
+            return False
+        else:
+            raise ValueError(f"Invalid trading_mode: {trading_mode}")
+
     @staticmethod
     def _normalize_position(position: Any) -> Positions:
         if isinstance(position, Positions):
@@ -175,8 +184,11 @@ class ReforceXY(BaseReinforcementLearningModel):
 
     @staticmethod
     def get_action_masks(
-        position: Positions, force_action: Optional[ForceActions] = None
+        trading_mode: str,
+        position: Positions,
+        force_action: Optional[ForceActions] = None,
     ) -> NDArray[np.bool_]:
+        is_short_allowed = ReforceXY.is_short_allowed(trading_mode)
         position = ReforceXY._normalize_position(position)
 
         action_masks = np.zeros(len(Actions), dtype=np.bool_)
@@ -191,7 +203,8 @@ class ReforceXY(BaseReinforcementLearningModel):
         action_masks[Actions.Neutral.value] = True
         if position == Positions.Neutral:
             action_masks[Actions.Long_enter.value] = True
-            action_masks[Actions.Short_enter.value] = True
+            if is_short_allowed:
+                action_masks[Actions.Short_enter.value] = True
         elif position == Positions.Long:
             action_masks[Actions.Long_exit.value] = True
         elif position == Positions.Short:
@@ -615,9 +628,9 @@ class ReforceXY(BaseReinforcementLearningModel):
         :param model: Any = the trained model used to inference the features.
         """
 
-        simulated_position: Positions = Positions.Neutral
+        virtual_position: Positions = Positions.Neutral
 
-        def _update_simulated_position(action: int, position: Positions) -> Positions:
+        def _update_virtual_position(action: int, position: Positions) -> Positions:
             if action == Actions.Long_enter.value and position == Positions.Neutral:
                 return Positions.Long
             if action == Actions.Short_enter.value and position == Positions.Neutral:
@@ -668,7 +681,7 @@ class ReforceXY(BaseReinforcementLearningModel):
 
             if self.action_masking and self.inference_masking:
                 action_masks_param["action_masks"] = ReforceXY.get_action_masks(
-                    simulated_position
+                    self.config.get("trading_mode"), virtual_position
                 )
 
             action, _ = model.predict(
@@ -681,7 +694,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             window = dataframe.iloc[window_end - self.CONV_WIDTH : window_end]
             action = _predict(window)
             predicted_actions.append(action)
-            simulated_position = _update_simulated_position(action, simulated_position)
+            virtual_position = _update_virtual_position(action, virtual_position)
 
         pad = [np.nan] * (self.CONV_WIDTH - 1)
         actions_list = pad + predicted_actions
@@ -740,7 +753,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             else self.get_storage()
         )
         if "PPO" in self.model_type:
-            resource_eval_freq = max(PPO_N_STEPS)
+            resource_eval_freq = min(PPO_N_STEPS)
         else:
             resource_eval_freq = self.get_eval_freq(total_timesteps, hyperopt=True)
         reduction_factor = 3
@@ -955,13 +968,13 @@ class ReforceXY(BaseReinforcementLearningModel):
             train_env = DummyVecEnv(train_fns)
             eval_env = DummyVecEnv(eval_fns)
 
-        train_env = VecMonitor(train_env)
-        eval_env = VecMonitor(eval_env)
-
         if self.frame_stacking:
             train_env = VecFrameStack(train_env, n_stack=self.frame_stacking)
             eval_env = VecFrameStack(eval_env, n_stack=self.frame_stacking)
 
+        train_env = VecMonitor(train_env)
+        eval_env = VecMonitor(eval_env)
+
         return train_env, eval_env
 
     def objective(
@@ -1159,7 +1172,9 @@ class MyRLEnv(Base5ActionRLEnv):
         )
 
     def _is_valid(self, action: int) -> bool:
-        return ReforceXY.get_action_masks(self._position, self._force_action)[action]
+        return ReforceXY.get_action_masks(
+            self.config.get("trading_mode"), self._position, self._force_action
+        )[action]
 
     def reset_env(
         self,
@@ -1558,7 +1573,9 @@ class MyRLEnv(Base5ActionRLEnv):
         )
 
     def action_masks(self) -> NDArray[np.bool_]:
-        return ReforceXY.get_action_masks(self._position, self._force_action)
+        return ReforceXY.get_action_masks(
+            self.config.get("trading_mode"), self._position, self._force_action
+        )
 
     def get_feature_value(
         self,
@@ -1719,25 +1736,12 @@ class MyRLEnv(Base5ActionRLEnv):
                 right_index=True,
                 how="left",
             )
-        except Exception:
-            try:
-                _price_history = (
-                    self.prices.iloc[_rollout_history.tick]
-                    .copy()
-                    .reset_index(drop=True)
-                )
-                history = merge(
-                    _rollout_history,
-                    _price_history,
-                    left_index=True,
-                    right_index=True,
-                )
-            except Exception as e:
-                logger.error(
-                    f"Failed to merge history with prices: {repr(e)}",
-                    exc_info=True,
-                )
-                return DataFrame()
+        except Exception as e:
+            logger.error(
+                f"Failed to merge history with prices: {repr(e)}",
+                exc_info=True,
+            )
+            return DataFrame()
         return history
 
     def get_env_plot(self) -> plt.Figure:
index 4139cb362d47163cd0d61584bbd3149c4151f760..4555f1bbfe3772e65776659baeea8100817f4a01 100644 (file)
@@ -109,7 +109,7 @@ class RLAgentStrategy(IStrategy):
 
     def is_short_allowed(self) -> bool:
         trading_mode = self.config.get("trading_mode")
-        if trading_mode == "margin" or trading_mode == "futures":
+        if trading_mode in {"margin", "futures"}:
             return True
         elif trading_mode == "spot":
             return False
index dbac8dd7f0d1f41c272216276c43baef2fb69c79..e48c2a3c236524ce0991f79af185b60e6ad2402f 100644 (file)
@@ -1574,7 +1574,7 @@ class QuickAdapterV3(IStrategy):
 
     def is_short_allowed(self) -> bool:
         trading_mode = self.config.get("trading_mode")
-        if trading_mode == "margin" or trading_mode == "futures":
+        if trading_mode in {"margin", "futures"}:
             return True
         elif trading_mode == "spot":
             return False