]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor(reforcexy): cleanup frame stacking implementation
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Wed, 17 Sep 2025 18:38:11 +0000 (20:38 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Wed, 17 Sep 2025 18:38:11 +0000 (20:38 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py
quickadapter/user_data/strategies/QuickAdapterV3.py

index ee4cf360e7ec439206b6d34c87bdb1e2f24b1efa..0bb58121f3fc89c3c835276854a26ec27cb6545c 100644 (file)
@@ -543,6 +543,8 @@ class ReforceXY(BaseReinforcementLearningModel):
         def _action_masks(position: Any) -> list[bool]:
             return [_is_valid(action.value, position) for action in Actions]
 
+        frame_buffer: list[np.ndarray] = []
+
         def _predict(window) -> int:
             observation: DataFrame = dataframe.iloc[window.index]
             action_masks_param: Dict[str, Any] = {}
@@ -565,17 +567,16 @@ class ReforceXY(BaseReinforcementLearningModel):
 
             np_observation = observation.to_numpy(dtype=np.float32)
 
+            fb: list[np.ndarray] = frame_buffer
             frame_stacking = self.frame_stacking
             if frame_stacking and frame_stacking > 1:
-                if not hasattr(_predict, "_frame_buffer"):
-                    _predict._frame_buffer = []
-                fb: list[np.ndarray] = getattr(_predict, "_frame_buffer")
-                fb.append(np_observation)
+                fb.append(np_observation.copy())
                 if len(fb) > frame_stacking:
                     del fb[0 : len(fb) - frame_stacking]
                 if len(fb) < frame_stacking:
-                    pad_needed = frame_stacking - len(fb)
-                    fb_padded = [fb[0]] * pad_needed + fb
+                    pad_count = frame_stacking - len(fb)
+                    pad_frame = fb[0] if fb else np_observation
+                    fb_padded = [pad_frame] * pad_count + fb
                 else:
                     fb_padded = fb
                 stacked_observations = np.concatenate(fb_padded, axis=1)
@@ -838,8 +839,8 @@ class ReforceXY(BaseReinforcementLearningModel):
             logger.warning("Optuna encountered NaN (AssertionError)")
             nan_encountered = True
         except ValueError as e:
-            if "nan" in str(e).lower():
-                logger.warning("Optuna encountered NaN (ValueError)")
+            if any(x in str(e).lower() for x in ("nan", "inf")):
+                logger.warning("Optuna encountered NaN/Inf (ValueError): %s", e)
                 nan_encountered = True
             else:
                 raise
@@ -847,7 +848,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             logger.warning("Optuna encountered NaN/Inf (FloatingPointError): %s", e)
             nan_encountered = True
         except RuntimeError as e:
-            if "nan" in str(e).lower() or "inf" in str(e).lower():
+            if any(x in str(e).lower() for x in ("nan", "inf")):
                 logger.warning("Optuna encountered NaN/Inf (RuntimeError): %s", e)
                 nan_encountered = True
             else:
@@ -1557,6 +1558,19 @@ class InfoMetricsCallback(TensorboardCallback):
         )
 
     def _on_step(self) -> bool:
+        def _is_numeric_non_bool(x: Any) -> bool:
+            return isinstance(
+                x, (int, float, np.integer, np.floating)
+            ) and not isinstance(x, bool)
+
+        def _is_finite_number(x: Any) -> bool:
+            if not _is_numeric_non_bool(x):
+                return False
+            try:
+                return np.isfinite(float(x))
+            except Exception:
+                return False
+
         infos_list: list[Dict[str, Any]] | None = self.locals.get("infos")
         aggregated_info: Dict[str, Any] = {}
 
@@ -1567,19 +1581,6 @@ class InfoMetricsCallback(TensorboardCallback):
             )
             filtered_values: int = 0
 
-            def _is_numeric_non_bool(x: Any) -> bool:
-                return isinstance(
-                    x, (int, float, np.integer, np.floating)
-                ) and not isinstance(x, bool)
-
-            def _is_finite_number(x: Any) -> bool:
-                if not _is_numeric_non_bool(x):
-                    return False
-                try:
-                    return np.isfinite(float(x))
-                except Exception:
-                    return False
-
             for info_dict in infos_list:
                 if not isinstance(info_dict, dict):
                     continue
@@ -1608,15 +1609,16 @@ class InfoMetricsCallback(TensorboardCallback):
                 values = numeric_acc.get(key)
                 if values:
                     try:
-                        aggregated_info[f"{key}_min"] = float(min(values))
-                        aggregated_info[f"{key}_max"] = float(max(values))
-                        percentiles = np.percentile(values, [25, 50, 75, 90])
+                        np_values = np.asarray(values, dtype=float)
+                        aggregated_info[f"{key}_min"] = float(np.min(np_values))
+                        aggregated_info[f"{key}_max"] = float(np.max(np_values))
+                        percentiles = np.percentile(np_values, [25, 50, 75, 90])
                         aggregated_info[f"{key}_p25"] = float(percentiles[0])
                         aggregated_info[f"{key}_p50"] = float(percentiles[1])
                         aggregated_info[f"{key}_p75"] = float(percentiles[2])
                         aggregated_info[f"{key}_p90"] = float(percentiles[3])
                         med = float(percentiles[1])
-                        mad = float(np.median(np.abs(np.array(values) - med)))
+                        mad = float(np.median(np.abs(np_values - med)))
                         aggregated_info[f"{key}_mad"] = mad
                     except Exception:
                         pass
@@ -1680,12 +1682,7 @@ class InfoMetricsCallback(TensorboardCallback):
                         cat_dict[metric] = base + v
                         cnt_dict[metric] = cnt_dict.get(metric, 0) + 1
                     else:
-                        if (
-                            aggregated_tensorboard_metric_counts.get(category, {}).get(
-                                metric, 0
-                            )
-                            == 0
-                        ):
+                        if cnt_dict.get(metric, 0) == 0:
                             cat_dict[metric] = value
 
         for metric, value in aggregated_info.items():
@@ -1893,6 +1890,7 @@ def deepmerge(dst: Dict[str, Any], src: Dict[str, Any]) -> Dict[str, Any]:
     return dst_copy
 
 
+@lru_cache(maxsize=64)
 def linear_schedule(initial_value: float) -> Callable[[float], float]:
     def func(progress_remaining: float) -> float:
         return progress_remaining * initial_value
@@ -1900,6 +1898,13 @@ def linear_schedule(initial_value: float) -> Callable[[float], float]:
     return func
 
 
+@lru_cache(maxsize=128)
+def _compute_gradient_steps(tf: int, ss: int) -> int:
+    if tf > 0 and ss > 0:
+        return min(tf, max(tf // ss, 1))
+    return -1
+
+
 def compute_gradient_steps(train_freq: Any, subsample_steps: Any) -> int:
     tf: Optional[int] = None
     if isinstance(train_freq, (tuple, list)) and train_freq:
@@ -1909,8 +1914,8 @@ def compute_gradient_steps(train_freq: Any, subsample_steps: Any) -> int:
 
     ss: Optional[int] = subsample_steps if isinstance(subsample_steps, int) else None
 
-    if isinstance(tf, int) and tf > 0 and isinstance(ss, int) and ss > 0:
-        return min(tf, max(tf // ss, 1))
+    if isinstance(tf, int) and isinstance(ss, int):
+        return _compute_gradient_steps(tf, ss)
     return -1
 
 
index e50a6f50bda28f91453b3767c3502defb24c27b3..3b682e604feae6b08b52c2212da0f02d40d64f28 100644 (file)
@@ -1117,8 +1117,8 @@ class QuickAdapterV3(IStrategy):
         side: str,
         order: Literal["entry", "exit"],
         rate: float,
-        min_natr_ratio_percent: float = 0.0085,
-        max_natr_ratio_percent: float = 0.085,
+        min_natr_ratio_percent: float = 0.0075,
+        max_natr_ratio_percent: float = 0.075,
         lookback_period: int = 1,
         decay_ratio: float = 0.5,
     ) -> bool: