From f3b4d70f110134f0b32cccfc439e0876fceae38d Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Tue, 24 Mar 2026 12:07:03 +0100 Subject: [PATCH] refactor(ocpp-server): harmonize shutdown tests with project conventions Move fixture near other fixtures, replace _patch_main tuple return with contextmanager, parametrize SIGINT/SIGTERM, remove dead code helpers, merge two test classes into one. --- tests/ocpp-server/test_server.py | 178 ++++++++++++++----------------- 1 file changed, 78 insertions(+), 100 deletions(-) diff --git a/tests/ocpp-server/test_server.py b/tests/ocpp-server/test_server.py index 0ac53d1f..175e6f0d 100644 --- a/tests/ocpp-server/test_server.py +++ b/tests/ocpp-server/test_server.py @@ -139,6 +139,63 @@ def command_charge_point(mock_connection): return cp +@pytest.fixture +def main_mocks(): + """Provide mock loop, server, shutdown event, and signal capture.""" + mock_loop = MagicMock() + signal_handlers: dict[int, tuple] = {} + + def _capture_handler(sig, callback, *args): + signal_handlers[sig] = (callback, args) + + mock_loop.add_signal_handler = MagicMock(side_effect=_capture_handler) + + mock_server = AsyncMock() + mock_server.close = MagicMock() + mock_server.wait_closed = AsyncMock() + + mock_event = MagicMock() + mock_event.set = MagicMock() + + return mock_loop, mock_server, mock_event, signal_handlers + + +@contextlib.contextmanager +def _patch_main(mock_loop, mock_server, mock_event, extra_patches=None): + args = argparse.Namespace( + command=None, + delay=None, + period=None, + host="127.0.0.1", + port=9000, + boot_status=RegistrationStatusEnumType.accepted, + total_cost=10.0, + auth_mode="normal", + whitelist=["valid_token"], + blacklist=["blocked_token"], + offline=False, + ) + mock_serve_cm = AsyncMock() + mock_serve_cm.__aenter__ = AsyncMock(return_value=mock_server) + mock_serve_cm.__aexit__ = AsyncMock(return_value=False) + + patches = [ + patch( + "server.argparse.ArgumentParser.parse_known_args", + return_value=(MagicMock(command=args.command), []), + ), + patch("server.argparse.ArgumentParser.parse_args", return_value=args), + patch("server.websockets.serve", return_value=mock_serve_cm), + patch("server.asyncio.get_running_loop", return_value=mock_loop), + patch("server.asyncio.Event", return_value=mock_event), + *(extra_patches or []), + ] + with contextlib.ExitStack() as stack: + for p in patches: + stack.enter_context(p) + yield + + class TestCheckPositiveNumber: """Tests for the check_positive_number argument validator.""" @@ -1125,110 +1182,26 @@ class TestSendCommand: MockTimer.assert_not_called() -# --- Helpers for main() shutdown tests --- - -DEFAULT_MAIN_ARGS = { - "command": None, - "delay": None, - "period": None, - "host": "127.0.0.1", - "port": 9000, - "boot_status": RegistrationStatusEnumType.accepted, - "total_cost": 10.0, - "auth_mode": "normal", - "whitelist": ["valid_token"], - "blacklist": ["blocked_token"], - "offline": False, -} - - -def _mock_args(**overrides): - return argparse.Namespace(**{**DEFAULT_MAIN_ARGS, **overrides}) - - -@pytest.fixture -def main_mocks(): - """Provide mock loop, server, and shutdown event for main() tests.""" - mock_loop = MagicMock() - signal_handlers: dict[int, tuple] = {} - - def _capture_handler(sig, callback, *args): - signal_handlers[sig] = (callback, args) - - mock_loop.add_signal_handler = MagicMock(side_effect=_capture_handler) - - mock_server = AsyncMock() - mock_server.close = MagicMock() - mock_server.wait_closed = AsyncMock() - - mock_event = MagicMock() - mock_event.set = MagicMock() - - return mock_loop, mock_server, mock_event, signal_handlers - - -def _patch_main(mock_loop, mock_server, mock_event, args=None): - if args is None: - args = _mock_args() - - mock_serve_cm = AsyncMock() - mock_serve_cm.__aenter__ = AsyncMock(return_value=mock_server) - mock_serve_cm.__aexit__ = AsyncMock(return_value=False) - - return contextlib.ExitStack(), [ - patch( - "server.argparse.ArgumentParser.parse_known_args", - return_value=(MagicMock(command=args.command), []), - ), - patch("server.argparse.ArgumentParser.parse_args", return_value=args), - patch("server.websockets.serve", return_value=mock_serve_cm), - patch("server.asyncio.get_running_loop", return_value=mock_loop), - patch("server.asyncio.Event", return_value=mock_event), - ] - - -def _enter_patches(stack, patches): - for p in patches: - stack.enter_context(p) - return stack - - class TestMainGracefulShutdown: """Tests for the main() graceful shutdown logic.""" - async def test_first_signal_closes_server_and_sets_event(self, main_mocks): + @pytest.mark.parametrize("sig", [signal.SIGINT, signal.SIGTERM]) + async def test_first_signal_closes_server_and_sets_event(self, main_mocks, sig): mock_loop, mock_server, mock_event, signal_handlers = main_mocks - async def _fire_sigint(): - handler, args = signal_handlers[signal.SIGINT] + async def _fire_signal(): + handler, args = signal_handlers[sig] handler(*args) - mock_event.wait = AsyncMock(side_effect=_fire_sigint) + mock_event.wait = AsyncMock(side_effect=_fire_signal) - stack, patches = _patch_main(mock_loop, mock_server, mock_event) - with _enter_patches(stack, patches): + with _patch_main(mock_loop, mock_server, mock_event): await main() mock_server.close.assert_called_once() mock_event.set.assert_called_once() mock_server.wait_closed.assert_called_once() - async def test_sigterm_triggers_same_shutdown(self, main_mocks): - mock_loop, mock_server, mock_event, signal_handlers = main_mocks - - async def _fire_sigterm(): - handler, args = signal_handlers[signal.SIGTERM] - handler(*args) - - mock_event.wait = AsyncMock(side_effect=_fire_sigterm) - - stack, patches = _patch_main(mock_loop, mock_server, mock_event) - with _enter_patches(stack, patches): - await main() - - mock_server.close.assert_called_once() - mock_event.set.assert_called_once() - async def test_second_signal_forces_exit(self, main_mocks): mock_loop, mock_server, mock_event, signal_handlers = main_mocks @@ -1239,8 +1212,7 @@ class TestMainGracefulShutdown: mock_event.wait = AsyncMock(side_effect=_fire_twice) - stack, patches = _patch_main(mock_loop, mock_server, mock_event) - with _enter_patches(stack, patches): + with _patch_main(mock_loop, mock_server, mock_event): with pytest.raises(SystemExit) as exc_info: await main() assert exc_info.value.code == 128 + signal.SIGINT.value @@ -1256,8 +1228,7 @@ class TestMainGracefulShutdown: mock_event.wait = AsyncMock(side_effect=_fire_sigint) caplog.set_level(logging.WARNING) - stack, patches = _patch_main(mock_loop, mock_server, mock_event) - with _enter_patches(stack, patches): + with _patch_main(mock_loop, mock_server, mock_event): await main() assert "timed out" in caplog.text.lower() @@ -1268,9 +1239,13 @@ class TestMainGracefulShutdown: mock_event.wait = AsyncMock() mock_signal_fn = MagicMock() - stack, patches = _patch_main(mock_loop, mock_server, mock_event) - patches.append(patch("server.signal.signal", mock_signal_fn)) - with _enter_patches(stack, patches): + + with _patch_main( + mock_loop, + mock_server, + mock_event, + extra_patches=[patch("server.signal.signal", mock_signal_fn)], + ): await main() assert mock_signal_fn.call_count == 2 @@ -1288,9 +1263,12 @@ class TestMainGracefulShutdown: def _capture_signal(sig, handler): captured_handlers[sig] = handler - stack, patches = _patch_main(mock_loop, mock_server, mock_event) - patches.append(patch("server.signal.signal", side_effect=_capture_signal)) - with _enter_patches(stack, patches): + with _patch_main( + mock_loop, + mock_server, + mock_event, + extra_patches=[patch("server.signal.signal", side_effect=_capture_signal)], + ): await main() sigint_handler = captured_handlers[signal.SIGINT] -- 2.43.0