fix(ocpp-server): add asyncio compatible timer implementation and use it
authorJérôme Benoit <jerome.benoit@piment-noir.org>
Thu, 27 Jun 2024 19:06:34 +0000 (21:06 +0200)
committerJérôme Benoit <jerome.benoit@piment-noir.org>
Thu, 27 Jun 2024 19:06:34 +0000 (21:06 +0200)
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
tests/ocpp-server/server.py
tests/ocpp-server/timer.py [new file with mode: 0644]

index 6fd83a1199749d2b330bf53c881124eaf54758bc..14c638f5b5e6e321110a36f941f1fa88c41fc275 100644 (file)
@@ -2,7 +2,8 @@ import argparse
 import asyncio
 import logging
 from datetime import datetime, timezone
-from threading import Timer
+from functools import partial
+from typing import Optional
 
 import ocpp.v201
 import websockets
@@ -18,24 +19,21 @@ from ocpp.v201.enums import (
 )
 from websockets import ConnectionClosed
 
+from timer import Timer
+
 # Setting up the logging configuration to display debug level messages.
 logging.basicConfig(level=logging.DEBUG)
 
 ChargePoints = set()
 
 
-class RepeatTimer(Timer):
-    """Class that inherits from the Timer class. It will run a
-    function at regular intervals."""
-
-    def run(self):
-        while not self.finished.wait(self.interval):
-            self.function(*self.args, **self.kwargs)
-
-
 # Define a ChargePoint class inheriting from the OCPP 2.0.1 ChargePoint class.
 class ChargePoint(ocpp.v201.ChargePoint):
-    _command_timer: RepeatTimer
+    _command_timer: Optional[Timer]
+
+    def __init__(self, connection):
+        super().__init__(connection.path.strip("/"), connection)
+        self._command_timer = None
 
     # Message handlers to receive OCPP messages.
     @on(Action.BootNotification)
@@ -98,15 +96,6 @@ class ChargePoint(ocpp.v201.ChargePoint):
         logging.info("Received %s", Action.MeterValues)
         return ocpp.v201.call_result.MeterValues()
 
-    @on(Action.GetBaseReport)
-    async def on_get_base_report(
-        self, request_id: int, report_base: ReportBaseType, **kwargs
-    ):
-        logging.info("Received %s", Action.GetBaseReport)
-        return ocpp.v201.call_result.GetBaseReport(
-            status=GenericDeviceModelStatusType.accepted
-        )
-
     # Request handlers to emit OCPP messages.
     async def _send_clear_cache(self):
         request = ocpp.v201.call.ClearCache()
@@ -138,20 +127,30 @@ class ChargePoint(ocpp.v201.ChargePoint):
             case _:
                 logging.info(f"Not supported command {command_name}")
 
-    async def send_command(self, command_name: Action, delay=None, period=None):
+    async def send_command(
+        self, command_name: Action, delay: Optional[float], period: Optional[float]
+    ):
         if not delay and not period:
-            raise ValueError("Either delay or period must be set")
+            raise ValueError("Either delay or period must be defined")
+        if delay and delay <= 0:
+            raise ValueError("Delay must be a positive number")
+        if period and period <= 0:
+            raise ValueError("Period must be a positive number")
         try:
-            if delay and delay > 0:
-                await asyncio.sleep(delay)
-                await self._send_command(command_name)
-            if period and period > 0 and not self._command_timer:
-                self._command_timer = RepeatTimer(
+            if delay and not self._command_timer:
+                self._command_timer = Timer(
+                    delay,
+                    False,
+                    self._send_command,
+                    [command_name],
+                )
+            if period and not self._command_timer:
+                self._command_timer = Timer(
                     period,
+                    True,
                     self._send_command,
                     [command_name],
                 )
-                self._command_timer.start()
         except ConnectionClosed:
             self.handle_connection_closed()
 
@@ -164,7 +163,12 @@ class ChargePoint(ocpp.v201.ChargePoint):
 
 
 # Function to handle new WebSocket connections.
-async def on_connect(websocket, path):
+async def on_connect(
+    websocket,
+    command_name: Optional[Action],
+    delay: Optional[float],
+    period: Optional[float],
+):
     """For every new charge point that connects, create a ChargePoint instance and start
     listening for messages."""
     try:
@@ -184,10 +188,12 @@ async def on_connect(websocket, path):
         )
         return await websocket.close()
 
-    charge_point_id = path.strip("/")
-    cp = ChargePoint(charge_point_id, websocket)
+    cp = ChargePoint(websocket)
+    if command_name:
+        await cp.send_command(command_name, delay, period)
 
     ChargePoints.add(cp)
+
     try:
         await cp.start()
     except ConnectionClosed:
@@ -196,29 +202,24 @@ async def on_connect(websocket, path):
 
 # Main function to start the WebSocket server.
 async def main():
-    # Define argument parser
-    parser = argparse.ArgumentParser(description="OCPP2 Charge Point Simulator")
-    parser.add_argument("--command", type=str, help="OCPP2 Command Name")
-    parser.add_argument("--delay", type=int, help="Delay in seconds")
-    parser.add_argument("--period", type=int, help="Period in seconds")
+    parser = argparse.ArgumentParser(description="OCPP2 Server")
+    parser.add_argument("-c", "--command", type=Action, help="OCPP2 Command Name")
+    parser.add_argument("-d", "--delay", type=float, help="Delay in seconds")
+    parser.add_argument("-p", "--period", type=float, help="Period in seconds")
+
+    args = parser.parse_args()
 
     # Create the WebSocket server and specify the handler for new connections.
     server = await websockets.serve(
-        on_connect,
+        partial(
+            on_connect, command_name=args.command, delay=args.delay, period=args.period
+        ),
         "127.0.0.1",  # Listen on loopback.
         9000,  # Port number.
         subprotocols=["ocpp2.0", "ocpp2.0.1"],  # Specify OCPP 2.0.1 subprotocols.
     )
     logging.info("WebSocket Server Started")
 
-    args = parser.parse_args()
-
-    if args.command:
-        for cp in ChargePoints:
-            await asyncio.create_task(
-                cp.send_command(cp, args.command, args.delay, args.period)
-            )
-
     # Wait for the server to close (runs indefinitely).
     await server.wait_closed()
 
diff --git a/tests/ocpp-server/timer.py b/tests/ocpp-server/timer.py
new file mode 100644 (file)
index 0000000..b267bef
--- /dev/null
@@ -0,0 +1,67 @@
+"""
+Timer for asyncio.
+"""
+
+import asyncio
+
+
+class Timer:
+    def __init__(
+        self,
+        timeout: float,
+        repeat: bool,
+        callback,
+        callback_args=(),
+        callback_kwargs=None,
+    ):
+        """An asynchronous Timer object.
+
+        Parameters
+        -----------
+        timeout: :class:`float`:
+        The duration for which the timer should last.
+
+        repeat: :class:`bool`:
+        Whether the timer should repeat.
+
+        callback: :class:`Coroutine` or `Method` or `Function`:
+        An `asyncio` coroutine or a regular method that will be called as soon as
+        the timer ends.
+
+        callback_args: Optional[:class:`tuple`]:
+        The args to be passed to the callback.
+
+        callback_kwargs: Optional[:class:`dict`]:
+        The kwargs to be passed to the callback.
+        """
+        self._timeout = timeout
+        self._repeat = repeat
+        self._callback = callback
+        self._task = asyncio.create_task(self._job())
+        self._callback_args = callback_args
+        if callback_kwargs is None:
+            callback_kwargs = {}
+        self._callback_kwargs = callback_kwargs
+
+    async def _job(self):
+        if self._repeat:
+            while self._task.cancelled() is False:
+                await asyncio.sleep(self._timeout)
+                await self._call_callback()
+        else:
+            await asyncio.sleep(self._timeout)
+            await self._call_callback()
+
+    async def _call_callback(self):
+        if asyncio.iscoroutine(self._callback) or asyncio.iscoroutinefunction(
+            self._callback
+        ):
+            await self._callback(*self._callback_args, **self._callback_kwargs)
+        else:
+            self._callback(*self._callback_args, **self._callback_kwargs)
+
+    def cancel(self):
+        """
+        Cancels the timer. The callback will not be called.
+        """
+        self._task.cancel()