]> Piment Noir Git Repositories - freqai-strategies.git/commitdiff
refactor(reforcexy): ensure envs are created with consistent data
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Sun, 21 Sep 2025 20:11:32 +0000 (22:11 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Sun, 21 Sep 2025 20:11:32 +0000 (22:11 +0200)
     snapshots

Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
ReforceXY/user_data/freqaimodels/ReforceXY.py

index 83525a740fec9a770971cf2301415eb0d793734c..e2410194a0f7a2e6e0364595181fb73b83aba795 100644 (file)
@@ -285,7 +285,13 @@ class ReforceXY(BaseReinforcementLearningModel):
 
         logger.info("Populating environments: %s", self.n_envs)
         self.train_env, self.eval_env = self._get_train_and_eval_environments(
-            train_df, test_df, dk, prices_train, prices_test, seed, env_dict
+            dk,
+            train_df=train_df,
+            test_df=test_df,
+            prices_train=prices_train,
+            prices_test=prices_test,
+            seed=seed,
+            env_info=env_dict,
         )
 
     def get_model_params(self) -> Dict[str, Any]:
@@ -521,7 +527,7 @@ class ReforceXY(BaseReinforcementLearningModel):
 
         start_time = time.time()
         if self.hyperopt:
-            best_trial_params = self.study(train_df, test_df, total_timesteps, dk)
+            best_trial_params = self.study(dk, total_timesteps)
             if best_trial_params is None:
                 logger.error(
                     "Hyperopt failed. Using default configured model params instead"
@@ -713,11 +719,7 @@ class ReforceXY(BaseReinforcementLearningModel):
             return False
 
     def study(
-        self,
-        train_df: DataFrame,
-        test_df: DataFrame,
-        total_timesteps: int,
-        dk: FreqaiDataKitchen,
+        self, dk: FreqaiDataKitchen, total_timesteps: int
     ) -> Optional[Dict[str, Any]]:
         """
         Runs hyperparameter optimization using Optuna and returns the best hyperparameters found merged with the user defined parameters
@@ -760,9 +762,7 @@ class ReforceXY(BaseReinforcementLearningModel):
         start_time = time.time()
         try:
             study.optimize(
-                lambda trial: self.objective(
-                    trial, train_df, test_df, total_timesteps, dk
-                ),
+                lambda trial: self.objective(trial, dk, total_timesteps),
                 n_trials=self.optuna_n_trials,
                 timeout=(
                     hours_to_seconds(self.optuna_timeout_hours)
@@ -884,16 +884,23 @@ class ReforceXY(BaseReinforcementLearningModel):
 
     def _get_train_and_eval_environments(
         self,
-        train_df: DataFrame,
-        test_df: DataFrame,
         dk: FreqaiDataKitchen,
+        train_df: Optional[DataFrame] = None,
+        test_df: Optional[DataFrame] = None,
         prices_train: Optional[DataFrame] = None,
         prices_test: Optional[DataFrame] = None,
         seed: Optional[int] = None,
         env_info: Optional[Dict[str, Any]] = None,
         trial: Optional[Trial] = None,
     ) -> Tuple[BaseEnvironment, BaseEnvironment]:
-        if prices_train is None or prices_test is None:
+        if (
+            train_df is None
+            or test_df is None
+            or prices_train is None
+            or prices_test is None
+        ):
+            train_df = dk.data_dictionary["train_features"]
+            test_df = dk.data_dictionary["test_features"]
             prices_train, prices_test = self.build_ohlc_price_dataframes(
                 dk.data_dictionary, dk.pair, dk
             )
@@ -951,12 +958,7 @@ class ReforceXY(BaseReinforcementLearningModel):
         return train_env, eval_env
 
     def objective(
-        self,
-        trial: Trial,
-        train_df: DataFrame,
-        test_df: DataFrame,
-        total_timesteps: int,
-        dk: FreqaiDataKitchen,
+        self, trial: Trial, dk: FreqaiDataKitchen, total_timesteps: int
     ) -> float:
         """
         Defines a single trial for hyperparameter optimization using Optuna
@@ -1004,9 +1006,7 @@ class ReforceXY(BaseReinforcementLearningModel):
         else:
             tensorboard_log_path = None
 
-        train_env, eval_env = self._get_train_and_eval_environments(
-            train_df, test_df, dk, trial=trial
-        )
+        train_env, eval_env = self._get_train_and_eval_environments(dk, trial=trial)
 
         model = self.MODELCLASS(
             self.policy_type,