From 42dda4ed870848916cdf392ac54dd792b543638e Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Sat, 2 Aug 2025 15:29:46 +0200 Subject: [PATCH] refactor(qav3): add a few input guards MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Jérôme Benoit --- .../freqaimodels/QuickAdapterRegressorV3.py | 12 +++++++----- quickadapter/user_data/strategies/QuickAdapterV3.py | 4 ++++ 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py index e96ee61..96562a7 100644 --- a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py +++ b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py @@ -572,6 +572,8 @@ class QuickAdapterRegressorV3(BaseRegressionModel): def soft_extremum_min_max( pred_extrema: pd.Series, alpha: float ) -> tuple[float, float]: + if alpha < 0: + raise ValueError("alpha must be non-negative") pred_minima, pred_maxima = QuickAdapterRegressorV3.get_pred_min_max( pred_extrema ) @@ -1791,11 +1793,11 @@ def soft_extremum(series: pd.Series, alpha: float) -> float: return np.nan if np.isclose(alpha, 0): return np.mean(np_array) - scaled_data = alpha * np_array - max_scaled_data = np.max(scaled_data) - if np.isinf(max_scaled_data): - return np_array[np.argmax(scaled_data)] - shifted_exponentials = np.exp(scaled_data - max_scaled_data) + scaled_np_array = alpha * np_array + max_scaled_np_array = np.max(scaled_np_array) + if np.isinf(max_scaled_np_array): + return np_array[np.argmax(scaled_np_array)] + shifted_exponentials = np.exp(scaled_np_array - max_scaled_np_array) numerator = np.sum(np_array * shifted_exponentials) denominator = np.sum(shifted_exponentials) if denominator == 0: diff --git a/quickadapter/user_data/strategies/QuickAdapterV3.py b/quickadapter/user_data/strategies/QuickAdapterV3.py index 7265b36..a4bf4c0 100644 --- a/quickadapter/user_data/strategies/QuickAdapterV3.py +++ b/quickadapter/user_data/strategies/QuickAdapterV3.py @@ -406,6 +406,10 @@ class QuickAdapterV3(IStrategy): self._label_params[pair]["label_natr_ratio"] = label_natr_ratio def get_label_natr_ratio_percent(self, pair: str, percent: float) -> float: + if not isinstance(percent, float) or not (0.0 <= percent <= 1.0): + raise ValueError( + f"Invalid percent value: {percent}. It should be a float between 0 and 1" + ) return self.get_label_natr_ratio(pair) * percent @staticmethod -- 2.43.0