isna(trade_duration) or trade_duration <= 0
)
- def get_trade_interpolation_natr(
+ def get_trade_weighted_interpolation_natr(
self, df: DataFrame, trade: Trade
) -> Optional[float]:
label_natr = df.get("natr_label_period_candles")
if isna(current_natr) or current_natr < 0:
return None
median_natr = trade_label_natr.median()
- interpolation_values = [current_natr, median_natr, entry_natr]
+
+ entry_quantile = calculate_quantile(trade_label_natr.to_numpy(), entry_natr)
+ current_quantile = calculate_quantile(trade_label_natr.to_numpy(), current_natr)
+ median_quantile = calculate_quantile(trade_label_natr.to_numpy(), median_natr)
+
+ if isna(entry_quantile) or isna(current_quantile) or isna(median_quantile):
+ return None
+
+ def calculate_weight(
+ quantile: float,
+ min_weight: float = 0.0,
+ max_weight: float = 1.0,
+ steepness: float = 1.5,
+ ) -> float:
+ normalized_distance_from_center = abs(quantile - 0.5) * 2.0
+ return (
+ min_weight
+ + (max_weight - min_weight) * normalized_distance_from_center**steepness
+ )
+
+ entry_weight = calculate_weight(entry_quantile)
+ current_weight = calculate_weight(current_quantile)
+ median_weight = calculate_weight(median_quantile)
+
+ total_weight = entry_weight + current_weight + median_weight
+ if np.isclose(total_weight, 0):
+ return None
+ entry_weight /= total_weight
+ current_weight /= total_weight
+ median_weight /= total_weight
+
+ return (
+ entry_natr * entry_weight
+ + current_natr * current_weight
+ + median_natr * median_weight
+ )
+
+ def get_trade_interpolation_natr(
+ self, df: DataFrame, trade: Trade
+ ) -> Optional[float]:
+ label_natr = df.get("natr_label_period_candles")
+ if label_natr is None or label_natr.empty:
+ return None
+ dates = df.get("date")
+ if dates is None or dates.empty:
+ return None
+ entry_date = self.get_trade_entry_date(trade)
+ trade_label_natr = label_natr[dates >= entry_date]
+ if trade_label_natr.empty:
+ return None
+ entry_natr = trade_label_natr.iloc[0]
+ if isna(entry_natr) or entry_natr < 0:
+ return None
+ if len(trade_label_natr) == 1:
+ return entry_natr
+ current_natr = trade_label_natr.iloc[-1]
+ if isna(current_natr) or current_natr < 0:
+ return None
trade_volatility_quantile = calculate_quantile(
trade_label_natr.to_numpy(), entry_natr
)
return None
return np.interp(
trade_volatility_quantile,
- np.linspace(0.0, 1.0, len(interpolation_values)),
- interpolation_values,
+ [0.0, 1.0],
+ [current_natr, entry_natr],
)
def get_trade_moving_average_natr(
)
if trade_price_target == "interpolation":
return self.get_trade_interpolation_natr(df, trade)
+ elif trade_price_target == "weighted_interpolation":
+ return self.get_trade_weighted_interpolation_natr(df, trade)
elif trade_price_target == "moving_average":
return self.get_trade_moving_average_natr(
df, trade.pair, trade_duration_candles
)
else:
raise ValueError(
- f"Invalid trade_price_target: {trade_price_target}. Expected 'interpolation' or 'moving_average'."
+ f"Invalid trade_price_target: {trade_price_target}. Expected 'interpolation', 'weighted_interpolation' or 'moving_average'."
)
def get_stoploss_distance(