diff mbox series

[07/20] python/aqmp: add runstate state machine to AsyncProtocol

Message ID 20210701041313.1696009-8-jsnow@redhat.com (mailing list archive)
State New, archived
Headers show
Series python: introduce Asynchronous QMP package | expand

Commit Message

John Snow July 1, 2021, 4:13 a.m. UTC
This serves a few purposes:

1. Protect interfaces when it's not safe to call them (via @require)

2. Add an interface by which an async client can determine if the state
has changed, for the purposes of connection management.

Signed-off-by: John Snow <jsnow@redhat.com>
---
 python/qemu/aqmp/__init__.py |   5 +-
 python/qemu/aqmp/protocol.py | 133 +++++++++++++++++++++++++++++++++--
 2 files changed, 133 insertions(+), 5 deletions(-)
diff mbox series

Patch

diff --git a/python/qemu/aqmp/__init__.py b/python/qemu/aqmp/__init__.py
index e003c898bd..5c44fabeea 100644
--- a/python/qemu/aqmp/__init__.py
+++ b/python/qemu/aqmp/__init__.py
@@ -22,11 +22,14 @@ 
 # the COPYING file in the top-level directory.
 
 from .error import AQMPError, MultiException
-from .protocol import ConnectError
+from .protocol import ConnectError, Runstate
 
 
 # The order of these fields impact the Sphinx documentation order.
 __all__ = (
+    # Classes
+    'Runstate',
+
     # Exceptions, most generic to most explicit
     'AQMPError',
     'ConnectError',
diff --git a/python/qemu/aqmp/protocol.py b/python/qemu/aqmp/protocol.py
index beb7e12d9c..a99a191982 100644
--- a/python/qemu/aqmp/protocol.py
+++ b/python/qemu/aqmp/protocol.py
@@ -12,11 +12,10 @@ 
 
 import asyncio
 from asyncio import StreamReader, StreamWriter
+from enum import Enum
+from functools import wraps
 from ssl import SSLContext
-# import exceptions will be removed in a forthcoming commit.
-# The problem stems from pylint/flake8 believing that 'Any'
-# is unused because of its only use in a string-quoted type.
-from typing import (  # pylint: disable=unused-import # noqa
+from typing import (
     Any,
     Awaitable,
     Callable,
@@ -26,6 +25,7 @@ 
     Tuple,
     TypeVar,
     Union,
+    cast,
 )
 
 from .error import AQMPError, MultiException
@@ -45,6 +45,20 @@ 
 _FutureT = TypeVar('_FutureT', bound=Optional['asyncio.Future[Any]'])
 
 
+class Runstate(Enum):
+    """Protocol session runstate."""
+
+    #: Fully quiesced and disconnected.
+    IDLE = 0
+    #: In the process of connecting or establishing a session.
+    CONNECTING = 1
+    #: Fully connected and active session.
+    RUNNING = 2
+    #: In the process of disconnecting.
+    #: Runstate may be returned to `IDLE` by calling `disconnect()`.
+    DISCONNECTING = 3
+
+
 class ConnectError(AQMPError):
     """
     Raised when the initial connection process has failed.
@@ -66,6 +80,75 @@  def __str__(self) -> str:
         return f"{self.error_message}: {self.exc!s}"
 
 
+class StateError(AQMPError):
+    """
+    An API command (connect, execute, etc) was issued at an inappropriate time.
+
+    This error is raised when a command like
+    :py:meth:`~AsyncProtocol.connect()` is issued at an inappropriate
+    time.
+
+    :param error_message: Human-readable string describing the state violation.
+    :param state: The actual `Runstate` seen at the time of the violation.
+    :param required: The `Runstate` required to process this command.
+
+    """
+    def __init__(self, error_message: str,
+                 state: Runstate, required: Runstate):
+        super().__init__(error_message)
+        self.error_message = error_message
+        self.state = state
+        self.required = required
+
+
+F = TypeVar('F', bound=Callable[..., Any])  # pylint: disable=invalid-name
+
+
+# Don't Panic.
+def require(required_state: Runstate) -> Callable[[F], F]:
+    """
+    Decorator: protect a method so it can only be run in a certain `Runstate`.
+
+    :param required_state: The `Runstate` required to invoke this method.
+    :raise StateError: When the required `Runstate` is not met.
+    """
+    def _decorator(func: F) -> F:
+        # _decorator is the decorator that is built by calling the
+        # require() decorator factory; e.g.:
+        #
+        # @require(Runstate.IDLE) def # foo(): ...
+        # will replace 'foo' with the result of '_decorator(foo)'.
+
+        @wraps(func)
+        def _wrapper(proto: 'AsyncProtocol[Any]',
+                     *args: Any, **kwargs: Any) -> Any:
+            # _wrapper is the function that gets executed prior to the
+            # decorated method.
+
+            if proto.runstate != required_state:
+                if proto.runstate == Runstate.CONNECTING:
+                    emsg = "Client is currently connecting."
+                elif proto.runstate == Runstate.DISCONNECTING:
+                    emsg = ("Client is disconnecting."
+                            " Call disconnect() to return to IDLE state.")
+                elif proto.runstate == Runstate.RUNNING:
+                    emsg = "Client is already connected and running."
+                elif proto.runstate == Runstate.IDLE:
+                    emsg = "Client is disconnected and idle."
+                else:
+                    assert False
+                raise StateError(emsg, proto.runstate, required_state)
+            # No StateError, so call the wrapped method.
+            return func(proto, *args, **kwargs)
+
+        # Return the decorated method;
+        # Transforming Func to Decorated[Func].
+        return cast(F, _wrapper)
+
+    # Return the decorator instance from the decorator factory. Phew!
+    return _decorator
+
+
 class AsyncProtocol(Generic[T]):
     """
     AsyncProtocol implements a generic async message-based protocol.
@@ -124,7 +207,18 @@  def __init__(self) -> None:
         #: exit.
         self._dc_task: Optional[asyncio.Future[None]] = None
 
+        self._runstate = Runstate.IDLE
+
+        #: An `asyncio.Event` that signals when `runstate` is changed.
+        self.runstate_changed: asyncio.Event = asyncio.Event()
+
+    @property
+    def runstate(self) -> Runstate:
+        """The current `Runstate` of the connection."""
+        return self._runstate
+
     @upper_half
+    @require(Runstate.IDLE)
     async def connect(self, address: Union[str, Tuple[str, int]],
                       ssl: Optional[SSLContext] = None) -> None:
         """
@@ -165,6 +259,21 @@  async def disconnect(self, force: bool = False) -> None:
     # Section: Session machinery
     # --------------------------
 
+    @upper_half
+    @bottom_half
+    def _set_state(self, state: Runstate) -> None:
+        """
+        Change the `Runstate` of the protocol connection.
+
+        Signals the `runstate_changed` event.
+        """
+        if state == self._runstate:
+            return
+
+        self._runstate = state
+        self.runstate_changed.set()
+        self.runstate_changed.clear()
+
     @upper_half
     async def _new_session(self,
                            address: Union[str, Tuple[str, int]],
@@ -189,6 +298,9 @@  async def _new_session(self,
             protocol-level failure occurs while establishing a new
             session, the wrapped error may also be an `AQMPError`.
         """
+        assert self.runstate == Runstate.IDLE
+        self._set_state(Runstate.CONNECTING)
+
         self._outgoing = asyncio.Queue()
 
         phase = "connection"
@@ -204,6 +316,8 @@  async def _new_session(self,
             emsg = f"Failed to establish {phase}"
             raise ConnectError(emsg, err) from err
 
+        assert self.runstate == Runstate.RUNNING
+
     @upper_half
     async def _do_connect(self, address: Union[str, Tuple[str, int]],
                           ssl: Optional[SSLContext] = None) -> None:
@@ -227,6 +341,8 @@  async def _begin_new_session(self) -> None:
         """
         After a connection is established, start the bottom half machinery.
         """
+        assert self.runstate == Runstate.CONNECTING
+
         reader_coro = self._bh_loop_forever(self._bh_recv_message)
         writer_coro = self._bh_loop_forever(self._bh_send_message)
 
@@ -239,6 +355,8 @@  async def _begin_new_session(self) -> None:
             return_exceptions=True,
         )
 
+        self._set_state(Runstate.RUNNING)
+
     @upper_half
     @bottom_half
     def _schedule_disconnect(self, force: bool = False) -> None:
@@ -276,6 +394,7 @@  def _results(self) -> None:
             Iterable Exception used to multiplex multiple exceptions in the
             event that multiple Tasks failed with non-cancellation reasons.
         """
+        assert self.runstate == Runstate.DISCONNECTING
         exceptions: List[BaseException] = []
 
         assert self._bh_tasks is None or self._bh_tasks.done()
@@ -340,6 +459,7 @@  def _paranoid_task_erase(task: _FutureT) -> Optional[_FutureT]:
             assert (task is None) or task.done()
             return None if (task and task.done()) else task
 
+        assert self.runstate == Runstate.DISCONNECTING
         self._dc_task = _paranoid_task_erase(self._dc_task)
         self._reader_task = _paranoid_task_erase(self._reader_task)
         self._writer_task = _paranoid_task_erase(self._writer_task)
@@ -348,6 +468,8 @@  def _paranoid_task_erase(task: _FutureT) -> Optional[_FutureT]:
         self._reader = None
         self._writer = None
 
+        self._set_state(Runstate.IDLE)
+
     # ----------------------------
     # Section: Bottom Half methods
     # ----------------------------
@@ -367,6 +489,9 @@  async def _bh_disconnect(self, force: bool = False) -> None:
             terminating execution. When `True`, terminate immediately.
 
         """
+        # Prohibit new calls to execute() et al.
+        self._set_state(Runstate.DISCONNECTING)
+
         await self._bh_stop_writer(force)
         await self._bh_stop_reader()