import asyncio
import logging
import math
+import signal
+import sys
from dataclasses import dataclass
from datetime import datetime, timezone
from enum import StrEnum
DEFAULT_HEARTBEAT_INTERVAL = 60
DEFAULT_TOTAL_COST = 10.0
MAX_REQUEST_ID = 2**31 - 1
+SHUTDOWN_TIMEOUT = 30.0
SUBPROTOCOLS: list[websockets.Subprotocol] = [
websockets.Subprotocol("ocpp2.0"),
websockets.Subprotocol("ocpp2.0.1"),
auth_config.offline,
)
- server = await websockets.serve(
+ loop = asyncio.get_running_loop()
+ shutdown_count = 0
+ shutdown_event = asyncio.Event()
+
+ async with websockets.serve(
partial(
on_connect,
config=config,
args.host,
args.port,
subprotocols=SUBPROTOCOLS,
- )
- logger.info("WebSocket Server Started on %s:%d", args.host, args.port)
+ ) as server:
+ logger.info("WebSocket Server Started on %s:%d", args.host, args.port)
+
+ def _on_signal(sig: signal.Signals) -> None:
+ nonlocal shutdown_count
+ shutdown_count += 1
+ if shutdown_count == 1:
+ logger.info("Received %s, initiating graceful shutdown...", sig.name)
+ server.close()
+ shutdown_event.set()
+ else:
+ logger.warning("Received %s again, forcing exit", sig.name)
+ sys.exit(128 + sig.value)
+
+ for sig in (signal.SIGINT, signal.SIGTERM):
+ try:
+ loop.add_signal_handler(sig, _on_signal, sig)
+ except NotImplementedError:
+ # Windows: ProactorEventLoop doesn't support add_signal_handler.
+ # signal.signal() fires outside the event loop, so schedule
+ # _on_signal into the loop via call_soon_threadsafe.
+ def _signal_handler(
+ _signum: int,
+ _frame: object,
+ s: signal.Signals = sig,
+ ) -> None:
+ loop.call_soon_threadsafe(_on_signal, s)
+
+ signal.signal(sig, _signal_handler)
+
+ await shutdown_event.wait()
+
+ try:
+ async with asyncio.timeout(SHUTDOWN_TIMEOUT):
+ await server.wait_closed()
+ except TimeoutError:
+ logger.warning(
+ "Shutdown timed out after %.0fs"
+ " — connections may not have closed cleanly",
+ SHUTDOWN_TIMEOUT,
+ )
- await server.wait_closed()
+ logger.info("Server shutdown complete")
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
- asyncio.run(main())
+ try:
+ asyncio.run(main())
+ except KeyboardInterrupt:
+ pass
+ sys.exit(0)