import argparse
import asyncio
import logging
+from dataclasses import dataclass
from datetime import datetime, timezone
+from enum import StrEnum
from functools import partial
from random import randint
DEFAULT_TOTAL_COST = 10.0
SUBPROTOCOLS = ["ocpp2.0", "ocpp2.0.1"]
+
+class AuthMode(StrEnum):
+ """Authorization modes for testing different authentication scenarios."""
+
+ normal = "normal"
+ whitelist = "whitelist"
+ blacklist = "blacklist"
+ rate_limit = "rate_limit"
+ offline = "offline"
+
+
+@dataclass(frozen=True)
+class AuthConfig:
+ """Authorization configuration for a charge point."""
+
+ mode: AuthMode
+ whitelist: tuple[str, ...]
+ blacklist: tuple[str, ...]
+ offline: bool
+ default_status: AuthorizationStatusEnumType
+
+ def __getitem__(self, key: str):
+ """Support dict-style access for backwards compatibility."""
+ value = getattr(self, key)
+ return list(value) if isinstance(value, tuple) else value
+
+
+@dataclass(frozen=True)
+class ServerConfig:
+ """Server-level configuration passed to each connection handler."""
+
+ command_name: Action | None
+ delay: float | None
+ period: float | None
+ auth_config: AuthConfig
+ boot_status: RegistrationStatusEnumType
+ total_cost: float
+
+
ChargePoints: set["ChargePoint"] = set()
"""OCPP 2.0.1 charge point handler with configurable behavior for testing."""
_command_timer: Timer | None
- _auth_config: dict
+ _auth_config: AuthConfig
_boot_status: RegistrationStatusEnumType
_total_cost: float
def __init__(
self,
connection,
- auth_config: dict | None = None,
+ auth_config: AuthConfig | None = None,
boot_status: RegistrationStatusEnumType = RegistrationStatusEnumType.accepted,
total_cost: float = DEFAULT_TOTAL_COST,
):
self._command_timer = None
self._boot_status = boot_status
self._total_cost = total_cost
- self._auth_config = auth_config or {
- "mode": "normal",
- "whitelist": ["valid_token", "test_token", "authorized_user"],
- "blacklist": ["blocked_token", "invalid_user"],
- "offline": False,
- "default_status": AuthorizationStatusEnumType.accepted,
- }
+ if auth_config is None:
+ self._auth_config = AuthConfig(
+ mode=AuthMode.normal,
+ whitelist=("valid_token", "test_token", "authorized_user"),
+ blacklist=("blocked_token", "invalid_user"),
+ offline=False,
+ default_status=AuthorizationStatusEnumType.accepted,
+ )
+ elif isinstance(auth_config, dict):
+ self._auth_config = AuthConfig(
+ mode=AuthMode(auth_config.get("mode", "normal")),
+ whitelist=tuple(auth_config.get("whitelist", ())),
+ blacklist=tuple(auth_config.get("blacklist", ())),
+ offline=auth_config.get("offline", False),
+ default_status=auth_config.get(
+ "default_status", AuthorizationStatusEnumType.accepted
+ ),
+ )
+ else:
+ self._auth_config = auth_config
def _resolve_auth_status(self, token_id: str) -> AuthorizationStatusEnumType:
"""Resolve authorization status based on auth mode and token."""
- mode = self._auth_config.get("mode", "normal")
- if mode == "whitelist":
- return (
- AuthorizationStatusEnumType.accepted
- if token_id in self._auth_config.get("whitelist", [])
- else AuthorizationStatusEnumType.blocked
- )
- if mode == "blacklist":
- return (
- AuthorizationStatusEnumType.blocked
- if token_id in self._auth_config.get("blacklist", [])
- else AuthorizationStatusEnumType.accepted
- )
- if mode == "rate_limit":
- return AuthorizationStatusEnumType.not_at_this_time
- return self._auth_config.get(
- "default_status", AuthorizationStatusEnumType.accepted
- )
+ match self._auth_config.mode:
+ case AuthMode.whitelist:
+ return (
+ AuthorizationStatusEnumType.accepted
+ if token_id in self._auth_config.whitelist
+ else AuthorizationStatusEnumType.blocked
+ )
+ case AuthMode.blacklist:
+ return (
+ AuthorizationStatusEnumType.blocked
+ if token_id in self._auth_config.blacklist
+ else AuthorizationStatusEnumType.accepted
+ )
+ case AuthMode.rate_limit:
+ return AuthorizationStatusEnumType.not_at_this_time
+ case _:
+ return self._auth_config.default_status
# --- Incoming message handlers (CS → CSMS) ---
"Received %s for token: %s", Action.authorize, id_token.get("id_token")
)
- if self._auth_config.get("offline", False):
+ if self._auth_config.offline:
logging.warning("Offline mode - simulating network failure")
raise InternalError(description="Simulated network failure")
async def on_connect(
websocket,
- command_name: Action | None,
- delay: float | None,
- period: float | None,
- auth_config: dict | None,
- boot_status: RegistrationStatusEnumType,
- total_cost: float,
+ config: ServerConfig,
):
"""Handle new WebSocket connections from charge points."""
try:
cp = ChargePoint(
websocket,
- auth_config=auth_config,
- boot_status=boot_status,
- total_cost=total_cost,
+ auth_config=config.auth_config,
+ boot_status=config.boot_status,
+ total_cost=config.total_cost,
)
- if command_name:
- await cp.send_command(command_name, delay, period)
+ if config.command_name:
+ await cp.send_command(config.command_name, config.delay, config.period)
ChargePoints.add(cp)
args = parser.parse_args()
- auth_config = {
- "mode": args.auth_mode,
- "whitelist": args.whitelist,
- "blacklist": args.blacklist,
- "offline": args.offline,
- "default_status": AuthorizationStatusEnumType.accepted,
- }
+ auth_config = AuthConfig(
+ mode=AuthMode(args.auth_mode),
+ whitelist=tuple(args.whitelist),
+ blacklist=tuple(args.blacklist),
+ offline=args.offline,
+ default_status=AuthorizationStatusEnumType.accepted,
+ )
+
+ config = ServerConfig(
+ command_name=args.command,
+ delay=args.delay,
+ period=args.period,
+ auth_config=auth_config,
+ boot_status=args.boot_status,
+ total_cost=args.total_cost,
+ )
logging.info(
- "Auth configuration: mode=%s, offline=%s", args.auth_mode, args.offline
+ "Auth configuration: mode=%s, offline=%s",
+ auth_config.mode,
+ auth_config.offline,
)
server = await websockets.serve(
partial(
on_connect,
- command_name=args.command,
- delay=args.delay,
- period=args.period,
- auth_config=auth_config,
- boot_status=args.boot_status,
- total_cost=args.total_cost,
+ config=config,
),
args.host,
args.port,