]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
fix(qav3): compute RMSE on a label period basis
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Wed, 12 Mar 2025 00:51:08 +0000 (01:51 +0100)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Wed, 12 Mar 2025 00:51:08 +0000 (01:51 +0100)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
quickadapter/user_data/freqaimodels/LightGBMRegressorQuickAdapterV35.py
quickadapter/user_data/freqaimodels/XGBoostRegressorQuickAdapterV35.py

index 8a2c9c7f43407a3aa1867469de1fd41289cc1a68..d813d3b878a4fe32a4fdc224212685a2a845b39e 100644 (file)
@@ -580,10 +580,6 @@ def period_objective(
         step=candles_step,
     )
     test_window = (test_window // label_period_candles) * label_period_candles
-    if test_window < min_label_period_candles:
-        raise optuna.TrialPruned(
-            f"Adjusted test window {test_window} is too small for minimum label period {min_label_period_candles}."
-        )
     X_test = X_test.iloc[-test_window:]
     y_test = y_test.iloc[-test_window:]
     test_weights = test_weights[-test_window:]
@@ -602,25 +598,23 @@ def period_objective(
     y_pred = model.predict(X_test)
 
     n_windows = len(y_test) // label_period_candles
-    y_test_windows = [
+    y_test = [
         y_test.iloc[i : i + label_period_candles].to_numpy()
         for i in np.arange(0, label_period_candles * n_windows, label_period_candles)
     ]
-    test_weights_windows = [
+    test_weights = [
         test_weights[i : i + label_period_candles]
         for i in np.arange(0, label_period_candles * n_windows, label_period_candles)
     ]
-    y_pred_windows = [
+    y_pred = [
         y_pred[i : i + label_period_candles]
         for i in np.arange(0, label_period_candles * n_windows, label_period_candles)
     ]
-    y_test = [window for window in y_test_windows]
-    test_weights = np.concatenate([window for window in test_weights_windows])
-    y_pred = [window for window in y_pred_windows]
 
-    error = sklearn.metrics.root_mean_squared_error(
-        y_test, y_pred, sample_weight=test_weights
-    )
+    error = 0.0
+    for y_t, y_p, t_w in zip(y_test, y_pred, test_weights):
+        error += sklearn.metrics.root_mean_squared_error(y_t, y_p, sample_weight=t_w)
+    error /= n_windows
 
     return error
 
index 54a7ca3c31f297cd85aea6b021e595bdefaeda10..8feb2c821d85b5e354c8f583078d165367a839f4 100644 (file)
@@ -583,10 +583,6 @@ def period_objective(
         step=candles_step,
     )
     test_window = (test_window // label_period_candles) * label_period_candles
-    if test_window < min_label_period_candles:
-        raise optuna.TrialPruned(
-            f"Adjusted test window {test_window} is too small for minimum label period {min_label_period_candles}."
-        )
     X_test = X_test.iloc[-test_window:]
     y_test = y_test.iloc[-test_window:]
     test_weights = test_weights[-test_window:]
@@ -610,25 +606,23 @@ def period_objective(
     y_pred = model.predict(X_test)
 
     n_windows = len(y_test) // label_period_candles
-    y_test_windows = [
+    y_test = [
         y_test.iloc[i : i + label_period_candles].to_numpy()
         for i in np.arange(0, label_period_candles * n_windows, label_period_candles)
     ]
-    test_weights_windows = [
+    test_weights = [
         test_weights[i : i + label_period_candles]
         for i in np.arange(0, label_period_candles * n_windows, label_period_candles)
     ]
-    y_pred_windows = [
+    y_pred = [
         y_pred[i : i + label_period_candles]
         for i in np.arange(0, label_period_candles * n_windows, label_period_candles)
     ]
-    y_test = [window for window in y_test_windows]
-    test_weights = np.concatenate([window for window in test_weights_windows])
-    y_pred = [window for window in y_pred_windows]
 
-    error = sklearn.metrics.root_mean_squared_error(
-        y_test, y_pred, sample_weight=test_weights
-    )
+    error = 0.0
+    for y_t, y_p, t_w in zip(y_test, y_pred, test_weights):
+        error += sklearn.metrics.root_mean_squared_error(y_t, y_p, sample_weight=t_w)
+    error /= n_windows
 
     return error