From 619a3c7e020d5dc737b7b487f4708b8c29fdc511 Mon Sep 17 00:00:00 2001 From: =?utf8?q?J=C3=A9r=C3=B4me=20Benoit?= Date: Tue, 24 Mar 2026 12:03:17 +0100 Subject: [PATCH] test(ocpp-server): add graceful shutdown tests covering signal handling MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Cover SIGINT/SIGTERM handlers, double-signal force quit, Windows fallback via call_soon_threadsafe, and shutdown timeout warning. Coverage: 80% → 90%. --- tests/ocpp-server/test_server.py | 177 +++++++++++++++++++++++++++++++ 1 file changed, 177 insertions(+) diff --git a/tests/ocpp-server/test_server.py b/tests/ocpp-server/test_server.py index d53f8b65..0ac53d1f 100644 --- a/tests/ocpp-server/test_server.py +++ b/tests/ocpp-server/test_server.py @@ -1,7 +1,9 @@ """Tests for the OCPP 2.0.1 mock server.""" import argparse +import contextlib import logging +import signal from typing import ClassVar from unittest.mock import AsyncMock, MagicMock, patch @@ -50,6 +52,7 @@ from server import ( ServerConfig, _random_request_id, check_positive_number, + main, on_connect, ) @@ -1120,3 +1123,177 @@ class TestSendCommand: with patch("server.Timer") as MockTimer: await charge_point.send_command(Action.clear_cache, delay=1.0, period=None) MockTimer.assert_not_called() + + +# --- Helpers for main() shutdown tests --- + +DEFAULT_MAIN_ARGS = { + "command": None, + "delay": None, + "period": None, + "host": "127.0.0.1", + "port": 9000, + "boot_status": RegistrationStatusEnumType.accepted, + "total_cost": 10.0, + "auth_mode": "normal", + "whitelist": ["valid_token"], + "blacklist": ["blocked_token"], + "offline": False, +} + + +def _mock_args(**overrides): + return argparse.Namespace(**{**DEFAULT_MAIN_ARGS, **overrides}) + + +@pytest.fixture +def main_mocks(): + """Provide mock loop, server, and shutdown event for main() tests.""" + mock_loop = MagicMock() + signal_handlers: dict[int, tuple] = {} + + def _capture_handler(sig, callback, *args): + signal_handlers[sig] = (callback, args) + + mock_loop.add_signal_handler = MagicMock(side_effect=_capture_handler) + + mock_server = AsyncMock() + mock_server.close = MagicMock() + mock_server.wait_closed = AsyncMock() + + mock_event = MagicMock() + mock_event.set = MagicMock() + + return mock_loop, mock_server, mock_event, signal_handlers + + +def _patch_main(mock_loop, mock_server, mock_event, args=None): + if args is None: + args = _mock_args() + + mock_serve_cm = AsyncMock() + mock_serve_cm.__aenter__ = AsyncMock(return_value=mock_server) + mock_serve_cm.__aexit__ = AsyncMock(return_value=False) + + return contextlib.ExitStack(), [ + patch( + "server.argparse.ArgumentParser.parse_known_args", + return_value=(MagicMock(command=args.command), []), + ), + patch("server.argparse.ArgumentParser.parse_args", return_value=args), + patch("server.websockets.serve", return_value=mock_serve_cm), + patch("server.asyncio.get_running_loop", return_value=mock_loop), + patch("server.asyncio.Event", return_value=mock_event), + ] + + +def _enter_patches(stack, patches): + for p in patches: + stack.enter_context(p) + return stack + + +class TestMainGracefulShutdown: + """Tests for the main() graceful shutdown logic.""" + + async def test_first_signal_closes_server_and_sets_event(self, main_mocks): + mock_loop, mock_server, mock_event, signal_handlers = main_mocks + + async def _fire_sigint(): + handler, args = signal_handlers[signal.SIGINT] + handler(*args) + + mock_event.wait = AsyncMock(side_effect=_fire_sigint) + + stack, patches = _patch_main(mock_loop, mock_server, mock_event) + with _enter_patches(stack, patches): + await main() + + mock_server.close.assert_called_once() + mock_event.set.assert_called_once() + mock_server.wait_closed.assert_called_once() + + async def test_sigterm_triggers_same_shutdown(self, main_mocks): + mock_loop, mock_server, mock_event, signal_handlers = main_mocks + + async def _fire_sigterm(): + handler, args = signal_handlers[signal.SIGTERM] + handler(*args) + + mock_event.wait = AsyncMock(side_effect=_fire_sigterm) + + stack, patches = _patch_main(mock_loop, mock_server, mock_event) + with _enter_patches(stack, patches): + await main() + + mock_server.close.assert_called_once() + mock_event.set.assert_called_once() + + async def test_second_signal_forces_exit(self, main_mocks): + mock_loop, mock_server, mock_event, signal_handlers = main_mocks + + async def _fire_twice(): + handler, args = signal_handlers[signal.SIGINT] + handler(*args) + handler(*args) + + mock_event.wait = AsyncMock(side_effect=_fire_twice) + + stack, patches = _patch_main(mock_loop, mock_server, mock_event) + with _enter_patches(stack, patches): + with pytest.raises(SystemExit) as exc_info: + await main() + assert exc_info.value.code == 128 + signal.SIGINT.value + + async def test_shutdown_timeout_logs_warning(self, main_mocks, caplog): + mock_loop, mock_server, mock_event, signal_handlers = main_mocks + mock_server.wait_closed = AsyncMock(side_effect=TimeoutError) + + async def _fire_sigint(): + handler, args = signal_handlers[signal.SIGINT] + handler(*args) + + mock_event.wait = AsyncMock(side_effect=_fire_sigint) + caplog.set_level(logging.WARNING) + + stack, patches = _patch_main(mock_loop, mock_server, mock_event) + with _enter_patches(stack, patches): + await main() + + assert "timed out" in caplog.text.lower() + + async def test_windows_fallback_registers_signal_handlers(self, main_mocks): + mock_loop, mock_server, mock_event, _ = main_mocks + mock_loop.add_signal_handler = MagicMock(side_effect=NotImplementedError) + mock_event.wait = AsyncMock() + + mock_signal_fn = MagicMock() + stack, patches = _patch_main(mock_loop, mock_server, mock_event) + patches.append(patch("server.signal.signal", mock_signal_fn)) + with _enter_patches(stack, patches): + await main() + + assert mock_signal_fn.call_count == 2 + registered_signals = {call.args[0] for call in mock_signal_fn.call_args_list} + assert registered_signals == {signal.SIGINT, signal.SIGTERM} + + async def test_windows_handler_schedules_via_call_soon_threadsafe(self, main_mocks): + mock_loop, mock_server, mock_event, _ = main_mocks + mock_loop.add_signal_handler = MagicMock(side_effect=NotImplementedError) + mock_loop.call_soon_threadsafe = MagicMock() + mock_event.wait = AsyncMock() + + captured_handlers: dict[int, object] = {} + + def _capture_signal(sig, handler): + captured_handlers[sig] = handler + + stack, patches = _patch_main(mock_loop, mock_server, mock_event) + patches.append(patch("server.signal.signal", side_effect=_capture_signal)) + with _enter_patches(stack, patches): + await main() + + sigint_handler = captured_handlers[signal.SIGINT] + assert callable(sigint_handler) + sigint_handler(signal.SIGINT.value, None) + mock_loop.call_soon_threadsafe.assert_called_once() -- 2.43.0