From: Jérôme Benoit Date: Mon, 25 May 2026 15:52:05 +0000 (+0200) Subject: feat(quickadapter): add soft off-pivot weighting (epsilon, gaussian) to label_weighti... X-Git-Url: https://git.piment-noir.org/?a=commitdiff_plain;h=fa6712826381e1014b7c556ce55a5c519aef58cc;p=freqai-strategies.git feat(quickadapter): add soft off-pivot weighting (epsilon, gaussian) to label_weighting (#74) * feat(quickadapter): add soft off-pivot weighting (epsilon, gaussian) to label_weighting Adds three off-pivot weighting modes behind a new fill_method tunable in freqai.label_weighting: - zero (default): current hard-zero behavior, retained for backward compatibility. - epsilon: off-pivot rows receive a flat baseline fill_epsilon * (pivot_weights), where is mean or median, controlled by fill_epsilon_baseline. - gaussian: off-pivot rows receive a per-row weight from a heatmap-style decay max_p w_p * exp(-(i-p)^2 / (2 sigma^2)), controlled by fill_sigma_candles (>= 0.5). The default is zero so existing configs without the new keys behave identically. Switching fill_method materially changes per-leaf weight mass and may require GBM hyperparameter retuning; flagged in the README description column. Implementation: - Adds FillMethod/FILL_METHODS and FillEpsilonBaseline/FILL_EPSILON_BASELINES Literal types and tuples in LabelTransformer.py. - Extends DEFAULTS_LABEL_WEIGHTING with the four new keys and their defaults. - Extends _WEIGHTING_SPECS in Utils.py with corresponding _EnumValidator and _NumericValidator entries (epsilon in [0, 1], sigma_candles >= 0.5). - Refactors _scatter_weights to accept fill_weights as a precomputed array plus optional indices_array/valid_mask kwargs; preserves pre-existing length-mismatch ValueError and empty-input early-return semantics. - Adds _gaussian_fill_weights helper with in-place pipeline keeping peak memory at one (chunk, M) buffer; chunk-by-N keyed on _GAUSSIAN_FILL_CHUNK_BUDGET = 50_000_000 cells (~400 MB peak); emits a density warning when M / N > 0.1; rejects negative pivot weights. - Adds *, logger: Logger keyword-only parameter to compute_label_weights and updates the single call site in QuickAdapterV3.py. - Replaces the raw nonzero count in compose_sample_weights with a pivot-equivalent count helper (_pivot_equivalent_count) so the sparse training mass warning stays meaningful under epsilon / gaussian. Documentation: - Four new rows added to the README configuration tunables table under Label weighting; fill_method flagged as requiring trained-model deletion when changed. - Four new keys added to config-template.json under label_weighting. Verified manually on host via AST extraction harness (no automated test infrastructure exists in quickadapter/): - STATIC_OK: defaults + tuples assertions pass. - SPOTCHECK_4..9: cluster amplification (out[50] ~= 8.0), sigma < 0.5 rejected, negative pivot weights rejected, density warning emitted at M/N=0.2, empty pivots return zeros, mean/median epsilon ratio = 20.8x. - SPARSE_4A..C: sparse-mass warning fires under zero mode + sparse pivots and gaussian sigma=0.5 underflow regime; silent under broad gaussian fills. * fix(quickadapter): harden sanitize_and_renormalize against rescale overflow and drop_mask contract violations Two production-quality safeguards on the load-bearing primitive used by the new fill_method dispatch in PR #74, plus one cosmetic comment cleanup. 1. Subnormal-total rescale overflow guard: When the sum of sanitized weights falls into a subnormal range (e.g. a single 1e-310 survivor among zeros, n=1000), n/total overflows to +Inf and safe * Inf propagates Inf to every nonzero entry, producing mean(out) = NaN and silently violating the documented mean=1 invariant. The fix computes the rescale factor into a local c, checks np.isfinite(c), and falls through to the existing uniform-fallback path with a distinct warning message ('rescale factor non-finite') so operators can distinguish this regime from the existing 'weights collapsed' case. Bit-identical on all common paths; c -> 0 underflow is unreachable (min c = 1/DBL_MAX > 0). 2. drop_mask shape and dtype assertions: sanitize_and_renormalize is now load-bearing for compose_sample_weights under all three fill_method modes (zero/epsilon/gaussian). Numpy broadcasts a (k, n)-shaped mask silently, breaking the (n,) output contract. Shape and dtype precondition checks raise ValueError early with prefixed messages matching the function's existing logger style. Dtype check uses np.issubdtype(..., np.bool_) so any boolean alias (bool, np.bool_, 'bool') is accepted; integer masks are rejected. 3. LabelTransformer.py: replace 'current behavior' comment with 'default' on FILL_METHODS[0] since the comparison no longer makes sense once the PR is merged. Verified manually: - REVIEW_FIX_1A..C: bool, np.bool_, 'bool' all accepted; int rejected. - REVIEW_FIX_2A..B: subnormal-overflow path emits the new distinct warning; real collapse path emits the original warning. - REVIEW_FIX_3_OK: docstring contradiction removed. - REGRESSION_OK: bit-identical common path. - All original PR #74 verifications still pass (SPARSE_4A..C, SPOTCHECK_4..9). * fix(quickadapter): short-circuit compute_label_weights on empty pivot weights When metrics[strategy] is empty but indices is non-empty, the new fill_method dispatch in epsilon/gaussian arms slices weights[valid_mask] before _scatter_weights can short-circuit, raising IndexError on a size-0 / N-mask shape mismatch. Pre-PR _scatter_weights returned the default-filled array silently in this case (preserved invariant noted inline at the empty-input early return). Add a short-circuit before the dispatch so the contract is consistent across all three fill methods. Also trim _gaussian_fill_weights docstring to match the codebase style (neighboring private helpers carry no docstring or a single short paragraph) and drop a redundant in-line comment that the in-place np.multiply(out=buf) pattern already conveys. Verified on the AST-extraction harness (pre-fix reproduction → fix verification): 12 contract assertions across 4 edge cases x 3 fill methods, plus crossmode + non-empty differentiation, all pass; PR #74 SPOTCHECK_4..9, SPARSE_4A..C, REVIEW_FIX_*, REGRESSION_OK still pass. * chore(quickadapter): bump strategy and regressor version 3.11.10 -> 3.11.11 * refactor(quickadapter): polish label_weighting docs, comments, and sparse-mass diagnostic Three coordinated polish edits following final-review feedback: 1. _pivot_equivalent_count: replace the 0.5 * median threshold with _PIVOT_EQUIVALENT_MAX_FRACTION * surviving max (default 0.1). The median-based heuristic saturated at N under epsilon mode (off-pivot floor dominates the median once N >> M), silencing the warning the docstring claimed to provide. The max-relative threshold separates pivot-class rows from off-pivot fill across the bimodal regimes fill_method introduces. Constant is module-level and named so the choice is auditable; warning text now self-describes the threshold ('rows above 10% of surviving max'). 2. _scatter_weights: trim the 'Order matters...' comment from 3 lines to 1 line. The shorter form pins the intentional ordering without paraphrasing git history; future 'validate inputs first' refactors are still flagged. 3. README: extend the fill_method row with a concise retuning hint (per-leaf regularization + Optuna study reset) so the operator guidance surfaces in user docs, not only in the planning artefact. Tighten fill_sigma_candles description to match neighboring-row density. Verified manually: - SPARSE_4A..C: original PR cases still pass. - SPARSE_4D: epsilon+sparse pivots (M=20, N=1000) now correctly fires the sparse-mass warning (was silenced with median-based threshold). - SPARSE_4E: zero+skewed pivots ([1,1,...,10]) still fire under the new threshold (no regression on the skew case). - SPOTCHECK_4..9, BUG_74_FIX_*, REVIEW_FIX_*, REGRESSION_OK: all unchanged. --- diff --git a/README.md b/README.md index 8c59bd2..cb0cb38 100644 --- a/README.md +++ b/README.md @@ -79,6 +79,10 @@ docker compose up -d --build | freqai.label_weighting.metric_coefficients | {} | dict[str, float] | Per-metric coefficients for `combined` strategy. Keys: `amplitude`, `amplitude_threshold_ratio`, `volume_rate`, `speed`, `efficiency_ratio`, `volume_weighted_efficiency_ratio`. | | freqai.label_weighting.aggregation | `arithmetic_mean` | enum {`arithmetic_mean`,`geometric_mean`,`harmonic_mean`,`quadratic_mean`,`weighted_median`,`softmax`} | Metric aggregation method for `combined` strategy. `arithmetic_mean`=(Σ(w·m)/Σ(w)), `geometric_mean`=(∏(m^w))^(1/Σw), `harmonic_mean`=Σ(w)/(Σ(w/m)), `quadratic_mean`=(Σ(w·m²)/Σ(w))^(1/2), `weighted_median`=Q₀.₅(m,w), `softmax`=Σ(m·s_i) where s_i=w_i·exp(m_i/T)/Σ(w_j·exp(m_j/T)). | | freqai.label_weighting.softmax_temperature | 1.0 | float > 0 | Temperature T for `softmax` aggregation, controls distribution sharpness. | +| freqai.label_weighting.fill_method | `zero` | enum {`zero`,`epsilon`,`gaussian`} | Off-pivot weighting scheme. `zero` hard-zeros off-pivot rows; `epsilon` applies a flat baseline `fill_epsilon * (pivot_weights)`; `gaussian` applies heatmap-style decay around each pivot. Switching away from `zero` may require retuning tree-leaf regularization (`min_child_weight`, `lambda`) and resetting any prior Optuna study. Changing this parameter requires deleting trained models. | +| freqai.label_weighting.fill_epsilon | 0.001 | float [0,1] | Off-pivot fraction of the pivot baseline. Ignored when `fill_method != "epsilon"`. | +| freqai.label_weighting.fill_epsilon_baseline | `mean` | enum {`mean`,`median`} | Pivot baseline statistic. `mean` tracks central tendency; `median` is robust against pivot-weight skew. Ignored when `fill_method != "epsilon"`. | +| freqai.label_weighting.fill_sigma_candles | 3.0 | float >= 0.5 | Gaussian standard deviation in candles for `fill_method == "gaussian"`. Lower bound 0.5 prevents underflow that silently degrades to `zero` mode. Ignored when `fill_method != "gaussian"`. | | _Label pipeline_ | | | | | freqai.label_pipeline.standardization | `none` | enum {`none`,`zscore`,`robust`,`mmad`,`power_yj`} | Standardization method applied to labels before normalization. `none`=w, `zscore`=(w-μ)/σ, `robust`=(w-median)/(Q₃-Q₁), `mmad`=(w-median)/(MAD·k), `power_yj`=YJ(w). | | freqai.label_pipeline.robust_quantiles | [0.25, 0.75] | list[float] where 0 <= Q1 < Q3 <= 1 | Quantile range for robust standardization, Q1 and Q3. | diff --git a/quickadapter/user_data/config-template.json b/quickadapter/user_data/config-template.json index acca7c9..d105c6a 100644 --- a/quickadapter/user_data/config-template.json +++ b/quickadapter/user_data/config-template.json @@ -98,7 +98,11 @@ "data_kitchen_thread_count": 6, // set to number of CPU threads / 4 "track_performance": false, "label_weighting": { - "strategy": "none" + "strategy": "none", + "fill_method": "zero", + "fill_epsilon": 0.001, + "fill_epsilon_baseline": "mean", + "fill_sigma_candles": 3.0 // Per-label format: // "default": { // "strategy": "none" diff --git a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py index 8fc62d3..df540b4 100644 --- a/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py +++ b/quickadapter/user_data/freqaimodels/QuickAdapterRegressorV3.py @@ -102,7 +102,7 @@ class QuickAdapterRegressorV3(BaseRegressionModel): https://github.com/sponsors/robcaulk """ - version = "3.11.10" + version = "3.11.11" _TEST_SIZE: Final[float] = 0.1 diff --git a/quickadapter/user_data/strategies/LabelTransformer.py b/quickadapter/user_data/strategies/LabelTransformer.py index 11f93b8..d4c9caf 100644 --- a/quickadapter/user_data/strategies/LabelTransformer.py +++ b/quickadapter/user_data/strategies/LabelTransformer.py @@ -63,6 +63,19 @@ WEIGHT_STRATEGIES: Final[tuple[WeightStrategy, ...]] = ( "combined", ) +FillMethod = Literal["zero", "epsilon", "gaussian"] +FILL_METHODS: Final[tuple[FillMethod, ...]] = ( + "zero", # 0 - hard zero (default) + "epsilon", # 1 - flat fraction of pivot baseline + "gaussian", # 2 - per-row Gaussian decay around each pivot +) + +FillEpsilonBaseline = Literal["mean", "median"] +FILL_EPSILON_BASELINES: Final[tuple[FillEpsilonBaseline, ...]] = ( + "mean", # 0 - arithmetic mean (default) + "median", # 1 - robust against pivot-weight skew +) + StandardizationType = Literal["none", "zscore", "robust", "mmad", "power_yj"] STANDARDIZATION_TYPES: Final[tuple[StandardizationType, ...]] = ( "none", # 0 - w @@ -85,6 +98,10 @@ DEFAULTS_LABEL_WEIGHTING: Final[dict[str, Any]] = { "metric_coefficients": {}, "aggregation": COMBINED_AGGREGATIONS[0], # "arithmetic_mean" "softmax_temperature": 1.0, + "fill_method": FILL_METHODS[0], # "zero" + "fill_epsilon": 1e-3, + "fill_epsilon_baseline": FILL_EPSILON_BASELINES[0], # "mean" + "fill_sigma_candles": 3.0, } DEFAULTS_LABEL_PIPELINE: Final[dict[str, Any]] = { diff --git a/quickadapter/user_data/strategies/QuickAdapterV3.py b/quickadapter/user_data/strategies/QuickAdapterV3.py index 1c6d576..64b168c 100644 --- a/quickadapter/user_data/strategies/QuickAdapterV3.py +++ b/quickadapter/user_data/strategies/QuickAdapterV3.py @@ -114,7 +114,7 @@ class QuickAdapterV3(IStrategy): _ANNOTATION_LINE_OFFSET_CANDLES: Final[int] = 10 def version(self) -> str: - return "3.11.10" + return "3.11.11" timeframe = "5m" timeframe_minutes = timeframe_to_minutes(timeframe) @@ -875,6 +875,7 @@ class QuickAdapterV3(IStrategy): indices=label_data.indices, metrics=label_data.metrics, weighting_config=col_weighting_config, + logger=logger, ) if label_col == EXTREMA_COLUMN: diff --git a/quickadapter/user_data/strategies/Utils.py b/quickadapter/user_data/strategies/Utils.py index 4befeb5..72edcdf 100644 --- a/quickadapter/user_data/strategies/Utils.py +++ b/quickadapter/user_data/strategies/Utils.py @@ -33,6 +33,8 @@ from LabelTransformer import ( DEFAULTS_LABEL_SMOOTHING, DEFAULTS_LABEL_WEIGHTING, EXTREMA_SELECTION_METHODS, + FILL_EPSILON_BASELINES, + FILL_METHODS, NORMALIZATION_TYPES, PREDICTION_METHODS, SMOOTHING_METHODS, @@ -195,6 +197,14 @@ _WEIGHTING_SPECS: Final[dict[str, _ParamSpec]] = { "softmax_temperature": _ParamSpec( _NumericValidator(min_value=0, min_exclusive=True) ), + "fill_method": _ParamSpec(_EnumValidator(FILL_METHODS)), + "fill_epsilon": _ParamSpec( + _NumericValidator(min_value=0.0, max_value=1.0), output_type=float + ), + "fill_epsilon_baseline": _ParamSpec(_EnumValidator(FILL_EPSILON_BASELINES)), + "fill_sigma_candles": _ParamSpec( + _NumericValidator(min_value=0.5), output_type=float + ), } _PIPELINE_SPECS: Final[dict[str, _ParamSpec]] = { @@ -734,18 +744,43 @@ def sanitize_and_renormalize( return arr safe = np.where(np.isfinite(arr) & (arr > 0.0), arr, 0.0) if drop_mask is not None: + drop_mask = np.asarray(drop_mask) + if drop_mask.shape != arr.shape: + raise ValueError( + f"sanitize_and_renormalize: drop_mask shape " + f"{drop_mask.shape} != arr shape {arr.shape}" + ) + if not np.issubdtype(drop_mask.dtype, np.bool_): + raise ValueError( + f"sanitize_and_renormalize: drop_mask dtype " + f"{drop_mask.dtype} is not boolean" + ) safe = np.where(drop_mask, 0.0, safe) total = safe.sum() + rescale_overflow = False if total > 0.0 and np.isfinite(total): - return safe * (n / total) + c = n / total + if np.isfinite(c): + return safe * c + rescale_overflow = True if logger is not None: - logger.warning( - "sanitize_and_renormalize: weights collapsed (context=%s, " - "total=%r, n=%d); falling back to uniform weights", - context or "unspecified", - total, - n, - ) + if rescale_overflow: + logger.warning( + "sanitize_and_renormalize: rescale factor non-finite " + "(context=%s, n=%d, total=%r); falling back to uniform " + "weights", + context or "unspecified", + n, + total, + ) + else: + logger.warning( + "sanitize_and_renormalize: weights collapsed (context=%s, " + "total=%r, n=%d); falling back to uniform weights", + context or "unspecified", + total, + n, + ) fallback = np.ones(n, dtype=float) if drop_mask is not None: masked = np.where(drop_mask, 0.0, fallback) @@ -761,6 +796,29 @@ def sanitize_and_renormalize( return fallback +_PIVOT_EQUIVALENT_MAX_FRACTION: Final[float] = 0.1 + + +def _pivot_equivalent_count( + label_weights: NDArray[np.floating], + drop_mask: NDArray[np.bool_], +) -> int: + """Count rows whose label weight is at least a fraction of the surviving max. + + A max-relative threshold (``_PIVOT_EQUIVALENT_MAX_FRACTION``) separates + pivot-class rows from off-pivot fill across the bimodal regimes that + ``fill_method`` introduces (where a median-based threshold would + saturate at ``N`` once the off-pivot floor dominates the median). + """ + survivors = label_weights[~drop_mask] + if survivors.size == 0: + return 0 + threshold = _PIVOT_EQUIVALENT_MAX_FRACTION * float(survivors.max()) + if threshold <= 0.0: + return 0 + return int((survivors >= threshold).sum()) + + def compose_sample_weights( base_weights: NDArray[np.floating], label_weights: NDArray[np.floating] | None, @@ -796,13 +854,15 @@ def compose_sample_weights( "compose_sample_weights: all rows dropped by zero or non-finite " "label weights; no surviving training samples" ) - nonzero = int((~drop_mask).sum()) + nonzero = _pivot_equivalent_count(arr, drop_mask) if nonzero / n < SPARSE_TRAINING_MASS_THRESHOLD: logger.warning( "compose_sample_weights: sparse training mass " - "(%d/%d rows = %.2f%% nonzero, threshold=%.2f%%)", + "(%d/%d rows above %.0f%% of surviving max = %.2f%%, " + "threshold=%.2f%%)", nonzero, n, + 100.0 * _PIVOT_EQUIVALENT_MAX_FRACTION, 100.0 * nonzero / n, 100.0 * SPARSE_TRAINING_MASS_THRESHOLD, ) @@ -1041,36 +1101,114 @@ def _impute_weights( return weights +_GAUSSIAN_FILL_CHUNK_BUDGET: Final[int] = 50_000_000 +_GAUSSIAN_FILL_DENSITY_WARN: Final[float] = 0.1 + + +def _gaussian_fill_weights( + n_values: int, + pivot_indices: NDArray[np.integer], + pivot_weights: NDArray[np.floating], + sigma_candles: float, + *, + logger: Logger | None = None, +) -> NDArray[np.floating]: + """Per-row max of Gaussian-decayed pivot weights. + + Out[i] = max over p of ``w_p * exp(-(i - p)**2 / (2 * sigma**2))``. + With clustered pivots within ``~sigma_candles``, the per-row max + lets a stronger neighbor dominate weaker ones; pick + ``sigma_candles <= label_period_candles / 2`` to preserve pivot + identity. + """ + if sigma_candles < 0.5: + raise ValueError( + f"Invalid sigma_candles value {sigma_candles!r}: must be >= 0.5" + ) + if pivot_indices.size == 0: + return np.zeros(n_values, dtype=float) + if np.any(pivot_weights < 0.0): + raise ValueError( + f"Invalid pivot_weights min={float(pivot_weights.min())!r}: " + f"must be >= 0" + ) + pivot_indices_array = pivot_indices.astype(float) + pivot_weights_row = pivot_weights.astype(float)[np.newaxis, :] + inv_two_sigma_sq = 0.5 / (sigma_candles * sigma_candles) + M = pivot_indices_array.size + if ( + logger is not None + and n_values > 0 + and M / n_values > _GAUSSIAN_FILL_DENSITY_WARN + ): + logger.warning( + "gaussian_fill: pivot density M/N=%.3f > %.2f (M=%d, N=%d); " + "consider tightening zigzag detection", + M / n_values, + _GAUSSIAN_FILL_DENSITY_WARN, + M, + n_values, + ) + chunk = max(1, _GAUSSIAN_FILL_CHUNK_BUDGET // max(M, 1)) + if logger is not None and chunk < n_values: + logger.debug( + "gaussian_fill: N=%d, M=%d, chunk=%d, ~%.0f MB peak buffer", + n_values, + M, + chunk, + chunk * M * 8 / 1e6, + ) + out = np.zeros(n_values, dtype=float) + for start in range(0, n_values, chunk): + stop = min(start + chunk, n_values) + positions = np.arange(start, stop, dtype=float) + buf = positions[:, np.newaxis] - pivot_indices_array[np.newaxis, :] + np.multiply(buf, buf, out=buf) + np.multiply(buf, -inv_two_sigma_sq, out=buf) + np.exp(buf, out=buf) + np.multiply(buf, pivot_weights_row, out=buf) + np.max(buf, axis=1, out=out[start:stop]) + return out + + def _scatter_weights( n_values: int, indices: list[int], weights: NDArray[np.floating], - default_weight: float, + fill_weights: NDArray[np.floating], + *, + indices_array: NDArray[np.integer] | None = None, + valid_mask: NDArray[np.bool_] | None = None, ) -> NDArray[np.floating]: """Scatter per-pivot weights into a full-length array. - Non-pivot rows are filled with ``default_weight``. Callers pass ``0.0`` - to exclude non-pivot rows from training (pivot-only weighting), or a - positive value to give them a baseline weight. + Pivot rows (validated via ``valid_mask``) receive ``weights``; off-pivot + rows receive the corresponding entry of ``fill_weights`` (shape + ``(n_values,)``). Callers may pre-compute ``indices_array`` and + ``valid_mask`` and pass them in to avoid recomputation when the dispatch + needs the same mask for both filtered pivot extraction and the scatter. """ + if fill_weights.shape != (n_values,): + raise ValueError( + f"Invalid fill_weights shape {fill_weights.shape!r}: " + f"must be ({n_values},)" + ) + # Empty-input early return precedes the length-mismatch check on purpose. if len(indices) == 0 or weights.size == 0: - return np.full(n_values, default_weight, dtype=float) - + return fill_weights.astype(float, copy=True) if len(indices) != weights.size: raise ValueError( - f"Invalid indices/weights values: length mismatch, got {len(indices)} indices but {weights.size} weights" - ) - - weights_array = np.full(n_values, default_weight, dtype=float) - - indices_array = np.array(indices) - mask = (indices_array >= 0) & (indices_array < n_values) - - if not np.any(mask): + f"Invalid indices/weights values: length mismatch, " + f"got {len(indices)} indices but {weights.size} weights" + ) + if indices_array is None: + indices_array = np.asarray(indices, dtype=int) + if valid_mask is None: + valid_mask = (indices_array >= 0) & (indices_array < n_values) + weights_array = fill_weights.astype(float, copy=True) + if not np.any(valid_mask): return weights_array - - valid_indices = indices_array[mask] - weights_array[valid_indices] = weights[mask] + weights_array[indices_array[valid_mask]] = weights[valid_mask] return weights_array @@ -1181,12 +1319,15 @@ def compute_label_weights( indices: list[int], metrics: dict[str, list[float]], weighting_config: dict[str, Any], + *, + logger: Logger, ) -> NDArray[np.floating]: """Compute per-row label importance weights. Returns an array with positive values at pivot ``indices`` (scaled by - strategy) and ``0.0`` elsewhere. Callers must skip invocation when - strategy is ``'none'``; this raises ValueError otherwise. + strategy) and off-pivot values controlled by ``fill_method``. Callers + must skip invocation when strategy is ``'none'``; this raises + ValueError otherwise. """ label_weighting = {**DEFAULTS_LABEL_WEIGHTING, **weighting_config} strategy = label_weighting["strategy"] @@ -1218,11 +1359,52 @@ def compute_label_weights( weights=weights, ) + if weights.size == 0: + return np.zeros(n_values, dtype=float) + + indices_array = np.asarray(indices, dtype=int) + valid_mask = (indices_array >= 0) & (indices_array < n_values) + + fill_method = label_weighting["fill_method"] + + if fill_method == FILL_METHODS[0]: # "zero" + fill_weights = np.zeros(n_values, dtype=float) + elif fill_method == FILL_METHODS[1]: # "epsilon" + eps = label_weighting["fill_epsilon"] + baseline = label_weighting["fill_epsilon_baseline"] + if valid_mask.any(): + pivot_values = weights[valid_mask] + if baseline == FILL_EPSILON_BASELINES[0]: # "mean" + pivot_baseline = float(np.nanmean(pivot_values)) + elif baseline == FILL_EPSILON_BASELINES[1]: # "median" + pivot_baseline = float(np.nanmedian(pivot_values)) + else: + raise ValueError( + f"Invalid fill_epsilon_baseline value {baseline!r}" + ) + if not np.isfinite(pivot_baseline): + pivot_baseline = 0.0 + else: + pivot_baseline = 0.0 + fill_weights = np.full(n_values, eps * pivot_baseline, dtype=float) + elif fill_method == FILL_METHODS[2]: # "gaussian" + fill_weights = _gaussian_fill_weights( + n_values=n_values, + pivot_indices=indices_array[valid_mask], + pivot_weights=weights[valid_mask], + sigma_candles=label_weighting["fill_sigma_candles"], + logger=logger, + ) + else: + raise ValueError(f"Invalid fill_method value {fill_method!r}") + return _scatter_weights( n_values=n_values, indices=indices, weights=weights, - default_weight=0.0, + fill_weights=fill_weights, + indices_array=indices_array, + valid_mask=valid_mask, )