return cp
+@pytest.fixture
+def main_mocks():
+ """Provide mock loop, server, shutdown event, and signal capture."""
+ mock_loop = MagicMock()
+ signal_handlers: dict[int, tuple] = {}
+
+ def _capture_handler(sig, callback, *args):
+ signal_handlers[sig] = (callback, args)
+
+ mock_loop.add_signal_handler = MagicMock(side_effect=_capture_handler)
+
+ mock_server = AsyncMock()
+ mock_server.close = MagicMock()
+ mock_server.wait_closed = AsyncMock()
+
+ mock_event = MagicMock()
+ mock_event.set = MagicMock()
+
+ return mock_loop, mock_server, mock_event, signal_handlers
+
+
+@contextlib.contextmanager
+def _patch_main(mock_loop, mock_server, mock_event, extra_patches=None):
+ args = argparse.Namespace(
+ command=None,
+ delay=None,
+ period=None,
+ host="127.0.0.1",
+ port=9000,
+ boot_status=RegistrationStatusEnumType.accepted,
+ total_cost=10.0,
+ auth_mode="normal",
+ whitelist=["valid_token"],
+ blacklist=["blocked_token"],
+ offline=False,
+ )
+ mock_serve_cm = AsyncMock()
+ mock_serve_cm.__aenter__ = AsyncMock(return_value=mock_server)
+ mock_serve_cm.__aexit__ = AsyncMock(return_value=False)
+
+ patches = [
+ patch(
+ "server.argparse.ArgumentParser.parse_known_args",
+ return_value=(MagicMock(command=args.command), []),
+ ),
+ patch("server.argparse.ArgumentParser.parse_args", return_value=args),
+ patch("server.websockets.serve", return_value=mock_serve_cm),
+ patch("server.asyncio.get_running_loop", return_value=mock_loop),
+ patch("server.asyncio.Event", return_value=mock_event),
+ *(extra_patches or []),
+ ]
+ with contextlib.ExitStack() as stack:
+ for p in patches:
+ stack.enter_context(p)
+ yield
+
+
class TestCheckPositiveNumber:
"""Tests for the check_positive_number argument validator."""
MockTimer.assert_not_called()
-# --- Helpers for main() shutdown tests ---
-
-DEFAULT_MAIN_ARGS = {
- "command": None,
- "delay": None,
- "period": None,
- "host": "127.0.0.1",
- "port": 9000,
- "boot_status": RegistrationStatusEnumType.accepted,
- "total_cost": 10.0,
- "auth_mode": "normal",
- "whitelist": ["valid_token"],
- "blacklist": ["blocked_token"],
- "offline": False,
-}
-
-
-def _mock_args(**overrides):
- return argparse.Namespace(**{**DEFAULT_MAIN_ARGS, **overrides})
-
-
-@pytest.fixture
-def main_mocks():
- """Provide mock loop, server, and shutdown event for main() tests."""
- mock_loop = MagicMock()
- signal_handlers: dict[int, tuple] = {}
-
- def _capture_handler(sig, callback, *args):
- signal_handlers[sig] = (callback, args)
-
- mock_loop.add_signal_handler = MagicMock(side_effect=_capture_handler)
-
- mock_server = AsyncMock()
- mock_server.close = MagicMock()
- mock_server.wait_closed = AsyncMock()
-
- mock_event = MagicMock()
- mock_event.set = MagicMock()
-
- return mock_loop, mock_server, mock_event, signal_handlers
-
-
-def _patch_main(mock_loop, mock_server, mock_event, args=None):
- if args is None:
- args = _mock_args()
-
- mock_serve_cm = AsyncMock()
- mock_serve_cm.__aenter__ = AsyncMock(return_value=mock_server)
- mock_serve_cm.__aexit__ = AsyncMock(return_value=False)
-
- return contextlib.ExitStack(), [
- patch(
- "server.argparse.ArgumentParser.parse_known_args",
- return_value=(MagicMock(command=args.command), []),
- ),
- patch("server.argparse.ArgumentParser.parse_args", return_value=args),
- patch("server.websockets.serve", return_value=mock_serve_cm),
- patch("server.asyncio.get_running_loop", return_value=mock_loop),
- patch("server.asyncio.Event", return_value=mock_event),
- ]
-
-
-def _enter_patches(stack, patches):
- for p in patches:
- stack.enter_context(p)
- return stack
-
-
class TestMainGracefulShutdown:
"""Tests for the main() graceful shutdown logic."""
- async def test_first_signal_closes_server_and_sets_event(self, main_mocks):
+ @pytest.mark.parametrize("sig", [signal.SIGINT, signal.SIGTERM])
+ async def test_first_signal_closes_server_and_sets_event(self, main_mocks, sig):
mock_loop, mock_server, mock_event, signal_handlers = main_mocks
- async def _fire_sigint():
- handler, args = signal_handlers[signal.SIGINT]
+ async def _fire_signal():
+ handler, args = signal_handlers[sig]
handler(*args)
- mock_event.wait = AsyncMock(side_effect=_fire_sigint)
+ mock_event.wait = AsyncMock(side_effect=_fire_signal)
- stack, patches = _patch_main(mock_loop, mock_server, mock_event)
- with _enter_patches(stack, patches):
+ with _patch_main(mock_loop, mock_server, mock_event):
await main()
mock_server.close.assert_called_once()
mock_event.set.assert_called_once()
mock_server.wait_closed.assert_called_once()
- async def test_sigterm_triggers_same_shutdown(self, main_mocks):
- mock_loop, mock_server, mock_event, signal_handlers = main_mocks
-
- async def _fire_sigterm():
- handler, args = signal_handlers[signal.SIGTERM]
- handler(*args)
-
- mock_event.wait = AsyncMock(side_effect=_fire_sigterm)
-
- stack, patches = _patch_main(mock_loop, mock_server, mock_event)
- with _enter_patches(stack, patches):
- await main()
-
- mock_server.close.assert_called_once()
- mock_event.set.assert_called_once()
-
async def test_second_signal_forces_exit(self, main_mocks):
mock_loop, mock_server, mock_event, signal_handlers = main_mocks
mock_event.wait = AsyncMock(side_effect=_fire_twice)
- stack, patches = _patch_main(mock_loop, mock_server, mock_event)
- with _enter_patches(stack, patches):
+ with _patch_main(mock_loop, mock_server, mock_event):
with pytest.raises(SystemExit) as exc_info:
await main()
assert exc_info.value.code == 128 + signal.SIGINT.value
mock_event.wait = AsyncMock(side_effect=_fire_sigint)
caplog.set_level(logging.WARNING)
- stack, patches = _patch_main(mock_loop, mock_server, mock_event)
- with _enter_patches(stack, patches):
+ with _patch_main(mock_loop, mock_server, mock_event):
await main()
assert "timed out" in caplog.text.lower()
mock_event.wait = AsyncMock()
mock_signal_fn = MagicMock()
- stack, patches = _patch_main(mock_loop, mock_server, mock_event)
- patches.append(patch("server.signal.signal", mock_signal_fn))
- with _enter_patches(stack, patches):
+
+ with _patch_main(
+ mock_loop,
+ mock_server,
+ mock_event,
+ extra_patches=[patch("server.signal.signal", mock_signal_fn)],
+ ):
await main()
assert mock_signal_fn.call_count == 2
def _capture_signal(sig, handler):
captured_handlers[sig] = handler
- stack, patches = _patch_main(mock_loop, mock_server, mock_event)
- patches.append(patch("server.signal.signal", side_effect=_capture_signal))
- with _enter_patches(stack, patches):
+ with _patch_main(
+ mock_loop,
+ mock_server,
+ mock_event,
+ extra_patches=[patch("server.signal.signal", side_effect=_capture_signal)],
+ ):
await main()
sigint_handler = captured_handlers[signal.SIGINT]