@@ -22,11 +22,14 @@
# the COPYING file in the top-level directory.
from .error import AQMPError, MultiException
-from .protocol import ConnectError
+from .protocol import ConnectError, Runstate
# The order of these fields impact the Sphinx documentation order.
__all__ = (
+ # Classes
+ 'Runstate',
+
# Exceptions, most generic to most explicit
'AQMPError',
'ConnectError',
@@ -12,11 +12,10 @@
import asyncio
from asyncio import StreamReader, StreamWriter
+from enum import Enum
+from functools import wraps
from ssl import SSLContext
-# import exceptions will be removed in a forthcoming commit.
-# The problem stems from pylint/flake8 believing that 'Any'
-# is unused because of its only use in a string-quoted type.
-from typing import ( # pylint: disable=unused-import # noqa
+from typing import (
Any,
Awaitable,
Callable,
@@ -26,6 +25,7 @@
Tuple,
TypeVar,
Union,
+ cast,
)
from .error import AQMPError, MultiException
@@ -45,6 +45,20 @@
_FutureT = TypeVar('_FutureT', bound=Optional['asyncio.Future[Any]'])
+class Runstate(Enum):
+ """Protocol session runstate."""
+
+ #: Fully quiesced and disconnected.
+ IDLE = 0
+ #: In the process of connecting or establishing a session.
+ CONNECTING = 1
+ #: Fully connected and active session.
+ RUNNING = 2
+ #: In the process of disconnecting.
+ #: Runstate may be returned to `IDLE` by calling `disconnect()`.
+ DISCONNECTING = 3
+
+
class ConnectError(AQMPError):
"""
Raised when the initial connection process has failed.
@@ -66,6 +80,75 @@ def __str__(self) -> str:
return f"{self.error_message}: {self.exc!s}"
+class StateError(AQMPError):
+ """
+ An API command (connect, execute, etc) was issued at an inappropriate time.
+
+ This error is raised when a command like
+ :py:meth:`~AsyncProtocol.connect()` is issued at an inappropriate
+ time.
+
+ :param error_message: Human-readable string describing the state violation.
+ :param state: The actual `Runstate` seen at the time of the violation.
+ :param required: The `Runstate` required to process this command.
+
+ """
+ def __init__(self, error_message: str,
+ state: Runstate, required: Runstate):
+ super().__init__(error_message)
+ self.error_message = error_message
+ self.state = state
+ self.required = required
+
+
+F = TypeVar('F', bound=Callable[..., Any]) # pylint: disable=invalid-name
+
+
+# Don't Panic.
+def require(required_state: Runstate) -> Callable[[F], F]:
+ """
+ Decorator: protect a method so it can only be run in a certain `Runstate`.
+
+ :param required_state: The `Runstate` required to invoke this method.
+ :raise StateError: When the required `Runstate` is not met.
+ """
+ def _decorator(func: F) -> F:
+ # _decorator is the decorator that is built by calling the
+ # require() decorator factory; e.g.:
+ #
+ # @require(Runstate.IDLE) def # foo(): ...
+ # will replace 'foo' with the result of '_decorator(foo)'.
+
+ @wraps(func)
+ def _wrapper(proto: 'AsyncProtocol[Any]',
+ *args: Any, **kwargs: Any) -> Any:
+ # _wrapper is the function that gets executed prior to the
+ # decorated method.
+
+ if proto.runstate != required_state:
+ if proto.runstate == Runstate.CONNECTING:
+ emsg = "Client is currently connecting."
+ elif proto.runstate == Runstate.DISCONNECTING:
+ emsg = ("Client is disconnecting."
+ " Call disconnect() to return to IDLE state.")
+ elif proto.runstate == Runstate.RUNNING:
+ emsg = "Client is already connected and running."
+ elif proto.runstate == Runstate.IDLE:
+ emsg = "Client is disconnected and idle."
+ else:
+ assert False
+ raise StateError(emsg, proto.runstate, required_state)
+ # No StateError, so call the wrapped method.
+ return func(proto, *args, **kwargs)
+
+ # Return the decorated method;
+ # Transforming Func to Decorated[Func].
+ return cast(F, _wrapper)
+
+ # Return the decorator instance from the decorator factory. Phew!
+ return _decorator
+
+
class AsyncProtocol(Generic[T]):
"""
AsyncProtocol implements a generic async message-based protocol.
@@ -124,7 +207,18 @@ def __init__(self) -> None:
#: exit.
self._dc_task: Optional[asyncio.Future[None]] = None
+ self._runstate = Runstate.IDLE
+
+ #: An `asyncio.Event` that signals when `runstate` is changed.
+ self.runstate_changed: asyncio.Event = asyncio.Event()
+
+ @property
+ def runstate(self) -> Runstate:
+ """The current `Runstate` of the connection."""
+ return self._runstate
+
@upper_half
+ @require(Runstate.IDLE)
async def connect(self, address: Union[str, Tuple[str, int]],
ssl: Optional[SSLContext] = None) -> None:
"""
@@ -165,6 +259,21 @@ async def disconnect(self, force: bool = False) -> None:
# Section: Session machinery
# --------------------------
+ @upper_half
+ @bottom_half
+ def _set_state(self, state: Runstate) -> None:
+ """
+ Change the `Runstate` of the protocol connection.
+
+ Signals the `runstate_changed` event.
+ """
+ if state == self._runstate:
+ return
+
+ self._runstate = state
+ self.runstate_changed.set()
+ self.runstate_changed.clear()
+
@upper_half
async def _new_session(self,
address: Union[str, Tuple[str, int]],
@@ -189,6 +298,9 @@ async def _new_session(self,
protocol-level failure occurs while establishing a new
session, the wrapped error may also be an `AQMPError`.
"""
+ assert self.runstate == Runstate.IDLE
+ self._set_state(Runstate.CONNECTING)
+
self._outgoing = asyncio.Queue()
phase = "connection"
@@ -204,6 +316,8 @@ async def _new_session(self,
emsg = f"Failed to establish {phase}"
raise ConnectError(emsg, err) from err
+ assert self.runstate == Runstate.RUNNING
+
@upper_half
async def _do_connect(self, address: Union[str, Tuple[str, int]],
ssl: Optional[SSLContext] = None) -> None:
@@ -227,6 +341,8 @@ async def _begin_new_session(self) -> None:
"""
After a connection is established, start the bottom half machinery.
"""
+ assert self.runstate == Runstate.CONNECTING
+
reader_coro = self._bh_loop_forever(self._bh_recv_message)
writer_coro = self._bh_loop_forever(self._bh_send_message)
@@ -239,6 +355,8 @@ async def _begin_new_session(self) -> None:
return_exceptions=True,
)
+ self._set_state(Runstate.RUNNING)
+
@upper_half
@bottom_half
def _schedule_disconnect(self, force: bool = False) -> None:
@@ -276,6 +394,7 @@ def _results(self) -> None:
Iterable Exception used to multiplex multiple exceptions in the
event that multiple Tasks failed with non-cancellation reasons.
"""
+ assert self.runstate == Runstate.DISCONNECTING
exceptions: List[BaseException] = []
assert self._bh_tasks is None or self._bh_tasks.done()
@@ -340,6 +459,7 @@ def _paranoid_task_erase(task: _FutureT) -> Optional[_FutureT]:
assert (task is None) or task.done()
return None if (task and task.done()) else task
+ assert self.runstate == Runstate.DISCONNECTING
self._dc_task = _paranoid_task_erase(self._dc_task)
self._reader_task = _paranoid_task_erase(self._reader_task)
self._writer_task = _paranoid_task_erase(self._writer_task)
@@ -348,6 +468,8 @@ def _paranoid_task_erase(task: _FutureT) -> Optional[_FutureT]:
self._reader = None
self._writer = None
+ self._set_state(Runstate.IDLE)
+
# ----------------------------
# Section: Bottom Half methods
# ----------------------------
@@ -367,6 +489,9 @@ async def _bh_disconnect(self, force: bool = False) -> None:
terminating execution. When `True`, terminate immediately.
"""
+ # Prohibit new calls to execute() et al.
+ self._set_state(Runstate.DISCONNECTING)
+
await self._bh_stop_writer(force)
await self._bh_stop_reader()
This serves a few purposes: 1. Protect interfaces when it's not safe to call them (via @require) 2. Add an interface by which an async client can determine if the state has changed, for the purposes of connection management. Signed-off-by: John Snow <jsnow@redhat.com> --- python/qemu/aqmp/__init__.py | 5 +- python/qemu/aqmp/protocol.py | 133 +++++++++++++++++++++++++++++++++-- 2 files changed, 133 insertions(+), 5 deletions(-)