]> Piment Noir Git Repositories - e-mobility-charging-stations-simulator.git/commitdiff
refactor(ocpp-server): introduce AuthMode, AuthConfig and ServerConfig typed dataclasses
authorJérôme Benoit <jerome.benoit@sap.com>
Fri, 13 Mar 2026 20:01:00 +0000 (21:01 +0100)
committerJérôme Benoit <jerome.benoit@sap.com>
Fri, 13 Mar 2026 20:01:00 +0000 (21:01 +0100)
tests/ocpp-server/server.py

index 95f15cea982adcd400763e5a3aa04e98277fc351..b6e435401301859b6f9d4ad7e37bb8018b1dcccd 100644 (file)
@@ -3,7 +3,9 @@
 import argparse
 import asyncio
 import logging
+from dataclasses import dataclass
 from datetime import datetime, timezone
+from enum import StrEnum
 from functools import partial
 from random import randint
 
@@ -56,6 +58,45 @@ DEFAULT_HEARTBEAT_INTERVAL = 60
 DEFAULT_TOTAL_COST = 10.0
 SUBPROTOCOLS = ["ocpp2.0", "ocpp2.0.1"]
 
+
+class AuthMode(StrEnum):
+    """Authorization modes for testing different authentication scenarios."""
+
+    normal = "normal"
+    whitelist = "whitelist"
+    blacklist = "blacklist"
+    rate_limit = "rate_limit"
+    offline = "offline"
+
+
+@dataclass(frozen=True)
+class AuthConfig:
+    """Authorization configuration for a charge point."""
+
+    mode: AuthMode
+    whitelist: tuple[str, ...]
+    blacklist: tuple[str, ...]
+    offline: bool
+    default_status: AuthorizationStatusEnumType
+
+    def __getitem__(self, key: str):
+        """Support dict-style access for backwards compatibility."""
+        value = getattr(self, key)
+        return list(value) if isinstance(value, tuple) else value
+
+
+@dataclass(frozen=True)
+class ServerConfig:
+    """Server-level configuration passed to each connection handler."""
+
+    command_name: Action | None
+    delay: float | None
+    period: float | None
+    auth_config: AuthConfig
+    boot_status: RegistrationStatusEnumType
+    total_cost: float
+
+
 ChargePoints: set["ChargePoint"] = set()
 
 
@@ -63,14 +104,14 @@ class ChargePoint(ocpp.v201.ChargePoint):
     """OCPP 2.0.1 charge point handler with configurable behavior for testing."""
 
     _command_timer: Timer | None
-    _auth_config: dict
+    _auth_config: AuthConfig
     _boot_status: RegistrationStatusEnumType
     _total_cost: float
 
     def __init__(
         self,
         connection,
-        auth_config: dict | None = None,
+        auth_config: AuthConfig | None = None,
         boot_status: RegistrationStatusEnumType = RegistrationStatusEnumType.accepted,
         total_cost: float = DEFAULT_TOTAL_COST,
     ):
@@ -78,34 +119,46 @@ class ChargePoint(ocpp.v201.ChargePoint):
         self._command_timer = None
         self._boot_status = boot_status
         self._total_cost = total_cost
-        self._auth_config = auth_config or {
-            "mode": "normal",
-            "whitelist": ["valid_token", "test_token", "authorized_user"],
-            "blacklist": ["blocked_token", "invalid_user"],
-            "offline": False,
-            "default_status": AuthorizationStatusEnumType.accepted,
-        }
+        if auth_config is None:
+            self._auth_config = AuthConfig(
+                mode=AuthMode.normal,
+                whitelist=("valid_token", "test_token", "authorized_user"),
+                blacklist=("blocked_token", "invalid_user"),
+                offline=False,
+                default_status=AuthorizationStatusEnumType.accepted,
+            )
+        elif isinstance(auth_config, dict):
+            self._auth_config = AuthConfig(
+                mode=AuthMode(auth_config.get("mode", "normal")),
+                whitelist=tuple(auth_config.get("whitelist", ())),
+                blacklist=tuple(auth_config.get("blacklist", ())),
+                offline=auth_config.get("offline", False),
+                default_status=auth_config.get(
+                    "default_status", AuthorizationStatusEnumType.accepted
+                ),
+            )
+        else:
+            self._auth_config = auth_config
 
     def _resolve_auth_status(self, token_id: str) -> AuthorizationStatusEnumType:
         """Resolve authorization status based on auth mode and token."""
-        mode = self._auth_config.get("mode", "normal")
-        if mode == "whitelist":
-            return (
-                AuthorizationStatusEnumType.accepted
-                if token_id in self._auth_config.get("whitelist", [])
-                else AuthorizationStatusEnumType.blocked
-            )
-        if mode == "blacklist":
-            return (
-                AuthorizationStatusEnumType.blocked
-                if token_id in self._auth_config.get("blacklist", [])
-                else AuthorizationStatusEnumType.accepted
-            )
-        if mode == "rate_limit":
-            return AuthorizationStatusEnumType.not_at_this_time
-        return self._auth_config.get(
-            "default_status", AuthorizationStatusEnumType.accepted
-        )
+        match self._auth_config.mode:
+            case AuthMode.whitelist:
+                return (
+                    AuthorizationStatusEnumType.accepted
+                    if token_id in self._auth_config.whitelist
+                    else AuthorizationStatusEnumType.blocked
+                )
+            case AuthMode.blacklist:
+                return (
+                    AuthorizationStatusEnumType.blocked
+                    if token_id in self._auth_config.blacklist
+                    else AuthorizationStatusEnumType.accepted
+                )
+            case AuthMode.rate_limit:
+                return AuthorizationStatusEnumType.not_at_this_time
+            case _:
+                return self._auth_config.default_status
 
     # --- Incoming message handlers (CS → CSMS) ---
 
@@ -138,7 +191,7 @@ class ChargePoint(ocpp.v201.ChargePoint):
             "Received %s for token: %s", Action.authorize, id_token.get("id_token")
         )
 
-        if self._auth_config.get("offline", False):
+        if self._auth_config.offline:
             logging.warning("Offline mode - simulating network failure")
             raise InternalError(description="Simulated network failure")
 
@@ -588,12 +641,7 @@ class ChargePoint(ocpp.v201.ChargePoint):
 
 async def on_connect(
     websocket,
-    command_name: Action | None,
-    delay: float | None,
-    period: float | None,
-    auth_config: dict | None,
-    boot_status: RegistrationStatusEnumType,
-    total_cost: float,
+    config: ServerConfig,
 ):
     """Handle new WebSocket connections from charge points."""
     try:
@@ -615,12 +663,12 @@ async def on_connect(
 
     cp = ChargePoint(
         websocket,
-        auth_config=auth_config,
-        boot_status=boot_status,
-        total_cost=total_cost,
+        auth_config=config.auth_config,
+        boot_status=config.boot_status,
+        total_cost=config.total_cost,
     )
-    if command_name:
-        await cp.send_command(command_name, delay, period)
+    if config.command_name:
+        await cp.send_command(config.command_name, config.delay, config.period)
 
     ChargePoints.add(cp)
 
@@ -718,27 +766,33 @@ async def main():
 
     args = parser.parse_args()
 
-    auth_config = {
-        "mode": args.auth_mode,
-        "whitelist": args.whitelist,
-        "blacklist": args.blacklist,
-        "offline": args.offline,
-        "default_status": AuthorizationStatusEnumType.accepted,
-    }
+    auth_config = AuthConfig(
+        mode=AuthMode(args.auth_mode),
+        whitelist=tuple(args.whitelist),
+        blacklist=tuple(args.blacklist),
+        offline=args.offline,
+        default_status=AuthorizationStatusEnumType.accepted,
+    )
+
+    config = ServerConfig(
+        command_name=args.command,
+        delay=args.delay,
+        period=args.period,
+        auth_config=auth_config,
+        boot_status=args.boot_status,
+        total_cost=args.total_cost,
+    )
 
     logging.info(
-        "Auth configuration: mode=%s, offline=%s", args.auth_mode, args.offline
+        "Auth configuration: mode=%s, offline=%s",
+        auth_config.mode,
+        auth_config.offline,
     )
 
     server = await websockets.serve(
         partial(
             on_connect,
-            command_name=args.command,
-            delay=args.delay,
-            period=args.period,
-            auth_config=auth_config,
-            boot_status=args.boot_status,
-            total_cost=args.total_cost,
+            config=config,
         ),
         args.host,
         args.port,