From: Jérôme Benoit Date: Sun, 15 Feb 2026 12:31:58 +0000 (+0100) Subject: feat(ReforceXY): add gpu_memory_fraction tunable X-Git-Url: https://git.piment-noir.org/?a=commitdiff_plain;h=939f1343e3f217d622aeca17f250762aee0400ed;p=freqai-strategies.git feat(ReforceXY): add gpu_memory_fraction tunable --- diff --git a/ReforceXY/user_data/freqaimodels/ReforceXY.py b/ReforceXY/user_data/freqaimodels/ReforceXY.py index baeed3b..b62f707 100644 --- a/ReforceXY/user_data/freqaimodels/ReforceXY.py +++ b/ReforceXY/user_data/freqaimodels/ReforceXY.py @@ -109,6 +109,10 @@ class ReforceXY(BaseReinforcementLearningModel): ... "freqai": { ... + "model_training_parameters": { + "device": "auto", // PyTorch device (auto|cpu|cuda|cuda:0) + "gpu_memory_fraction": null, // GPU VRAM fraction limit per process (0.0, 1.0], null disables + }, "rl_config": { ... "n_envs": 1, // Number of DummyVecEnv or SubProcVecEnv training environments @@ -328,6 +332,35 @@ class ReforceXY(BaseReinforcementLearningModel): ], ] = {} self.unset_unsupported() + self._configure_gpu_memory() + + def _configure_gpu_memory(self) -> None: + """ + Configure GPU memory fraction limit from model_training_parameters. + Called after config validation, before any CUDA operations. + """ + gpu_memory_fraction: Optional[float] = self.model_training_parameters.get( + "gpu_memory_fraction" + ) + if gpu_memory_fraction is None: + return + if not th.cuda.is_available(): + logger.warning( + "Config [global]: gpu_memory_fraction=%.2f ignored; CUDA not available", + gpu_memory_fraction, + ) + return + if not 0.0 < gpu_memory_fraction <= 1.0: + logger.warning( + "Config [global]: gpu_memory_fraction=%.2f invalid; must be in (0.0, 1.0]", + gpu_memory_fraction, + ) + return + th.cuda.set_per_process_memory_fraction(gpu_memory_fraction, device=0) + logger.info( + "Config [global]: gpu_memory_fraction=%.2f applied", + gpu_memory_fraction, + ) @staticmethod def _normalize_position(position: Any) -> Positions: @@ -356,7 +389,7 @@ class ReforceXY(BaseReinforcementLearningModel): ) if cache_key in ReforceXY._action_masks_cache: logger.debug( - "ActionMask: cache hit for can_short=%s position=%s", + "ActionMask [global]: cache hit for can_short=%s position=%s", can_short, position.name, ) @@ -376,7 +409,7 @@ class ReforceXY(BaseReinforcementLearningModel): ReforceXY._action_masks_cache[cache_key] = action_masks logger.debug( - "ActionMask: cache miss for can_short=%s position=%s, computed=%s", + "ActionMask [global]: cache miss for can_short=%s position=%s; computed=%s", can_short, position.name, action_masks,