)
pivots_sizes = [trial.values[1] for trial in best_trials]
- quantile_pivots_size = np.quantile(
- pivots_sizes, self.ft_params.get("label_quantile", 0.5)
- )
+ label_quantile = float(self.ft_params.get("label_quantile", 0.5))
+ if not (0.0 <= label_quantile <= 1.0):
+ raise ValueError("label_quantile must be between 0.0 and 1.0")
+ quantile_pivots_size = np.quantile(pivots_sizes, label_quantile)
direction0 = study.directions[0]
nearest_below_quantile = None
for trial in best_trials:
pivots_size = trial.values[1]
-
if pivots_size >= quantile_pivots_size:
if is_better_above_candidate(
trial, nearest_above_quantile, direction0, quantile_pivots_size
if not nearest_below_quantile:
return nearest_above_quantile
- if direction0 == optuna.study.StudyDirection.MAXIMIZE:
- if np.isclose(
- nearest_above_quantile.values[0], nearest_below_quantile.values[0]
- ):
- above_quantile_distance = (
- nearest_above_quantile.values[1] - quantile_pivots_size
- )
- below_quantile_distance = (
- quantile_pivots_size - nearest_below_quantile.values[1]
- )
-
- if abs(above_quantile_distance) < abs(below_quantile_distance):
- return nearest_above_quantile
- elif abs(above_quantile_distance) > abs(below_quantile_distance):
- return nearest_below_quantile
- else:
- direction1 = study.directions[1]
- if direction1 == optuna.study.StudyDirection.MAXIMIZE:
- return max(
- [nearest_above_quantile, nearest_below_quantile],
- key=lambda trial: trial.values[1],
- )
- else:
- return min(
- [nearest_above_quantile, nearest_below_quantile],
- key=lambda trial: trial.values[1],
- )
+ def tie_break_selection(
+ above: optuna.trial.FrozenTrial,
+ below: optuna.trial.FrozenTrial,
+ direction: optuna.study.StudyDirection,
+ ) -> optuna.trial.FrozenTrial:
+ above_quantile_distance = abs(above.values[1] - quantile_pivots_size)
+ below_quantile_distance = abs(quantile_pivots_size - below.values[1])
+
+ if above_quantile_distance < below_quantile_distance:
+ return above
+ elif above_quantile_distance > below_quantile_distance:
+ return below
else:
- return (
- nearest_above_quantile
- if nearest_above_quantile.values[0]
- > nearest_below_quantile.values[0]
- else nearest_below_quantile
- )
- else:
- if np.isclose(
- nearest_above_quantile.values[0], nearest_below_quantile.values[0]
- ):
- above_quantile_distance = (
- nearest_above_quantile.values[1] - quantile_pivots_size
- )
- below_quantile_distance = (
- quantile_pivots_size - nearest_below_quantile.values[1]
- )
-
- if abs(above_quantile_distance) < abs(below_quantile_distance):
- return nearest_above_quantile
- elif abs(above_quantile_distance) > abs(below_quantile_distance):
- return nearest_below_quantile
+ if direction == optuna.study.StudyDirection.MAXIMIZE:
+ return max([above, below], key=lambda trial: trial.values[1])
else:
- direction1 = study.directions[1]
- if direction1 == optuna.study.StudyDirection.MAXIMIZE:
- return max(
- [nearest_above_quantile, nearest_below_quantile],
- key=lambda trial: trial.values[1],
- )
- else:
- return min(
- [nearest_above_quantile, nearest_below_quantile],
- key=lambda trial: trial.values[1],
- )
+ return min([above, below], key=lambda trial: trial.values[1])
+
+ def final_selection(
+ above: optuna.trial.FrozenTrial,
+ below: optuna.trial.FrozenTrial,
+ direction: optuna.study.StudyDirection,
+ ) -> optuna.trial.FrozenTrial:
+ if direction == optuna.study.StudyDirection.MAXIMIZE:
+ return above if above.values[0] > below.values[0] else below
else:
- return (
- nearest_above_quantile
- if nearest_above_quantile.values[0]
- < nearest_below_quantile.values[0]
- else nearest_below_quantile
- )
+ return above if above.values[0] < below.values[0] else below
+
+ if np.isclose(
+ nearest_above_quantile.values[0], nearest_below_quantile.values[0]
+ ):
+ return tie_break_selection(
+ nearest_above_quantile,
+ nearest_below_quantile,
+ study.directions[1],
+ )
+ else:
+ return final_selection(
+ nearest_above_quantile, nearest_below_quantile, direction0
+ )
elif label_trials_selection == "chebyshev":
objective_values = np.array([trial.values for trial in best_trials]).T
normalized_values_list = []