From: Jérôme Benoit Date: Sat, 2 Aug 2025 13:29:46 +0000 (+0200) Subject: refactor(qav3): add a few input guards X-Git-Url: https://git.piment-noir.org/?a=commitdiff_plain;h=42dda4ed870848916cdf392ac54dd792b543638e;p=freqai-strategies.git refactor(qav3): add a few input guards Signed-off-by: Jérôme Benoit --- diff --git a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py index e96ee61..96562a7 100644 --- a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py +++ b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py @@ -572,6 +572,8 @@ class QuickAdapterRegressorV3(BaseRegressionModel): def soft_extremum_min_max( pred_extrema: pd.Series, alpha: float ) -> tuple[float, float]: + if alpha < 0: + raise ValueError("alpha must be non-negative") pred_minima, pred_maxima = QuickAdapterRegressorV3.get_pred_min_max( pred_extrema ) @@ -1791,11 +1793,11 @@ def soft_extremum(series: pd.Series, alpha: float) -> float: return np.nan if np.isclose(alpha, 0): return np.mean(np_array) - scaled_data = alpha * np_array - max_scaled_data = np.max(scaled_data) - if np.isinf(max_scaled_data): - return np_array[np.argmax(scaled_data)] - shifted_exponentials = np.exp(scaled_data - max_scaled_data) + scaled_np_array = alpha * np_array + max_scaled_np_array = np.max(scaled_np_array) + if np.isinf(max_scaled_np_array): + return np_array[np.argmax(scaled_np_array)] + shifted_exponentials = np.exp(scaled_np_array - max_scaled_np_array) numerator = np.sum(np_array * shifted_exponentials) denominator = np.sum(shifted_exponentials) if denominator == 0: diff --git a/quickadapter/user_data/strategies/QuickAdapterV3.py b/quickadapter/user_data/strategies/QuickAdapterV3.py index 7265b36..a4bf4c0 100644 --- a/quickadapter/user_data/strategies/QuickAdapterV3.py +++ b/quickadapter/user_data/strategies/QuickAdapterV3.py @@ -406,6 +406,10 @@ class QuickAdapterV3(IStrategy): self._label_params[pair]["label_natr_ratio"] = label_natr_ratio def get_label_natr_ratio_percent(self, pair: str, percent: float) -> float: + if not isinstance(percent, float) or not (0.0 <= percent <= 1.0): + raise ValueError( + f"Invalid percent value: {percent}. It should be a float between 0 and 1" + ) return self.get_label_natr_ratio(pair) * percent @staticmethod