import random
import time
import warnings
+from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
-from typing import AbstractSet, Any, Callable, Final, Literal, Optional, Union, cast
+from typing import AbstractSet, Any, Callable, ClassVar, Final, Literal, Optional, Union, assert_never, cast
import numpy as np
import optuna
from LabelTransformer import (
CUSTOM_THRESHOLD_METHODS,
EXTREMA_SELECTION_METHODS,
+ LABEL_WEIGHT_SUPPORT_POLICIES,
PREDICTION_METHODS,
SKIMAGE_THRESHOLD_METHODS,
THRESHOLD_METHODS,
CustomThresholdMethod,
ExtremaSelectionMethod,
LabelTransformer,
+ LabelWeightSupportPolicy,
SkimageThresholdMethod,
ThresholdMethod,
get_label_column_config,
DEFAULT_FIT_LIVE_PREDICTIONS_CANDLES,
DEFAULTS_LABEL_PREDICTION,
LABEL_COLUMNS,
+ LabelWeightSupportError,
REGRESSORS,
Regressor,
compose_sample_weights,
get_label_horizon_candles,
get_label_pipeline_config,
get_label_prediction_config,
+ get_label_weighting_config,
get_min_max_label_period_candles,
get_optuna_study_model_parameters,
label_known_at_column_name,
optuna_save_best_params,
sanitize_and_renormalize,
safe_distribution_fit,
+ summarize_label_weight_support,
soft_extremum,
zigzag,
)
SelectionMethod = Union[DistanceMethod, ClusterMethod, DensityMethod]
ValidationMode = Literal["warn", "raise", "none"]
SplitFn = Callable[
- [pd.DataFrame, pd.DataFrame, NDArray[np.floating], pd.DataFrame], dict[str, Any]
+ [pd.DataFrame, pd.DataFrame, "SampleWeightInputs", pd.DataFrame], dict[str, Any]
]
warnings.simplefilter(action="ignore", category=FutureWarning)
)
+@dataclass(frozen=True, slots=True)
+class SampleWeightInputs:
+ base: NDArray[np.floating]
+ label: NDArray[np.floating] | None
+ label_weighting_config: dict[str, Any]
+
+ _REQUIRED_LABEL_WEIGHTING_KEYS: ClassVar[frozenset[str]] = frozenset(
+ {
+ "support_policy",
+ "min_pivot_equivalent_count",
+ "min_positive_label_weight_fraction",
+ "min_effective_sample_size",
+ }
+ )
+
+ def __post_init__(self) -> None:
+ if self.base.ndim != 1:
+ raise ValueError(
+ f"SampleWeightInputs.base: must be 1-D (ndim={self.base.ndim})"
+ )
+ if self.label is not None and self.base.shape != self.label.shape:
+ raise ValueError(
+ f"SampleWeightInputs.label: shape {self.label.shape} "
+ f"!= base shape {self.base.shape}"
+ )
+ missing = (
+ self._REQUIRED_LABEL_WEIGHTING_KEYS - self.label_weighting_config.keys()
+ )
+ if missing:
+ raise KeyError(
+ f"SampleWeightInputs.label_weighting_config: missing required keys "
+ f"{sorted(missing)}"
+ )
+ policy = self.label_weighting_config["support_policy"]
+ if policy not in LABEL_WEIGHT_SUPPORT_POLICIES:
+ raise ValueError(
+ f"SampleWeightInputs.label_weighting_config.support_policy: "
+ f"{policy!r} not in {LABEL_WEIGHT_SUPPORT_POLICIES}"
+ )
+
+
class QuickAdapterRegressorV3(BaseRegressionModel):
"""
The following freqaimodel is released to sponsors of the non-profit FreqAI open-source project.
def _power_mean_metrics_set() -> set[str]:
return set(QuickAdapterRegressorV3._POWER_MEAN_MAP.keys())
- @staticmethod
- def _shuffle_in_unison(
- features: pd.DataFrame,
- labels: pd.DataFrame,
- weights: NDArray[np.floating],
- seed: int,
- ) -> tuple[pd.DataFrame, pd.DataFrame, NDArray[np.floating]]:
- features = features.sample(frac=1, random_state=seed).reset_index(drop=True)
- labels = labels.sample(frac=1, random_state=seed).reset_index(drop=True)
- weights = (
- pd.DataFrame(weights)
- .sample(frac=1, random_state=seed)
- .reset_index(drop=True)
- .to_numpy()[:, 0]
- )
- return features, labels, weights
-
@staticmethod
def _coerce_int(value: Any, name: str, *, minimum: int) -> int:
if isinstance(value, bool) or not isinstance(value, int) or value < minimum:
train_weights: NDArray[np.floating],
keep_mask: NDArray[np.bool_],
context: str,
- ) -> tuple[pd.DataFrame, pd.DataFrame, NDArray[np.floating]]:
+ train_label_weights: NDArray[np.floating] | None = None,
+ ) -> tuple[
+ pd.DataFrame,
+ pd.DataFrame,
+ NDArray[np.floating],
+ NDArray[np.floating] | None,
+ ]:
removed = int((~keep_mask).sum())
if removed:
logger.info(f"{context}: removed {removed} causal-unsafe train rows")
train_features.loc[keep_mask],
train_labels.loc[keep_mask],
train_weights[keep_mask],
+ None if train_label_weights is None else train_label_weights[keep_mask],
)
+ @staticmethod
+ def _shuffle_split_rows(
+ features: pd.DataFrame,
+ labels: pd.DataFrame,
+ base_weights: NDArray[np.floating],
+ label_weights: NDArray[np.floating] | None,
+ seed: int,
+ ) -> tuple[
+ pd.DataFrame,
+ pd.DataFrame,
+ NDArray[np.floating],
+ NDArray[np.floating] | None,
+ ]:
+ shuffled_features = features.sample(frac=1, random_state=seed)
+ order = features.index.get_indexer(shuffled_features.index)
+ if (order < 0).any():
+ raise ValueError(
+ f"_shuffle_split_rows: unable to align shuffled feature rows "
+ f"to sample weights (missing={int((order < 0).sum())} rows)"
+ )
+ shuffled_labels = labels.loc[shuffled_features.index]
+ shuffled_label_weights = None if label_weights is None else label_weights[order]
+ return shuffled_features, shuffled_labels, base_weights[order], shuffled_label_weights
+
+ @staticmethod
+ def _compose_eval_weights(
+ base_weights: NDArray[np.floating],
+ label_weights: NDArray[np.floating] | None,
+ *,
+ context: str,
+ ) -> NDArray[np.floating]:
+ """Compose eval (test/val) sample weights, bypassing ``support_policy``.
+
+ Support thresholds are training-fit invariants and routine test/val
+ splits trip them by construction. With ``on_collapse="fallback"``,
+ the label-derived ``drop_mask`` propagates on collapse-on-survivors
+ so framework-side early-stopping that consumes eval sample weights
+ sees the row-survival pattern of the training-time label weighting.
+ The all-dropped path raises ``LabelWeightSupportError`` and falls
+ back to base weights only; no survival pattern exists to propagate.
+ Shape-parity ``ValueError`` is left uncaught: a hard contract
+ failure, not a support condition.
+ """
+ try:
+ return compose_sample_weights(
+ base_weights,
+ label_weights,
+ logger=logger,
+ on_collapse="fallback",
+ )
+ except LabelWeightSupportError as exc:
+ logger.warning(
+ "%s: label-weighted eval weights failed (%s); using base weights",
+ context,
+ exc,
+ )
+ return compose_sample_weights(base_weights, None, logger=logger)
+
+ @staticmethod
+ def _apply_support_policy(
+ base_weights: NDArray[np.floating],
+ *,
+ context: str,
+ policy: LabelWeightSupportPolicy,
+ reasons: list[str],
+ ) -> NDArray[np.floating]:
+ reason_text = "; ".join(reasons)
+ match policy:
+ case "raise":
+ raise ValueError(
+ f"{context}: label weighting support failed ({reason_text}); "
+ "support_policy='raise'"
+ )
+ case "fallback":
+ logger.warning(
+ "%s: label weighting support failed (%s); "
+ "falling back to sanitized base weights (support_policy='fallback')",
+ context,
+ reason_text,
+ )
+ return compose_sample_weights(base_weights, None, logger=logger)
+ case _:
+ assert_never(policy)
+
+ @staticmethod
+ def _compose_train_weights_with_support(
+ base_weights: NDArray[np.floating],
+ label_weights: NDArray[np.floating] | None,
+ label_weighting_config: dict[str, Any],
+ *,
+ context: str,
+ ) -> NDArray[np.floating]:
+ if label_weights is None:
+ return compose_sample_weights(base_weights, None, logger=logger)
+
+ policy = cast(LabelWeightSupportPolicy, label_weighting_config["support_policy"])
+ try:
+ composed = compose_sample_weights(base_weights, label_weights, logger=logger)
+ except LabelWeightSupportError as exc:
+ return QuickAdapterRegressorV3._apply_support_policy(
+ base_weights,
+ context=context,
+ policy=policy,
+ reasons=[str(exc)],
+ )
+
+ summary = summarize_label_weight_support(label_weights, composed)
+ reasons: list[str] = []
+ min_pivot_equivalent_count = label_weighting_config["min_pivot_equivalent_count"]
+ min_positive_label_weight_fraction = label_weighting_config[
+ "min_positive_label_weight_fraction"
+ ]
+ min_effective_sample_size = label_weighting_config[
+ "min_effective_sample_size"
+ ]
+ if summary.pivot_equivalent_count < min_pivot_equivalent_count:
+ reasons.append(
+ f"pivot_equivalent_count={summary.pivot_equivalent_count} "
+ f"< min_pivot_equivalent_count={min_pivot_equivalent_count}"
+ )
+ if summary.positive_label_weight_fraction < min_positive_label_weight_fraction:
+ reasons.append(
+ f"positive_label_weight_fraction={summary.positive_label_weight_fraction:.6g} "
+ f"< min_positive_label_weight_fraction={min_positive_label_weight_fraction:.6g} "
+ f"({summary.positive_label_weight_count}/{summary.total_rows} rows)"
+ )
+ if summary.effective_sample_size < min_effective_sample_size:
+ reasons.append(
+ f"effective_sample_size={summary.effective_sample_size:.6g} "
+ f"< min_effective_sample_size={min_effective_sample_size:.6g}"
+ )
+ if reasons:
+ return QuickAdapterRegressorV3._apply_support_policy(
+ base_weights, context=context, policy=policy, reasons=reasons
+ )
+ logger.debug(
+ "%s: label weighting support passed "
+ "(pivot_equivalent_count=%d, positive_label_weight_fraction=%.6g, "
+ "effective_sample_size=%.6g)",
+ context,
+ summary.pivot_equivalent_count,
+ summary.positive_label_weight_fraction,
+ summary.effective_sample_size,
+ )
+ return composed
+
@staticmethod
def _get_selection_category(method: str) -> Optional[str]:
for (
return label_frequency_candles
+ @property
+ def label_weighting(self) -> dict[str, Any]:
+ label_weighting_raw = self.freqai_info.get("label_weighting")
+ if not isinstance(label_weighting_raw, dict):
+ label_weighting_raw = {}
+ return get_label_weighting_config(label_weighting_raw, logger)
+
@property
def label_pipeline(self) -> dict[str, Any]:
label_pipeline_raw = self.freqai_info.get("label_pipeline")
Dispatches on ``data_split_parameters.method``:
- ``train_test_split``: random sklearn split.
- ``timeseries_split``: chronological final-fold split.
- Both paths compose per-row weights via ``_compose_per_row_weights``
- before splitting and feed them to ``model.fit(sample_weight=...)``
- through ``_train_common``.
+ Both paths build per-row weights via ``_build_sample_weight_inputs``
+ before splitting. After split + causal-guard filtering, train weights
+ compose through ``_compose_train_weights_with_support`` (gated by
+ ``support_policy``) and eval weights through ``_compose_eval_weights``
+ (bypasses ``support_policy``). ``_train_common`` then feeds them to
+ ``model.fit(sample_weight=...)``.
"""
method = self.data_split_parameters.get(
"method", QuickAdapterRegressorV3.DATA_SPLIT_METHOD_DEFAULT
def split_fn(
features: pd.DataFrame,
labels: pd.DataFrame,
- weights: NDArray[np.floating],
+ weights: SampleWeightInputs,
unfiltered: pd.DataFrame,
) -> dict[str, Any]:
return split_builder(features, labels, weights, dk, unfiltered)
self,
features: pd.DataFrame,
labels: pd.DataFrame,
- weights: NDArray[np.floating],
+ weights: SampleWeightInputs,
dk: FreqaiDataKitchen,
unfiltered_df: pd.DataFrame,
) -> dict[str, Any]:
)
if test_size != 0:
- (
- train_features,
- test_features,
- train_labels,
- test_labels,
- train_weights,
- test_weights,
- ) = train_test_split(features, labels, weights, **sklearn_kwargs)
+ if weights.label is None:
+ (
+ train_features,
+ test_features,
+ train_labels,
+ test_labels,
+ train_base_weights,
+ test_base_weights,
+ ) = train_test_split(features, labels, weights.base, **sklearn_kwargs)
+ train_label_weights = None
+ test_label_weights = None
+ else:
+ (
+ train_features,
+ test_features,
+ train_labels,
+ test_labels,
+ train_base_weights,
+ test_base_weights,
+ train_label_weights,
+ test_label_weights,
+ ) = train_test_split(
+ features, labels, weights.base, weights.label, **sklearn_kwargs
+ )
if causal_mode:
row_positions = QuickAdapterRegressorV3._row_positions(
features, unfiltered_df
)
else:
_log_known_at_none_once(dk.pair, "train_test_split causal guard")
- train_features, train_labels, train_weights = (
- QuickAdapterRegressorV3._filter_train_by_mask(
- train_features,
- train_labels,
- train_weights,
- keep_mask,
- f"[{dk.pair}] train_test_split causal guard",
- )
+ (
+ train_features,
+ train_labels,
+ train_base_weights,
+ train_label_weights,
+ ) = QuickAdapterRegressorV3._filter_train_by_mask(
+ train_features,
+ train_labels,
+ train_base_weights,
+ keep_mask,
+ f"[{dk.pair}] train_test_split causal guard",
+ train_label_weights=train_label_weights,
)
else:
train_features = features
train_labels = labels
- train_weights = weights
+ train_base_weights = weights.base
+ train_label_weights = weights.label
test_features = features.iloc[:0]
test_labels = labels.iloc[:0]
- test_weights = weights[:0]
+ test_base_weights = weights.base[:0]
+ test_label_weights = None if weights.label is None else weights.label[:0]
if feat_dict.get("shuffle_after_split", False):
parent_seed = sklearn_kwargs.get("random_state")
if parent_seed is not None
else random.Random()
)
- train_features, train_labels, train_weights = (
- QuickAdapterRegressorV3._shuffle_in_unison(
+ train_features, train_labels, train_base_weights, train_label_weights = (
+ QuickAdapterRegressorV3._shuffle_split_rows(
train_features,
train_labels,
- train_weights,
+ train_base_weights,
+ train_label_weights,
shuffle_rng.randint(0, 2**31 - 1),
)
)
if test_size != 0:
- test_features, test_labels, test_weights = (
- QuickAdapterRegressorV3._shuffle_in_unison(
+ test_features, test_labels, test_base_weights, test_label_weights = (
+ QuickAdapterRegressorV3._shuffle_split_rows(
test_features,
test_labels,
- test_weights,
+ test_base_weights,
+ test_label_weights,
shuffle_rng.randint(0, 2**31 - 1),
)
)
- train_weights = sanitize_and_renormalize(
- train_weights, logger=logger, context="train_test_split:train"
+ train_weights = QuickAdapterRegressorV3._compose_train_weights_with_support(
+ train_base_weights,
+ train_label_weights,
+ weights.label_weighting_config,
+ context=f"[{dk.pair}] train_test_split:train",
)
if test_size != 0:
- test_weights = sanitize_and_renormalize(
- test_weights, logger=logger, context="train_test_split:test"
+ test_weights = QuickAdapterRegressorV3._compose_eval_weights(
+ test_base_weights,
+ test_label_weights,
+ context=f"[{dk.pair}] train_test_split:test",
)
+ else:
+ test_weights = test_base_weights
if feat_dict.get("reverse_train_test_order", False):
return dk.build_data_dictionary(
test_weights,
)
- def _compose_per_row_weights(
+ def _build_sample_weight_inputs(
self,
features_filtered: pd.DataFrame,
unfiltered_df: pd.DataFrame,
dk: FreqaiDataKitchen,
- ) -> NDArray[np.floating]:
- """Build a per-row sample weight vector aligned to features_filtered.index.
+ ) -> SampleWeightInputs:
+ """Build per-row base and label weight vectors aligned to features_filtered.index.
Multiplies freqtrade's per-row base weights (recency-decayed via
``dk.set_weights_higher_recent`` when ``feature_parameters.weight_factor > 0``,
any shuffle/split on ``features_filtered.index`` (a subset of
``unfiltered_df.index``) to avoid post-hoc reindex against shuffled
data. The weight column is absent when ``label_weighting.strategy``
- is ``'none'`` (no per-label importance applied); in that case
- ``label_weights=None`` is forwarded to ``compose_sample_weights``
- and only the base weights contribute.
+ is ``'none'`` (no per-label importance applied); in that case the
+ final split stage composes base-only sample weights.
"""
if not unfiltered_df.index.is_unique:
raise ValueError(
else:
base_weights = np.ones(n_rows, dtype=float)
+ label_weighting = self.label_weighting
+ label_weighting_config = get_label_column_config(
+ LABEL_COLUMNS[0], label_weighting["default"], label_weighting["columns"]
+ )
weight_col = label_weight_column_name(LABEL_COLUMNS[0])
if weight_col in unfiltered_df.columns:
label_weights = unfiltered_df.loc[
logger.debug(
f"label weight column absent ({weight_col!r}); using base weights only"
)
- return compose_sample_weights(
- base_weights,
- label_weights,
- logger=logger,
+ return SampleWeightInputs(
+ base=base_weights,
+ label=label_weights,
+ label_weighting_config=label_weighting_config,
)
def _train_common(
dk.label_list,
training_filter=True,
)
- weights = self._compose_per_row_weights(features_filtered, unfiltered_df, dk)
+ weights = self._build_sample_weight_inputs(features_filtered, unfiltered_df, dk)
dates = ensure_datetime_series(unfiltered_df["date"])
start_date = dates.iloc[0].strftime("%Y-%m-%d")
end_date = dates.iloc[-1].strftime("%Y-%m-%d")
self,
filtered_dataframe: pd.DataFrame,
labels: pd.DataFrame,
- weights: NDArray[np.floating],
+ weights: SampleWeightInputs,
dk: FreqaiDataKitchen,
unfiltered_df: pd.DataFrame,
) -> dict:
and 0 < test_size < 1
):
test_size = int(len(filtered_dataframe) * test_size)
- elif (
+ elif not (
not isinstance(test_size, bool)
and isinstance(test_size, int)
and test_size >= 1
):
- pass
- else:
raise ValueError(
f"Invalid data_split_parameters.test_size value {test_size!r}: "
f"must be float in (0, 1) as fraction, int >= 1 as count, or None"
test_features = filtered_dataframe.iloc[test_idx]
train_labels = labels.iloc[train_idx]
test_labels = labels.iloc[test_idx]
- train_weights = weights[train_idx]
- test_weights = sanitize_and_renormalize(
- weights[test_idx], logger=logger, context="timeseries_split:test"
+ train_base_weights = weights.base[train_idx]
+ test_base_weights = weights.base[test_idx]
+ train_label_weights = None if weights.label is None else weights.label[train_idx]
+ test_label_weights = None if weights.label is None else weights.label[test_idx]
+ test_weights = QuickAdapterRegressorV3._compose_eval_weights(
+ test_base_weights,
+ test_label_weights,
+ context=f"[{dk.pair}] timeseries_split:test",
)
if causal_mode:
keep_mask = (
known_at_train.to_numpy(dtype=np.int64) < first_test_position
)
- train_features, train_labels, train_weights = (
- QuickAdapterRegressorV3._filter_train_by_mask(
- train_features,
- train_labels,
- train_weights,
- keep_mask,
- f"[{dk.pair}] timeseries_split causal guard",
- )
+ (
+ train_features,
+ train_labels,
+ train_base_weights,
+ train_label_weights,
+ ) = QuickAdapterRegressorV3._filter_train_by_mask(
+ train_features,
+ train_labels,
+ train_base_weights,
+ keep_mask,
+ f"[{dk.pair}] timeseries_split causal guard",
+ train_label_weights=train_label_weights,
)
else:
_log_known_at_none_once(dk.pair, "timeseries_split causal guard")
- train_weights = sanitize_and_renormalize(
- train_weights, logger=logger, context="timeseries_split:train"
+ train_weights = QuickAdapterRegressorV3._compose_train_weights_with_support(
+ train_base_weights,
+ train_label_weights,
+ weights.label_weighting_config,
+ context=f"[{dk.pair}] timeseries_split:train",
)
if feat_dict.get("reverse_train_test_order", False):
Final,
Literal,
TypeVar,
+ assert_never,
)
import numpy as np
FILL_BANDWIDTHS,
FILL_EPSILON_BASELINES,
FILL_METHODS,
+ LABEL_WEIGHT_SUPPORT_POLICIES,
NORMALIZATION_TYPES,
PREDICTION_METHODS,
SMOOTHING_METHODS,
"fill_bandwidth_alpha": _ParamSpec(
_NumericValidator(min_value=0, min_exclusive=True), output_type=float
),
+ "support_policy": _ParamSpec(_EnumValidator(LABEL_WEIGHT_SUPPORT_POLICIES)),
+ "min_pivot_equivalent_count": _ParamSpec(
+ _NumericValidator(min_value=1, require_int=True), output_type=int
+ ),
+ "min_positive_label_weight_fraction": _ParamSpec(
+ _NumericValidator(min_value=0.0, max_value=1.0), output_type=float
+ ),
+ "min_effective_sample_size": _ParamSpec(
+ _NumericValidator(min_value=1), output_type=float
+ ),
}
_PIPELINE_SPECS: Final[dict[str, _ParamSpec]] = {
return int((survivors >= threshold).sum())
+@dataclass(frozen=True, slots=True)
+class LabelWeightSupportSummary:
+ """Diagnostics for label-weighting support on a training split.
+
+ - ``total_rows``: filtered training row count
+ - ``positive_label_weight_count``/``positive_label_weight_fraction``:
+ rows with finite positive **label** weights (pre-composition)
+ - ``pivot_equivalent_count``: rows whose label weight is at least
+ ``_PIVOT_EQUIVALENT_MAX_FRACTION`` (10%) of the surviving maximum
+ - ``effective_sample_size``: Kish's ESS computed on the final
+ composed **sample** weights, ``(Sigma w)^2 / Sigma(w^2)``
+ """
+ total_rows: int
+ positive_label_weight_count: int
+ positive_label_weight_fraction: float
+ pivot_equivalent_count: int
+ effective_sample_size: float
+
+
+def _effective_sample_size(weights: NDArray[np.floating]) -> float:
+ """Kish's effective sample size ``(Sigma w)^2 / Sigma(w^2)`` over
+ finite strictly-positive entries. Returns 0.0 on empty/degenerate input.
+ """
+ arr = np.asarray(weights, dtype=float)
+ positive = arr[np.isfinite(arr) & (arr > 0.0)]
+ if positive.size == 0:
+ return 0.0
+ total = float(positive.sum())
+ sum_squares = float(np.square(positive).sum())
+ if total <= 0.0 or sum_squares <= 0.0 or not np.isfinite(total + sum_squares):
+ return 0.0
+ return float((total * total) / sum_squares)
+
+
+def summarize_label_weight_support(
+ label_weights: NDArray[np.floating],
+ sample_weights: NDArray[np.floating],
+) -> LabelWeightSupportSummary:
+ """Compute support diagnostics for one training split.
+
+ ``positive_label_weight_*`` and ``pivot_equivalent_count`` are derived from
+ ``label_weights``; ``effective_sample_size`` is Kish's ESS on
+ ``sample_weights`` (the composed output of ``compose_sample_weights``).
+ """
+ labels = np.asarray(label_weights, dtype=float)
+ samples = np.asarray(sample_weights, dtype=float)
+ if labels.shape != samples.shape:
+ raise ValueError(
+ f"summarize_label_weight_support: label_weights shape {labels.shape} "
+ f"!= sample_weights shape {samples.shape}"
+ )
+ n = int(labels.size)
+ positive_mask = np.isfinite(labels) & (labels > 0.0)
+ positive_count = int(positive_mask.sum())
+ positive_fraction = float(positive_count / n) if n else 0.0
+ return LabelWeightSupportSummary(
+ total_rows=n,
+ positive_label_weight_count=positive_count,
+ positive_label_weight_fraction=positive_fraction,
+ pivot_equivalent_count=_pivot_equivalent_count(labels, ~positive_mask),
+ effective_sample_size=_effective_sample_size(samples),
+ )
+
+
+class LabelWeightSupportError(ValueError):
+ """Raised by ``compose_sample_weights`` when label-weighted composition
+ fails a support condition that callers may want to route through a
+ ``support_policy`` (all rows dropped, or collapse with
+ ``on_collapse="raise"``). Shape-parity violations are bare
+ ``ValueError`` and propagate as hard contract failures.
+ """
+
+
def compose_sample_weights(
base_weights: NDArray[np.floating],
label_weights: NDArray[np.floating] | None,
*,
logger: Logger,
+ on_collapse: Literal["raise", "fallback"] = "raise",
) -> NDArray[np.floating]:
"""Combine base sample weights with the label importance weights.
Returns ``w in R+^N`` with ``mean(w) == 1``. Rows where
``label_weights[i]`` is non-finite or ``<= 0`` are dropped
(``out[i] == 0``); surviving rows carry ``base_weights * label_weights``
- rescaled to global ``mean == 1``. On collapse of the label-weighted
- product, falls back to ``base_weights`` (with the label-derived
- drop_mask) so the recency signal is preserved.
-
- Raises ValueError on shape mismatch or when every row is dropped.
+ rescaled to global ``mean == 1``.
+
+ ``on_collapse`` controls the response when the label-weighted product
+ collapses on every surviving row: ``"raise"`` (default) surfaces the
+ collapse as ``LabelWeightSupportError`` so callers can route it through
+ their support policy; ``"fallback"`` warns and returns ``base_weights``
+ sanitized with the label-derived ``drop_mask`` so the recency signal
+ is preserved (used by eval splits that bypass support thresholds).
+
+ Raises ``ValueError`` on shape mismatch (hard contract failure).
+ Raises ``LabelWeightSupportError`` when every row is dropped or when
+ collapse occurs with ``on_collapse="raise"``.
"""
base_weights = np.asarray(base_weights, dtype=float)
if label_weights is None:
)
drop_mask = ~np.isfinite(arr) | (arr <= 0.0)
if drop_mask.all():
- raise ValueError(
+ raise LabelWeightSupportError(
"compose_sample_weights: all rows dropped by zero or non-finite "
"label weights; no surviving training samples"
)
100.0 * SPARSE_TRAINING_MASS_THRESHOLD,
)
combined = base_weights * arr
- # Detect collapse on surviving rows up front so the fallback can route
- # to base weights rather than the uniform fallback inside sanitize.
survivor_mask = ~(drop_mask | ~np.isfinite(combined) | (combined <= 0.0))
survivor_total = float(np.where(survivor_mask, combined, 0.0).sum())
if survivor_total > 0.0 and np.isfinite(survivor_total):
logger=logger,
context="compose:label_weighted",
)
- logger.warning(
- "compose_sample_weights: composed weights collapsed on surviving "
- "rows (survivor_total=%g); falling back to base weights",
- survivor_total,
- )
- return sanitize_and_renormalize(
- base_weights,
- drop_mask=drop_mask,
- logger=logger,
- context="compose:base_fallback",
- )
+ match on_collapse:
+ case "raise":
+ raise LabelWeightSupportError(
+ f"compose_sample_weights: composed weights collapsed on "
+ f"surviving rows (survivor_total={survivor_total:.6g})"
+ )
+ case "fallback":
+ logger.warning(
+ "compose_sample_weights: composed weights collapsed on surviving "
+ "rows (survivor_total=%.6g); falling back to base weights",
+ survivor_total,
+ )
+ return sanitize_and_renormalize(
+ base_weights,
+ drop_mask=drop_mask,
+ logger=logger,
+ context="compose:base_fallback",
+ )
+ case _:
+ assert_never(on_collapse)
def nan_average(
return escaped
-def _format_collection(
- value: list | tuple | set,
- ctx: _FormatContext,
- depth: int,
- brackets: tuple[str, str],
- empty: str,
- trailing_comma: bool = False,
-) -> str:
- if not value:
- return empty
- obj_id = id(value)
- if obj_id in ctx.seen:
- return f"{brackets[0]}<circular>{brackets[1]}"
- if depth >= _MAX_DEPTH:
- return f"{brackets[0]}...{brackets[1]}"
- ctx.seen.add(obj_id)
- items_iter = sorted(value, key=str) if isinstance(value, set) else value
- items = [_format_value(v, ctx, depth + 1) for v in list(items_iter)[:_MAX_ITEMS]]
- if len(value) > _MAX_ITEMS:
- items.append(f"...+{len(value) - _MAX_ITEMS}")
- content = ", ".join(items)
- if trailing_comma and len(value) == 1 and len(items) == 1:
- content += ","
- ctx.seen.discard(obj_id)
- return f"{brackets[0]}{content}{brackets[1]}"
-
-
@_format_value.register(list)
def _(value: list, ctx: _FormatContext, depth: int) -> str:
return _format_collection(value, ctx, depth, ("[", "]"), "[]")
return f"array{value.shape}"
+def _format_collection(
+ value: list | tuple | set,
+ ctx: _FormatContext,
+ depth: int,
+ brackets: tuple[str, str],
+ empty: str,
+ trailing_comma: bool = False,
+) -> str:
+ if not value:
+ return empty
+ obj_id = id(value)
+ if obj_id in ctx.seen:
+ return f"{brackets[0]}<circular>{brackets[1]}"
+ if depth >= _MAX_DEPTH:
+ return f"{brackets[0]}...{brackets[1]}"
+ ctx.seen.add(obj_id)
+ items_iter = sorted(value, key=str) if isinstance(value, set) else value
+ items = [_format_value(v, ctx, depth + 1) for v in list(items_iter)[:_MAX_ITEMS]]
+ if len(value) > _MAX_ITEMS:
+ items.append(f"...+{len(value) - _MAX_ITEMS}")
+ content = ", ".join(items)
+ if trailing_comma and len(value) == 1 and len(items) == 1:
+ content += ","
+ ctx.seen.discard(obj_id)
+ return f"{brackets[0]}{content}{brackets[1]}"
+
+
def format_dict(
d: dict[str, Any],
style: Literal["dict", "params"] = "dict",