DEFAULTS_LABEL_SMOOTHING,
DEFAULTS_LABEL_WEIGHTING,
EXTREMA_SELECTION_METHODS,
+ FILL_EPSILON_BASELINES,
+ FILL_METHODS,
NORMALIZATION_TYPES,
PREDICTION_METHODS,
SMOOTHING_METHODS,
"softmax_temperature": _ParamSpec(
_NumericValidator(min_value=0, min_exclusive=True)
),
+ "fill_method": _ParamSpec(_EnumValidator(FILL_METHODS)),
+ "fill_epsilon": _ParamSpec(
+ _NumericValidator(min_value=0.0, max_value=1.0), output_type=float
+ ),
+ "fill_epsilon_baseline": _ParamSpec(_EnumValidator(FILL_EPSILON_BASELINES)),
+ "fill_sigma_candles": _ParamSpec(
+ _NumericValidator(min_value=0.5), output_type=float
+ ),
}
_PIPELINE_SPECS: Final[dict[str, _ParamSpec]] = {
return arr
safe = np.where(np.isfinite(arr) & (arr > 0.0), arr, 0.0)
if drop_mask is not None:
+ drop_mask = np.asarray(drop_mask)
+ if drop_mask.shape != arr.shape:
+ raise ValueError(
+ f"sanitize_and_renormalize: drop_mask shape "
+ f"{drop_mask.shape} != arr shape {arr.shape}"
+ )
+ if not np.issubdtype(drop_mask.dtype, np.bool_):
+ raise ValueError(
+ f"sanitize_and_renormalize: drop_mask dtype "
+ f"{drop_mask.dtype} is not boolean"
+ )
safe = np.where(drop_mask, 0.0, safe)
total = safe.sum()
+ rescale_overflow = False
if total > 0.0 and np.isfinite(total):
- return safe * (n / total)
+ c = n / total
+ if np.isfinite(c):
+ return safe * c
+ rescale_overflow = True
if logger is not None:
- logger.warning(
- "sanitize_and_renormalize: weights collapsed (context=%s, "
- "total=%r, n=%d); falling back to uniform weights",
- context or "unspecified",
- total,
- n,
- )
+ if rescale_overflow:
+ logger.warning(
+ "sanitize_and_renormalize: rescale factor non-finite "
+ "(context=%s, n=%d, total=%r); falling back to uniform "
+ "weights",
+ context or "unspecified",
+ n,
+ total,
+ )
+ else:
+ logger.warning(
+ "sanitize_and_renormalize: weights collapsed (context=%s, "
+ "total=%r, n=%d); falling back to uniform weights",
+ context or "unspecified",
+ total,
+ n,
+ )
fallback = np.ones(n, dtype=float)
if drop_mask is not None:
masked = np.where(drop_mask, 0.0, fallback)
return fallback
+_PIVOT_EQUIVALENT_MAX_FRACTION: Final[float] = 0.1
+
+
+def _pivot_equivalent_count(
+ label_weights: NDArray[np.floating],
+ drop_mask: NDArray[np.bool_],
+) -> int:
+ """Count rows whose label weight is at least a fraction of the surviving max.
+
+ A max-relative threshold (``_PIVOT_EQUIVALENT_MAX_FRACTION``) separates
+ pivot-class rows from off-pivot fill across the bimodal regimes that
+ ``fill_method`` introduces (where a median-based threshold would
+ saturate at ``N`` once the off-pivot floor dominates the median).
+ """
+ survivors = label_weights[~drop_mask]
+ if survivors.size == 0:
+ return 0
+ threshold = _PIVOT_EQUIVALENT_MAX_FRACTION * float(survivors.max())
+ if threshold <= 0.0:
+ return 0
+ return int((survivors >= threshold).sum())
+
+
def compose_sample_weights(
base_weights: NDArray[np.floating],
label_weights: NDArray[np.floating] | None,
"compose_sample_weights: all rows dropped by zero or non-finite "
"label weights; no surviving training samples"
)
- nonzero = int((~drop_mask).sum())
+ nonzero = _pivot_equivalent_count(arr, drop_mask)
if nonzero / n < SPARSE_TRAINING_MASS_THRESHOLD:
logger.warning(
"compose_sample_weights: sparse training mass "
- "(%d/%d rows = %.2f%% nonzero, threshold=%.2f%%)",
+ "(%d/%d rows above %.0f%% of surviving max = %.2f%%, "
+ "threshold=%.2f%%)",
nonzero,
n,
+ 100.0 * _PIVOT_EQUIVALENT_MAX_FRACTION,
100.0 * nonzero / n,
100.0 * SPARSE_TRAINING_MASS_THRESHOLD,
)
return weights
+_GAUSSIAN_FILL_CHUNK_BUDGET: Final[int] = 50_000_000
+_GAUSSIAN_FILL_DENSITY_WARN: Final[float] = 0.1
+
+
+def _gaussian_fill_weights(
+ n_values: int,
+ pivot_indices: NDArray[np.integer],
+ pivot_weights: NDArray[np.floating],
+ sigma_candles: float,
+ *,
+ logger: Logger | None = None,
+) -> NDArray[np.floating]:
+ """Per-row max of Gaussian-decayed pivot weights.
+
+ Out[i] = max over p of ``w_p * exp(-(i - p)**2 / (2 * sigma**2))``.
+ With clustered pivots within ``~sigma_candles``, the per-row max
+ lets a stronger neighbor dominate weaker ones; pick
+ ``sigma_candles <= label_period_candles / 2`` to preserve pivot
+ identity.
+ """
+ if sigma_candles < 0.5:
+ raise ValueError(
+ f"Invalid sigma_candles value {sigma_candles!r}: must be >= 0.5"
+ )
+ if pivot_indices.size == 0:
+ return np.zeros(n_values, dtype=float)
+ if np.any(pivot_weights < 0.0):
+ raise ValueError(
+ f"Invalid pivot_weights min={float(pivot_weights.min())!r}: "
+ f"must be >= 0"
+ )
+ pivot_indices_array = pivot_indices.astype(float)
+ pivot_weights_row = pivot_weights.astype(float)[np.newaxis, :]
+ inv_two_sigma_sq = 0.5 / (sigma_candles * sigma_candles)
+ M = pivot_indices_array.size
+ if (
+ logger is not None
+ and n_values > 0
+ and M / n_values > _GAUSSIAN_FILL_DENSITY_WARN
+ ):
+ logger.warning(
+ "gaussian_fill: pivot density M/N=%.3f > %.2f (M=%d, N=%d); "
+ "consider tightening zigzag detection",
+ M / n_values,
+ _GAUSSIAN_FILL_DENSITY_WARN,
+ M,
+ n_values,
+ )
+ chunk = max(1, _GAUSSIAN_FILL_CHUNK_BUDGET // max(M, 1))
+ if logger is not None and chunk < n_values:
+ logger.debug(
+ "gaussian_fill: N=%d, M=%d, chunk=%d, ~%.0f MB peak buffer",
+ n_values,
+ M,
+ chunk,
+ chunk * M * 8 / 1e6,
+ )
+ out = np.zeros(n_values, dtype=float)
+ for start in range(0, n_values, chunk):
+ stop = min(start + chunk, n_values)
+ positions = np.arange(start, stop, dtype=float)
+ buf = positions[:, np.newaxis] - pivot_indices_array[np.newaxis, :]
+ np.multiply(buf, buf, out=buf)
+ np.multiply(buf, -inv_two_sigma_sq, out=buf)
+ np.exp(buf, out=buf)
+ np.multiply(buf, pivot_weights_row, out=buf)
+ np.max(buf, axis=1, out=out[start:stop])
+ return out
+
+
def _scatter_weights(
n_values: int,
indices: list[int],
weights: NDArray[np.floating],
- default_weight: float,
+ fill_weights: NDArray[np.floating],
+ *,
+ indices_array: NDArray[np.integer] | None = None,
+ valid_mask: NDArray[np.bool_] | None = None,
) -> NDArray[np.floating]:
"""Scatter per-pivot weights into a full-length array.
- Non-pivot rows are filled with ``default_weight``. Callers pass ``0.0``
- to exclude non-pivot rows from training (pivot-only weighting), or a
- positive value to give them a baseline weight.
+ Pivot rows (validated via ``valid_mask``) receive ``weights``; off-pivot
+ rows receive the corresponding entry of ``fill_weights`` (shape
+ ``(n_values,)``). Callers may pre-compute ``indices_array`` and
+ ``valid_mask`` and pass them in to avoid recomputation when the dispatch
+ needs the same mask for both filtered pivot extraction and the scatter.
"""
+ if fill_weights.shape != (n_values,):
+ raise ValueError(
+ f"Invalid fill_weights shape {fill_weights.shape!r}: "
+ f"must be ({n_values},)"
+ )
+ # Empty-input early return precedes the length-mismatch check on purpose.
if len(indices) == 0 or weights.size == 0:
- return np.full(n_values, default_weight, dtype=float)
-
+ return fill_weights.astype(float, copy=True)
if len(indices) != weights.size:
raise ValueError(
- f"Invalid indices/weights values: length mismatch, got {len(indices)} indices but {weights.size} weights"
- )
-
- weights_array = np.full(n_values, default_weight, dtype=float)
-
- indices_array = np.array(indices)
- mask = (indices_array >= 0) & (indices_array < n_values)
-
- if not np.any(mask):
+ f"Invalid indices/weights values: length mismatch, "
+ f"got {len(indices)} indices but {weights.size} weights"
+ )
+ if indices_array is None:
+ indices_array = np.asarray(indices, dtype=int)
+ if valid_mask is None:
+ valid_mask = (indices_array >= 0) & (indices_array < n_values)
+ weights_array = fill_weights.astype(float, copy=True)
+ if not np.any(valid_mask):
return weights_array
-
- valid_indices = indices_array[mask]
- weights_array[valid_indices] = weights[mask]
+ weights_array[indices_array[valid_mask]] = weights[valid_mask]
return weights_array
indices: list[int],
metrics: dict[str, list[float]],
weighting_config: dict[str, Any],
+ *,
+ logger: Logger,
) -> NDArray[np.floating]:
"""Compute per-row label importance weights.
Returns an array with positive values at pivot ``indices`` (scaled by
- strategy) and ``0.0`` elsewhere. Callers must skip invocation when
- strategy is ``'none'``; this raises ValueError otherwise.
+ strategy) and off-pivot values controlled by ``fill_method``. Callers
+ must skip invocation when strategy is ``'none'``; this raises
+ ValueError otherwise.
"""
label_weighting = {**DEFAULTS_LABEL_WEIGHTING, **weighting_config}
strategy = label_weighting["strategy"]
weights=weights,
)
+ if weights.size == 0:
+ return np.zeros(n_values, dtype=float)
+
+ indices_array = np.asarray(indices, dtype=int)
+ valid_mask = (indices_array >= 0) & (indices_array < n_values)
+
+ fill_method = label_weighting["fill_method"]
+
+ if fill_method == FILL_METHODS[0]: # "zero"
+ fill_weights = np.zeros(n_values, dtype=float)
+ elif fill_method == FILL_METHODS[1]: # "epsilon"
+ eps = label_weighting["fill_epsilon"]
+ baseline = label_weighting["fill_epsilon_baseline"]
+ if valid_mask.any():
+ pivot_values = weights[valid_mask]
+ if baseline == FILL_EPSILON_BASELINES[0]: # "mean"
+ pivot_baseline = float(np.nanmean(pivot_values))
+ elif baseline == FILL_EPSILON_BASELINES[1]: # "median"
+ pivot_baseline = float(np.nanmedian(pivot_values))
+ else:
+ raise ValueError(
+ f"Invalid fill_epsilon_baseline value {baseline!r}"
+ )
+ if not np.isfinite(pivot_baseline):
+ pivot_baseline = 0.0
+ else:
+ pivot_baseline = 0.0
+ fill_weights = np.full(n_values, eps * pivot_baseline, dtype=float)
+ elif fill_method == FILL_METHODS[2]: # "gaussian"
+ fill_weights = _gaussian_fill_weights(
+ n_values=n_values,
+ pivot_indices=indices_array[valid_mask],
+ pivot_weights=weights[valid_mask],
+ sigma_candles=label_weighting["fill_sigma_candles"],
+ logger=logger,
+ )
+ else:
+ raise ValueError(f"Invalid fill_method value {fill_method!r}")
+
return _scatter_weights(
n_values=n_values,
indices=indices,
weights=weights,
- default_weight=0.0,
+ fill_weights=fill_weights,
+ indices_array=indices_array,
+ valid_mask=valid_mask,
)