]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor(qav3): add a few input guards
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Sat, 2 Aug 2025 13:29:46 +0000 (15:29 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Sat, 2 Aug 2025 13:29:46 +0000 (15:29 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py
quickadapter/user_data/strategies/QuickAdapterV3.py

index e96ee617f48c2fa38ff9381a648b488e73b8b14e..96562a71078c83b9400620f0122c9203e483f2fc 100644 (file)
@@ -572,6 +572,8 @@ class QuickAdapterRegressorV3(BaseRegressionModel):
     def soft_extremum_min_max(
         pred_extrema: pd.Series, alpha: float
     ) -> tuple[float, float]:
+        if alpha < 0:
+            raise ValueError("alpha must be non-negative")
         pred_minima, pred_maxima = QuickAdapterRegressorV3.get_pred_min_max(
             pred_extrema
         )
@@ -1791,11 +1793,11 @@ def soft_extremum(series: pd.Series, alpha: float) -> float:
         return np.nan
     if np.isclose(alpha, 0):
         return np.mean(np_array)
-    scaled_data = alpha * np_array
-    max_scaled_data = np.max(scaled_data)
-    if np.isinf(max_scaled_data):
-        return np_array[np.argmax(scaled_data)]
-    shifted_exponentials = np.exp(scaled_data - max_scaled_data)
+    scaled_np_array = alpha * np_array
+    max_scaled_np_array = np.max(scaled_np_array)
+    if np.isinf(max_scaled_np_array):
+        return np_array[np.argmax(scaled_np_array)]
+    shifted_exponentials = np.exp(scaled_np_array - max_scaled_np_array)
     numerator = np.sum(np_array * shifted_exponentials)
     denominator = np.sum(shifted_exponentials)
     if denominator == 0:
index 7265b362a86673b5508a291c6562ccabc29e6782..a4bf4c045e16459b6e07430bdc5d47fd3ebb1271 100644 (file)
@@ -406,6 +406,10 @@ class QuickAdapterV3(IStrategy):
             self._label_params[pair]["label_natr_ratio"] = label_natr_ratio
 
     def get_label_natr_ratio_percent(self, pair: str, percent: float) -> float:
+        if not isinstance(percent, float) or not (0.0 <= percent <= 1.0):
+            raise ValueError(
+                f"Invalid percent value: {percent}. It should be a float between 0 and 1"
+            )
         return self.get_label_natr_ratio(pair) * percent
 
     @staticmethod