|  | import asyncio | 
|  | from contextlib import contextmanager | 
|  | import os | 
|  | import socket | 
|  | from tempfile import TemporaryDirectory | 
|  |  | 
|  | import avocado | 
|  |  | 
|  | from qemu.qmp import ConnectError, Runstate | 
|  | from qemu.qmp.protocol import AsyncProtocol, StateError | 
|  | from qemu.qmp.util import asyncio_run, create_task | 
|  |  | 
|  |  | 
|  | class NullProtocol(AsyncProtocol[None]): | 
|  | """ | 
|  | NullProtocol is a test mockup of an AsyncProtocol implementation. | 
|  |  | 
|  | It adds a fake_session instance variable that enables a code path | 
|  | that bypasses the actual connection logic, but still allows the | 
|  | reader/writers to start. | 
|  |  | 
|  | Because the message type is defined as None, an asyncio.Event named | 
|  | 'trigger_input' is created that prohibits the reader from | 
|  | incessantly being able to yield None; this event can be poked to | 
|  | simulate an incoming message. | 
|  |  | 
|  | For testing symmetry with do_recv, an interface is added to "send" a | 
|  | Null message. | 
|  |  | 
|  | For testing purposes, a "simulate_disconnection" method is also | 
|  | added which allows us to trigger a bottom half disconnect without | 
|  | injecting any real errors into the reader/writer loops; in essence | 
|  | it performs exactly half of what disconnect() normally does. | 
|  | """ | 
|  | def __init__(self, name=None): | 
|  | self.fake_session = False | 
|  | self.trigger_input: asyncio.Event | 
|  | super().__init__(name) | 
|  |  | 
|  | async def _establish_session(self): | 
|  | self.trigger_input = asyncio.Event() | 
|  | await super()._establish_session() | 
|  |  | 
|  | async def _do_start_server(self, address, ssl=None): | 
|  | if self.fake_session: | 
|  | self._accepted = asyncio.Event() | 
|  | self._set_state(Runstate.CONNECTING) | 
|  | await asyncio.sleep(0) | 
|  | else: | 
|  | await super()._do_start_server(address, ssl) | 
|  |  | 
|  | async def _do_accept(self): | 
|  | if self.fake_session: | 
|  | self._accepted = None | 
|  | else: | 
|  | await super()._do_accept() | 
|  |  | 
|  | async def _do_connect(self, address, ssl=None): | 
|  | if self.fake_session: | 
|  | self._set_state(Runstate.CONNECTING) | 
|  | await asyncio.sleep(0) | 
|  | else: | 
|  | await super()._do_connect(address, ssl) | 
|  |  | 
|  | async def _do_recv(self) -> None: | 
|  | await self.trigger_input.wait() | 
|  | self.trigger_input.clear() | 
|  |  | 
|  | def _do_send(self, msg: None) -> None: | 
|  | pass | 
|  |  | 
|  | async def send_msg(self) -> None: | 
|  | await self._outgoing.put(None) | 
|  |  | 
|  | async def simulate_disconnect(self) -> None: | 
|  | """ | 
|  | Simulates a bottom-half disconnect. | 
|  |  | 
|  | This method schedules a disconnection but does not wait for it | 
|  | to complete. This is used to put the loop into the DISCONNECTING | 
|  | state without fully quiescing it back to IDLE. This is normally | 
|  | something you cannot coax AsyncProtocol to do on purpose, but it | 
|  | will be similar to what happens with an unhandled Exception in | 
|  | the reader/writer. | 
|  |  | 
|  | Under normal circumstances, the library design requires you to | 
|  | await on disconnect(), which awaits the disconnect task and | 
|  | returns bottom half errors as a pre-condition to allowing the | 
|  | loop to return back to IDLE. | 
|  | """ | 
|  | self._schedule_disconnect() | 
|  |  | 
|  |  | 
|  | class LineProtocol(AsyncProtocol[str]): | 
|  | def __init__(self, name=None): | 
|  | super().__init__(name) | 
|  | self.rx_history = [] | 
|  |  | 
|  | async def _do_recv(self) -> str: | 
|  | raw = await self._readline() | 
|  | msg = raw.decode() | 
|  | self.rx_history.append(msg) | 
|  | return msg | 
|  |  | 
|  | def _do_send(self, msg: str) -> None: | 
|  | assert self._writer is not None | 
|  | self._writer.write(msg.encode() + b'\n') | 
|  |  | 
|  | async def send_msg(self, msg: str) -> None: | 
|  | await self._outgoing.put(msg) | 
|  |  | 
|  |  | 
|  | def run_as_task(coro, allow_cancellation=False): | 
|  | """ | 
|  | Run a given coroutine as a task. | 
|  |  | 
|  | Optionally, wrap it in a try..except block that allows this | 
|  | coroutine to be canceled gracefully. | 
|  | """ | 
|  | async def _runner(): | 
|  | try: | 
|  | await coro | 
|  | except asyncio.CancelledError: | 
|  | if allow_cancellation: | 
|  | return | 
|  | raise | 
|  | return create_task(_runner()) | 
|  |  | 
|  |  | 
|  | @contextmanager | 
|  | def jammed_socket(): | 
|  | """ | 
|  | Opens up a random unused TCP port on localhost, then jams it. | 
|  | """ | 
|  | socks = [] | 
|  |  | 
|  | try: | 
|  | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | 
|  | sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | 
|  | sock.bind(('127.0.0.1', 0)) | 
|  | sock.listen(1) | 
|  | address = sock.getsockname() | 
|  |  | 
|  | socks.append(sock) | 
|  |  | 
|  | # I don't *fully* understand why, but it takes *two* un-accepted | 
|  | # connections to start jamming the socket. | 
|  | for _ in range(2): | 
|  | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | 
|  | sock.connect(address) | 
|  | socks.append(sock) | 
|  |  | 
|  | yield address | 
|  |  | 
|  | finally: | 
|  | for sock in socks: | 
|  | sock.close() | 
|  |  | 
|  |  | 
|  | class Smoke(avocado.Test): | 
|  |  | 
|  | def setUp(self): | 
|  | self.proto = NullProtocol() | 
|  |  | 
|  | def test__repr__(self): | 
|  | self.assertEqual( | 
|  | repr(self.proto), | 
|  | "<NullProtocol runstate=IDLE>" | 
|  | ) | 
|  |  | 
|  | def testRunstate(self): | 
|  | self.assertEqual( | 
|  | self.proto.runstate, | 
|  | Runstate.IDLE | 
|  | ) | 
|  |  | 
|  | def testDefaultName(self): | 
|  | self.assertEqual( | 
|  | self.proto.name, | 
|  | None | 
|  | ) | 
|  |  | 
|  | def testLogger(self): | 
|  | self.assertEqual( | 
|  | self.proto.logger.name, | 
|  | 'qemu.qmp.protocol' | 
|  | ) | 
|  |  | 
|  | def testName(self): | 
|  | self.proto = NullProtocol('Steve') | 
|  |  | 
|  | self.assertEqual( | 
|  | self.proto.name, | 
|  | 'Steve' | 
|  | ) | 
|  |  | 
|  | self.assertEqual( | 
|  | self.proto.logger.name, | 
|  | 'qemu.qmp.protocol.Steve' | 
|  | ) | 
|  |  | 
|  | self.assertEqual( | 
|  | repr(self.proto), | 
|  | "<NullProtocol name='Steve' runstate=IDLE>" | 
|  | ) | 
|  |  | 
|  |  | 
|  | class TestBase(avocado.Test): | 
|  |  | 
|  | def setUp(self): | 
|  | self.proto = NullProtocol(type(self).__name__) | 
|  | self.assertEqual(self.proto.runstate, Runstate.IDLE) | 
|  | self.runstate_watcher = None | 
|  |  | 
|  | def tearDown(self): | 
|  | self.assertEqual(self.proto.runstate, Runstate.IDLE) | 
|  |  | 
|  | async def _asyncSetUp(self): | 
|  | pass | 
|  |  | 
|  | async def _asyncTearDown(self): | 
|  | if self.runstate_watcher: | 
|  | await self.runstate_watcher | 
|  |  | 
|  | @staticmethod | 
|  | def async_test(async_test_method): | 
|  | """ | 
|  | Decorator; adds SetUp and TearDown to async tests. | 
|  | """ | 
|  | async def _wrapper(self, *args, **kwargs): | 
|  | loop = asyncio.get_event_loop() | 
|  | loop.set_debug(True) | 
|  |  | 
|  | await self._asyncSetUp() | 
|  | await async_test_method(self, *args, **kwargs) | 
|  | await self._asyncTearDown() | 
|  |  | 
|  | return _wrapper | 
|  |  | 
|  | # Definitions | 
|  |  | 
|  | # The states we expect a "bad" connect/accept attempt to transition through | 
|  | BAD_CONNECTION_STATES = ( | 
|  | Runstate.CONNECTING, | 
|  | Runstate.DISCONNECTING, | 
|  | Runstate.IDLE, | 
|  | ) | 
|  |  | 
|  | # The states we expect a "good" session to transition through | 
|  | GOOD_CONNECTION_STATES = ( | 
|  | Runstate.CONNECTING, | 
|  | Runstate.RUNNING, | 
|  | Runstate.DISCONNECTING, | 
|  | Runstate.IDLE, | 
|  | ) | 
|  |  | 
|  | # Helpers | 
|  |  | 
|  | async def _watch_runstates(self, *states): | 
|  | """ | 
|  | This launches a task alongside (most) tests below to confirm that | 
|  | the sequence of runstate changes that occur is exactly as | 
|  | anticipated. | 
|  | """ | 
|  | async def _watcher(): | 
|  | for state in states: | 
|  | new_state = await self.proto.runstate_changed() | 
|  | self.assertEqual( | 
|  | new_state, | 
|  | state, | 
|  | msg=f"Expected state '{state.name}'", | 
|  | ) | 
|  |  | 
|  | self.runstate_watcher = create_task(_watcher()) | 
|  | # Kick the loop and force the task to block on the event. | 
|  | await asyncio.sleep(0) | 
|  |  | 
|  |  | 
|  | class State(TestBase): | 
|  |  | 
|  | @TestBase.async_test | 
|  | async def testSuperfluousDisconnect(self): | 
|  | """ | 
|  | Test calling disconnect() while already disconnected. | 
|  | """ | 
|  | await self._watch_runstates( | 
|  | Runstate.DISCONNECTING, | 
|  | Runstate.IDLE, | 
|  | ) | 
|  | await self.proto.disconnect() | 
|  |  | 
|  |  | 
|  | class Connect(TestBase): | 
|  | """ | 
|  | Tests primarily related to calling Connect(). | 
|  | """ | 
|  | async def _bad_connection(self, family: str): | 
|  | assert family in ('INET', 'UNIX') | 
|  |  | 
|  | if family == 'INET': | 
|  | await self.proto.connect(('127.0.0.1', 0)) | 
|  | elif family == 'UNIX': | 
|  | await self.proto.connect('/dev/null') | 
|  |  | 
|  | async def _hanging_connection(self): | 
|  | with jammed_socket() as addr: | 
|  | await self.proto.connect(addr) | 
|  |  | 
|  | async def _bad_connection_test(self, family: str): | 
|  | await self._watch_runstates(*self.BAD_CONNECTION_STATES) | 
|  |  | 
|  | with self.assertRaises(ConnectError) as context: | 
|  | await self._bad_connection(family) | 
|  |  | 
|  | self.assertIsInstance(context.exception.exc, OSError) | 
|  | self.assertEqual( | 
|  | context.exception.error_message, | 
|  | "Failed to establish connection" | 
|  | ) | 
|  |  | 
|  | @TestBase.async_test | 
|  | async def testBadINET(self): | 
|  | """ | 
|  | Test an immediately rejected call to an IP target. | 
|  | """ | 
|  | await self._bad_connection_test('INET') | 
|  |  | 
|  | @TestBase.async_test | 
|  | async def testBadUNIX(self): | 
|  | """ | 
|  | Test an immediately rejected call to a UNIX socket target. | 
|  | """ | 
|  | await self._bad_connection_test('UNIX') | 
|  |  | 
|  | @TestBase.async_test | 
|  | async def testCancellation(self): | 
|  | """ | 
|  | Test what happens when a connection attempt is aborted. | 
|  | """ | 
|  | # Note that accept() cannot be cancelled outright, as it isn't a task. | 
|  | # However, we can wrap it in a task and cancel *that*. | 
|  | await self._watch_runstates(*self.BAD_CONNECTION_STATES) | 
|  | task = run_as_task(self._hanging_connection(), allow_cancellation=True) | 
|  |  | 
|  | state = await self.proto.runstate_changed() | 
|  | self.assertEqual(state, Runstate.CONNECTING) | 
|  |  | 
|  | # This is insider baseball, but the connection attempt has | 
|  | # yielded *just* before the actual connection attempt, so kick | 
|  | # the loop to make sure it's truly wedged. | 
|  | await asyncio.sleep(0) | 
|  |  | 
|  | task.cancel() | 
|  | await task | 
|  |  | 
|  | @TestBase.async_test | 
|  | async def testTimeout(self): | 
|  | """ | 
|  | Test what happens when a connection attempt times out. | 
|  | """ | 
|  | await self._watch_runstates(*self.BAD_CONNECTION_STATES) | 
|  | task = run_as_task(self._hanging_connection()) | 
|  |  | 
|  | # More insider baseball: to improve the speed of this test while | 
|  | # guaranteeing that the connection even gets a chance to start, | 
|  | # verify that the connection hangs *first*, then await the | 
|  | # result of the task with a nearly-zero timeout. | 
|  |  | 
|  | state = await self.proto.runstate_changed() | 
|  | self.assertEqual(state, Runstate.CONNECTING) | 
|  | await asyncio.sleep(0) | 
|  |  | 
|  | with self.assertRaises(asyncio.TimeoutError): | 
|  | await asyncio.wait_for(task, timeout=0) | 
|  |  | 
|  | @TestBase.async_test | 
|  | async def testRequire(self): | 
|  | """ | 
|  | Test what happens when a connection attempt is made while CONNECTING. | 
|  | """ | 
|  | await self._watch_runstates(*self.BAD_CONNECTION_STATES) | 
|  | task = run_as_task(self._hanging_connection(), allow_cancellation=True) | 
|  |  | 
|  | state = await self.proto.runstate_changed() | 
|  | self.assertEqual(state, Runstate.CONNECTING) | 
|  |  | 
|  | with self.assertRaises(StateError) as context: | 
|  | await self._bad_connection('UNIX') | 
|  |  | 
|  | self.assertEqual( | 
|  | context.exception.error_message, | 
|  | "NullProtocol is currently connecting." | 
|  | ) | 
|  | self.assertEqual(context.exception.state, Runstate.CONNECTING) | 
|  | self.assertEqual(context.exception.required, Runstate.IDLE) | 
|  |  | 
|  | task.cancel() | 
|  | await task | 
|  |  | 
|  | @TestBase.async_test | 
|  | async def testImplicitRunstateInit(self): | 
|  | """ | 
|  | Test what happens if we do not wait on the runstate event until | 
|  | AFTER a connection is made, i.e., connect()/accept() themselves | 
|  | initialize the runstate event. All of the above tests force the | 
|  | initialization by waiting on the runstate *first*. | 
|  | """ | 
|  | task = run_as_task(self._hanging_connection(), allow_cancellation=True) | 
|  |  | 
|  | # Kick the loop to coerce the state change | 
|  | await asyncio.sleep(0) | 
|  | assert self.proto.runstate == Runstate.CONNECTING | 
|  |  | 
|  | # We already missed the transition to CONNECTING | 
|  | await self._watch_runstates(Runstate.DISCONNECTING, Runstate.IDLE) | 
|  |  | 
|  | task.cancel() | 
|  | await task | 
|  |  | 
|  |  | 
|  | class Accept(Connect): | 
|  | """ | 
|  | All of the same tests as Connect, but using the accept() interface. | 
|  | """ | 
|  | async def _bad_connection(self, family: str): | 
|  | assert family in ('INET', 'UNIX') | 
|  |  | 
|  | if family == 'INET': | 
|  | await self.proto.start_server_and_accept(('example.com', 1)) | 
|  | elif family == 'UNIX': | 
|  | await self.proto.start_server_and_accept('/dev/null') | 
|  |  | 
|  | async def _hanging_connection(self): | 
|  | with TemporaryDirectory(suffix='.qmp') as tmpdir: | 
|  | sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock") | 
|  | await self.proto.start_server_and_accept(sock) | 
|  |  | 
|  |  | 
|  | class FakeSession(TestBase): | 
|  |  | 
|  | def setUp(self): | 
|  | super().setUp() | 
|  | self.proto.fake_session = True | 
|  |  | 
|  | async def _asyncSetUp(self): | 
|  | await super()._asyncSetUp() | 
|  | await self._watch_runstates(*self.GOOD_CONNECTION_STATES) | 
|  |  | 
|  | async def _asyncTearDown(self): | 
|  | await self.proto.disconnect() | 
|  | await super()._asyncTearDown() | 
|  |  | 
|  | #### | 
|  |  | 
|  | @TestBase.async_test | 
|  | async def testFakeConnect(self): | 
|  |  | 
|  | """Test the full state lifecycle (via connect) with a no-op session.""" | 
|  | await self.proto.connect('/not/a/real/path') | 
|  | self.assertEqual(self.proto.runstate, Runstate.RUNNING) | 
|  |  | 
|  | @TestBase.async_test | 
|  | async def testFakeAccept(self): | 
|  | """Test the full state lifecycle (via accept) with a no-op session.""" | 
|  | await self.proto.start_server_and_accept('/not/a/real/path') | 
|  | self.assertEqual(self.proto.runstate, Runstate.RUNNING) | 
|  |  | 
|  | @TestBase.async_test | 
|  | async def testFakeRecv(self): | 
|  | """Test receiving a fake/null message.""" | 
|  | await self.proto.start_server_and_accept('/not/a/real/path') | 
|  |  | 
|  | logname = self.proto.logger.name | 
|  | with self.assertLogs(logname, level='DEBUG') as context: | 
|  | self.proto.trigger_input.set() | 
|  | self.proto.trigger_input.clear() | 
|  | await asyncio.sleep(0)  # Kick reader. | 
|  |  | 
|  | self.assertEqual( | 
|  | context.output, | 
|  | [f"DEBUG:{logname}:<-- None"], | 
|  | ) | 
|  |  | 
|  | @TestBase.async_test | 
|  | async def testFakeSend(self): | 
|  | """Test sending a fake/null message.""" | 
|  | await self.proto.start_server_and_accept('/not/a/real/path') | 
|  |  | 
|  | logname = self.proto.logger.name | 
|  | with self.assertLogs(logname, level='DEBUG') as context: | 
|  | # Cheat: Send a Null message to nobody. | 
|  | await self.proto.send_msg() | 
|  | # Kick writer; awaiting on a queue.put isn't sufficient to yield. | 
|  | await asyncio.sleep(0) | 
|  |  | 
|  | self.assertEqual( | 
|  | context.output, | 
|  | [f"DEBUG:{logname}:--> None"], | 
|  | ) | 
|  |  | 
|  | async def _prod_session_api( | 
|  | self, | 
|  | current_state: Runstate, | 
|  | error_message: str, | 
|  | accept: bool = True | 
|  | ): | 
|  | with self.assertRaises(StateError) as context: | 
|  | if accept: | 
|  | await self.proto.start_server_and_accept('/not/a/real/path') | 
|  | else: | 
|  | await self.proto.connect('/not/a/real/path') | 
|  |  | 
|  | self.assertEqual(context.exception.error_message, error_message) | 
|  | self.assertEqual(context.exception.state, current_state) | 
|  | self.assertEqual(context.exception.required, Runstate.IDLE) | 
|  |  | 
|  | @TestBase.async_test | 
|  | async def testAcceptRequireRunning(self): | 
|  | """Test that accept() cannot be called when Runstate=RUNNING""" | 
|  | await self.proto.start_server_and_accept('/not/a/real/path') | 
|  |  | 
|  | await self._prod_session_api( | 
|  | Runstate.RUNNING, | 
|  | "NullProtocol is already connected and running.", | 
|  | accept=True, | 
|  | ) | 
|  |  | 
|  | @TestBase.async_test | 
|  | async def testConnectRequireRunning(self): | 
|  | """Test that connect() cannot be called when Runstate=RUNNING""" | 
|  | await self.proto.start_server_and_accept('/not/a/real/path') | 
|  |  | 
|  | await self._prod_session_api( | 
|  | Runstate.RUNNING, | 
|  | "NullProtocol is already connected and running.", | 
|  | accept=False, | 
|  | ) | 
|  |  | 
|  | @TestBase.async_test | 
|  | async def testAcceptRequireDisconnecting(self): | 
|  | """Test that accept() cannot be called when Runstate=DISCONNECTING""" | 
|  | await self.proto.start_server_and_accept('/not/a/real/path') | 
|  |  | 
|  | # Cheat: force a disconnect. | 
|  | await self.proto.simulate_disconnect() | 
|  |  | 
|  | await self._prod_session_api( | 
|  | Runstate.DISCONNECTING, | 
|  | ("NullProtocol is disconnecting." | 
|  | " Call disconnect() to return to IDLE state."), | 
|  | accept=True, | 
|  | ) | 
|  |  | 
|  | @TestBase.async_test | 
|  | async def testConnectRequireDisconnecting(self): | 
|  | """Test that connect() cannot be called when Runstate=DISCONNECTING""" | 
|  | await self.proto.start_server_and_accept('/not/a/real/path') | 
|  |  | 
|  | # Cheat: force a disconnect. | 
|  | await self.proto.simulate_disconnect() | 
|  |  | 
|  | await self._prod_session_api( | 
|  | Runstate.DISCONNECTING, | 
|  | ("NullProtocol is disconnecting." | 
|  | " Call disconnect() to return to IDLE state."), | 
|  | accept=False, | 
|  | ) | 
|  |  | 
|  |  | 
|  | class SimpleSession(TestBase): | 
|  |  | 
|  | def setUp(self): | 
|  | super().setUp() | 
|  | self.server = LineProtocol(type(self).__name__ + '-server') | 
|  |  | 
|  | async def _asyncSetUp(self): | 
|  | await super()._asyncSetUp() | 
|  | await self._watch_runstates(*self.GOOD_CONNECTION_STATES) | 
|  |  | 
|  | async def _asyncTearDown(self): | 
|  | await self.proto.disconnect() | 
|  | try: | 
|  | await self.server.disconnect() | 
|  | except EOFError: | 
|  | pass | 
|  | await super()._asyncTearDown() | 
|  |  | 
|  | @TestBase.async_test | 
|  | async def testSmoke(self): | 
|  | with TemporaryDirectory(suffix='.qmp') as tmpdir: | 
|  | sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock") | 
|  | server_task = create_task(self.server.start_server_and_accept(sock)) | 
|  |  | 
|  | # give the server a chance to start listening [...] | 
|  | await asyncio.sleep(0) | 
|  | await self.proto.connect(sock) |