| import asyncio | 
 | from contextlib import contextmanager | 
 | import os | 
 | import socket | 
 | from tempfile import TemporaryDirectory | 
 |  | 
 | import avocado | 
 |  | 
 | from qemu.qmp import ConnectError, Runstate | 
 | from qemu.qmp.protocol import AsyncProtocol, StateError | 
 | from qemu.qmp.util import asyncio_run, create_task | 
 |  | 
 |  | 
 | class NullProtocol(AsyncProtocol[None]): | 
 |     """ | 
 |     NullProtocol is a test mockup of an AsyncProtocol implementation. | 
 |  | 
 |     It adds a fake_session instance variable that enables a code path | 
 |     that bypasses the actual connection logic, but still allows the | 
 |     reader/writers to start. | 
 |  | 
 |     Because the message type is defined as None, an asyncio.Event named | 
 |     'trigger_input' is created that prohibits the reader from | 
 |     incessantly being able to yield None; this event can be poked to | 
 |     simulate an incoming message. | 
 |  | 
 |     For testing symmetry with do_recv, an interface is added to "send" a | 
 |     Null message. | 
 |  | 
 |     For testing purposes, a "simulate_disconnection" method is also | 
 |     added which allows us to trigger a bottom half disconnect without | 
 |     injecting any real errors into the reader/writer loops; in essence | 
 |     it performs exactly half of what disconnect() normally does. | 
 |     """ | 
 |     def __init__(self, name=None): | 
 |         self.fake_session = False | 
 |         self.trigger_input: asyncio.Event | 
 |         super().__init__(name) | 
 |  | 
 |     async def _establish_session(self): | 
 |         self.trigger_input = asyncio.Event() | 
 |         await super()._establish_session() | 
 |  | 
 |     async def _do_start_server(self, address, ssl=None): | 
 |         if self.fake_session: | 
 |             self._accepted = asyncio.Event() | 
 |             self._set_state(Runstate.CONNECTING) | 
 |             await asyncio.sleep(0) | 
 |         else: | 
 |             await super()._do_start_server(address, ssl) | 
 |  | 
 |     async def _do_accept(self): | 
 |         if self.fake_session: | 
 |             self._accepted = None | 
 |         else: | 
 |             await super()._do_accept() | 
 |  | 
 |     async def _do_connect(self, address, ssl=None): | 
 |         if self.fake_session: | 
 |             self._set_state(Runstate.CONNECTING) | 
 |             await asyncio.sleep(0) | 
 |         else: | 
 |             await super()._do_connect(address, ssl) | 
 |  | 
 |     async def _do_recv(self) -> None: | 
 |         await self.trigger_input.wait() | 
 |         self.trigger_input.clear() | 
 |  | 
 |     def _do_send(self, msg: None) -> None: | 
 |         pass | 
 |  | 
 |     async def send_msg(self) -> None: | 
 |         await self._outgoing.put(None) | 
 |  | 
 |     async def simulate_disconnect(self) -> None: | 
 |         """ | 
 |         Simulates a bottom-half disconnect. | 
 |  | 
 |         This method schedules a disconnection but does not wait for it | 
 |         to complete. This is used to put the loop into the DISCONNECTING | 
 |         state without fully quiescing it back to IDLE. This is normally | 
 |         something you cannot coax AsyncProtocol to do on purpose, but it | 
 |         will be similar to what happens with an unhandled Exception in | 
 |         the reader/writer. | 
 |  | 
 |         Under normal circumstances, the library design requires you to | 
 |         await on disconnect(), which awaits the disconnect task and | 
 |         returns bottom half errors as a pre-condition to allowing the | 
 |         loop to return back to IDLE. | 
 |         """ | 
 |         self._schedule_disconnect() | 
 |  | 
 |  | 
 | class LineProtocol(AsyncProtocol[str]): | 
 |     def __init__(self, name=None): | 
 |         super().__init__(name) | 
 |         self.rx_history = [] | 
 |  | 
 |     async def _do_recv(self) -> str: | 
 |         raw = await self._readline() | 
 |         msg = raw.decode() | 
 |         self.rx_history.append(msg) | 
 |         return msg | 
 |  | 
 |     def _do_send(self, msg: str) -> None: | 
 |         assert self._writer is not None | 
 |         self._writer.write(msg.encode() + b'\n') | 
 |  | 
 |     async def send_msg(self, msg: str) -> None: | 
 |         await self._outgoing.put(msg) | 
 |  | 
 |  | 
 | def run_as_task(coro, allow_cancellation=False): | 
 |     """ | 
 |     Run a given coroutine as a task. | 
 |  | 
 |     Optionally, wrap it in a try..except block that allows this | 
 |     coroutine to be canceled gracefully. | 
 |     """ | 
 |     async def _runner(): | 
 |         try: | 
 |             await coro | 
 |         except asyncio.CancelledError: | 
 |             if allow_cancellation: | 
 |                 return | 
 |             raise | 
 |     return create_task(_runner()) | 
 |  | 
 |  | 
 | @contextmanager | 
 | def jammed_socket(): | 
 |     """ | 
 |     Opens up a random unused TCP port on localhost, then jams it. | 
 |     """ | 
 |     socks = [] | 
 |  | 
 |     try: | 
 |         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | 
 |         sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | 
 |         sock.bind(('127.0.0.1', 0)) | 
 |         sock.listen(1) | 
 |         address = sock.getsockname() | 
 |  | 
 |         socks.append(sock) | 
 |  | 
 |         # I don't *fully* understand why, but it takes *two* un-accepted | 
 |         # connections to start jamming the socket. | 
 |         for _ in range(2): | 
 |             sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | 
 |             sock.connect(address) | 
 |             socks.append(sock) | 
 |  | 
 |         yield address | 
 |  | 
 |     finally: | 
 |         for sock in socks: | 
 |             sock.close() | 
 |  | 
 |  | 
 | class Smoke(avocado.Test): | 
 |  | 
 |     def setUp(self): | 
 |         self.proto = NullProtocol() | 
 |  | 
 |     def test__repr__(self): | 
 |         self.assertEqual( | 
 |             repr(self.proto), | 
 |             "<NullProtocol runstate=IDLE>" | 
 |         ) | 
 |  | 
 |     def testRunstate(self): | 
 |         self.assertEqual( | 
 |             self.proto.runstate, | 
 |             Runstate.IDLE | 
 |         ) | 
 |  | 
 |     def testDefaultName(self): | 
 |         self.assertEqual( | 
 |             self.proto.name, | 
 |             None | 
 |         ) | 
 |  | 
 |     def testLogger(self): | 
 |         self.assertEqual( | 
 |             self.proto.logger.name, | 
 |             'qemu.qmp.protocol' | 
 |         ) | 
 |  | 
 |     def testName(self): | 
 |         self.proto = NullProtocol('Steve') | 
 |  | 
 |         self.assertEqual( | 
 |             self.proto.name, | 
 |             'Steve' | 
 |         ) | 
 |  | 
 |         self.assertEqual( | 
 |             self.proto.logger.name, | 
 |             'qemu.qmp.protocol.Steve' | 
 |         ) | 
 |  | 
 |         self.assertEqual( | 
 |             repr(self.proto), | 
 |             "<NullProtocol name='Steve' runstate=IDLE>" | 
 |         ) | 
 |  | 
 |  | 
 | class TestBase(avocado.Test): | 
 |  | 
 |     def setUp(self): | 
 |         self.proto = NullProtocol(type(self).__name__) | 
 |         self.assertEqual(self.proto.runstate, Runstate.IDLE) | 
 |         self.runstate_watcher = None | 
 |  | 
 |     def tearDown(self): | 
 |         self.assertEqual(self.proto.runstate, Runstate.IDLE) | 
 |  | 
 |     async def _asyncSetUp(self): | 
 |         pass | 
 |  | 
 |     async def _asyncTearDown(self): | 
 |         if self.runstate_watcher: | 
 |             await self.runstate_watcher | 
 |  | 
 |     @staticmethod | 
 |     def async_test(async_test_method): | 
 |         """ | 
 |         Decorator; adds SetUp and TearDown to async tests. | 
 |         """ | 
 |         async def _wrapper(self, *args, **kwargs): | 
 |             loop = asyncio.get_event_loop() | 
 |             loop.set_debug(True) | 
 |  | 
 |             await self._asyncSetUp() | 
 |             await async_test_method(self, *args, **kwargs) | 
 |             await self._asyncTearDown() | 
 |  | 
 |         return _wrapper | 
 |  | 
 |     # Definitions | 
 |  | 
 |     # The states we expect a "bad" connect/accept attempt to transition through | 
 |     BAD_CONNECTION_STATES = ( | 
 |         Runstate.CONNECTING, | 
 |         Runstate.DISCONNECTING, | 
 |         Runstate.IDLE, | 
 |     ) | 
 |  | 
 |     # The states we expect a "good" session to transition through | 
 |     GOOD_CONNECTION_STATES = ( | 
 |         Runstate.CONNECTING, | 
 |         Runstate.RUNNING, | 
 |         Runstate.DISCONNECTING, | 
 |         Runstate.IDLE, | 
 |     ) | 
 |  | 
 |     # Helpers | 
 |  | 
 |     async def _watch_runstates(self, *states): | 
 |         """ | 
 |         This launches a task alongside (most) tests below to confirm that | 
 |         the sequence of runstate changes that occur is exactly as | 
 |         anticipated. | 
 |         """ | 
 |         async def _watcher(): | 
 |             for state in states: | 
 |                 new_state = await self.proto.runstate_changed() | 
 |                 self.assertEqual( | 
 |                     new_state, | 
 |                     state, | 
 |                     msg=f"Expected state '{state.name}'", | 
 |                 ) | 
 |  | 
 |         self.runstate_watcher = create_task(_watcher()) | 
 |         # Kick the loop and force the task to block on the event. | 
 |         await asyncio.sleep(0) | 
 |  | 
 |  | 
 | class State(TestBase): | 
 |  | 
 |     @TestBase.async_test | 
 |     async def testSuperfluousDisconnect(self): | 
 |         """ | 
 |         Test calling disconnect() while already disconnected. | 
 |         """ | 
 |         await self._watch_runstates( | 
 |             Runstate.DISCONNECTING, | 
 |             Runstate.IDLE, | 
 |         ) | 
 |         await self.proto.disconnect() | 
 |  | 
 |  | 
 | class Connect(TestBase): | 
 |     """ | 
 |     Tests primarily related to calling Connect(). | 
 |     """ | 
 |     async def _bad_connection(self, family: str): | 
 |         assert family in ('INET', 'UNIX') | 
 |  | 
 |         if family == 'INET': | 
 |             await self.proto.connect(('127.0.0.1', 0)) | 
 |         elif family == 'UNIX': | 
 |             await self.proto.connect('/dev/null') | 
 |  | 
 |     async def _hanging_connection(self): | 
 |         with jammed_socket() as addr: | 
 |             await self.proto.connect(addr) | 
 |  | 
 |     async def _bad_connection_test(self, family: str): | 
 |         await self._watch_runstates(*self.BAD_CONNECTION_STATES) | 
 |  | 
 |         with self.assertRaises(ConnectError) as context: | 
 |             await self._bad_connection(family) | 
 |  | 
 |         self.assertIsInstance(context.exception.exc, OSError) | 
 |         self.assertEqual( | 
 |             context.exception.error_message, | 
 |             "Failed to establish connection" | 
 |         ) | 
 |  | 
 |     @TestBase.async_test | 
 |     async def testBadINET(self): | 
 |         """ | 
 |         Test an immediately rejected call to an IP target. | 
 |         """ | 
 |         await self._bad_connection_test('INET') | 
 |  | 
 |     @TestBase.async_test | 
 |     async def testBadUNIX(self): | 
 |         """ | 
 |         Test an immediately rejected call to a UNIX socket target. | 
 |         """ | 
 |         await self._bad_connection_test('UNIX') | 
 |  | 
 |     @TestBase.async_test | 
 |     async def testCancellation(self): | 
 |         """ | 
 |         Test what happens when a connection attempt is aborted. | 
 |         """ | 
 |         # Note that accept() cannot be cancelled outright, as it isn't a task. | 
 |         # However, we can wrap it in a task and cancel *that*. | 
 |         await self._watch_runstates(*self.BAD_CONNECTION_STATES) | 
 |         task = run_as_task(self._hanging_connection(), allow_cancellation=True) | 
 |  | 
 |         state = await self.proto.runstate_changed() | 
 |         self.assertEqual(state, Runstate.CONNECTING) | 
 |  | 
 |         # This is insider baseball, but the connection attempt has | 
 |         # yielded *just* before the actual connection attempt, so kick | 
 |         # the loop to make sure it's truly wedged. | 
 |         await asyncio.sleep(0) | 
 |  | 
 |         task.cancel() | 
 |         await task | 
 |  | 
 |     @TestBase.async_test | 
 |     async def testTimeout(self): | 
 |         """ | 
 |         Test what happens when a connection attempt times out. | 
 |         """ | 
 |         await self._watch_runstates(*self.BAD_CONNECTION_STATES) | 
 |         task = run_as_task(self._hanging_connection()) | 
 |  | 
 |         # More insider baseball: to improve the speed of this test while | 
 |         # guaranteeing that the connection even gets a chance to start, | 
 |         # verify that the connection hangs *first*, then await the | 
 |         # result of the task with a nearly-zero timeout. | 
 |  | 
 |         state = await self.proto.runstate_changed() | 
 |         self.assertEqual(state, Runstate.CONNECTING) | 
 |         await asyncio.sleep(0) | 
 |  | 
 |         with self.assertRaises(asyncio.TimeoutError): | 
 |             await asyncio.wait_for(task, timeout=0) | 
 |  | 
 |     @TestBase.async_test | 
 |     async def testRequire(self): | 
 |         """ | 
 |         Test what happens when a connection attempt is made while CONNECTING. | 
 |         """ | 
 |         await self._watch_runstates(*self.BAD_CONNECTION_STATES) | 
 |         task = run_as_task(self._hanging_connection(), allow_cancellation=True) | 
 |  | 
 |         state = await self.proto.runstate_changed() | 
 |         self.assertEqual(state, Runstate.CONNECTING) | 
 |  | 
 |         with self.assertRaises(StateError) as context: | 
 |             await self._bad_connection('UNIX') | 
 |  | 
 |         self.assertEqual( | 
 |             context.exception.error_message, | 
 |             "NullProtocol is currently connecting." | 
 |         ) | 
 |         self.assertEqual(context.exception.state, Runstate.CONNECTING) | 
 |         self.assertEqual(context.exception.required, Runstate.IDLE) | 
 |  | 
 |         task.cancel() | 
 |         await task | 
 |  | 
 |     @TestBase.async_test | 
 |     async def testImplicitRunstateInit(self): | 
 |         """ | 
 |         Test what happens if we do not wait on the runstate event until | 
 |         AFTER a connection is made, i.e., connect()/accept() themselves | 
 |         initialize the runstate event. All of the above tests force the | 
 |         initialization by waiting on the runstate *first*. | 
 |         """ | 
 |         task = run_as_task(self._hanging_connection(), allow_cancellation=True) | 
 |  | 
 |         # Kick the loop to coerce the state change | 
 |         await asyncio.sleep(0) | 
 |         assert self.proto.runstate == Runstate.CONNECTING | 
 |  | 
 |         # We already missed the transition to CONNECTING | 
 |         await self._watch_runstates(Runstate.DISCONNECTING, Runstate.IDLE) | 
 |  | 
 |         task.cancel() | 
 |         await task | 
 |  | 
 |  | 
 | class Accept(Connect): | 
 |     """ | 
 |     All of the same tests as Connect, but using the accept() interface. | 
 |     """ | 
 |     async def _bad_connection(self, family: str): | 
 |         assert family in ('INET', 'UNIX') | 
 |  | 
 |         if family == 'INET': | 
 |             await self.proto.start_server_and_accept(('example.com', 1)) | 
 |         elif family == 'UNIX': | 
 |             await self.proto.start_server_and_accept('/dev/null') | 
 |  | 
 |     async def _hanging_connection(self): | 
 |         with TemporaryDirectory(suffix='.qmp') as tmpdir: | 
 |             sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock") | 
 |             await self.proto.start_server_and_accept(sock) | 
 |  | 
 |  | 
 | class FakeSession(TestBase): | 
 |  | 
 |     def setUp(self): | 
 |         super().setUp() | 
 |         self.proto.fake_session = True | 
 |  | 
 |     async def _asyncSetUp(self): | 
 |         await super()._asyncSetUp() | 
 |         await self._watch_runstates(*self.GOOD_CONNECTION_STATES) | 
 |  | 
 |     async def _asyncTearDown(self): | 
 |         await self.proto.disconnect() | 
 |         await super()._asyncTearDown() | 
 |  | 
 |     #### | 
 |  | 
 |     @TestBase.async_test | 
 |     async def testFakeConnect(self): | 
 |  | 
 |         """Test the full state lifecycle (via connect) with a no-op session.""" | 
 |         await self.proto.connect('/not/a/real/path') | 
 |         self.assertEqual(self.proto.runstate, Runstate.RUNNING) | 
 |  | 
 |     @TestBase.async_test | 
 |     async def testFakeAccept(self): | 
 |         """Test the full state lifecycle (via accept) with a no-op session.""" | 
 |         await self.proto.start_server_and_accept('/not/a/real/path') | 
 |         self.assertEqual(self.proto.runstate, Runstate.RUNNING) | 
 |  | 
 |     @TestBase.async_test | 
 |     async def testFakeRecv(self): | 
 |         """Test receiving a fake/null message.""" | 
 |         await self.proto.start_server_and_accept('/not/a/real/path') | 
 |  | 
 |         logname = self.proto.logger.name | 
 |         with self.assertLogs(logname, level='DEBUG') as context: | 
 |             self.proto.trigger_input.set() | 
 |             self.proto.trigger_input.clear() | 
 |             await asyncio.sleep(0)  # Kick reader. | 
 |  | 
 |         self.assertEqual( | 
 |             context.output, | 
 |             [f"DEBUG:{logname}:<-- None"], | 
 |         ) | 
 |  | 
 |     @TestBase.async_test | 
 |     async def testFakeSend(self): | 
 |         """Test sending a fake/null message.""" | 
 |         await self.proto.start_server_and_accept('/not/a/real/path') | 
 |  | 
 |         logname = self.proto.logger.name | 
 |         with self.assertLogs(logname, level='DEBUG') as context: | 
 |             # Cheat: Send a Null message to nobody. | 
 |             await self.proto.send_msg() | 
 |             # Kick writer; awaiting on a queue.put isn't sufficient to yield. | 
 |             await asyncio.sleep(0) | 
 |  | 
 |         self.assertEqual( | 
 |             context.output, | 
 |             [f"DEBUG:{logname}:--> None"], | 
 |         ) | 
 |  | 
 |     async def _prod_session_api( | 
 |             self, | 
 |             current_state: Runstate, | 
 |             error_message: str, | 
 |             accept: bool = True | 
 |     ): | 
 |         with self.assertRaises(StateError) as context: | 
 |             if accept: | 
 |                 await self.proto.start_server_and_accept('/not/a/real/path') | 
 |             else: | 
 |                 await self.proto.connect('/not/a/real/path') | 
 |  | 
 |         self.assertEqual(context.exception.error_message, error_message) | 
 |         self.assertEqual(context.exception.state, current_state) | 
 |         self.assertEqual(context.exception.required, Runstate.IDLE) | 
 |  | 
 |     @TestBase.async_test | 
 |     async def testAcceptRequireRunning(self): | 
 |         """Test that accept() cannot be called when Runstate=RUNNING""" | 
 |         await self.proto.start_server_and_accept('/not/a/real/path') | 
 |  | 
 |         await self._prod_session_api( | 
 |             Runstate.RUNNING, | 
 |             "NullProtocol is already connected and running.", | 
 |             accept=True, | 
 |         ) | 
 |  | 
 |     @TestBase.async_test | 
 |     async def testConnectRequireRunning(self): | 
 |         """Test that connect() cannot be called when Runstate=RUNNING""" | 
 |         await self.proto.start_server_and_accept('/not/a/real/path') | 
 |  | 
 |         await self._prod_session_api( | 
 |             Runstate.RUNNING, | 
 |             "NullProtocol is already connected and running.", | 
 |             accept=False, | 
 |         ) | 
 |  | 
 |     @TestBase.async_test | 
 |     async def testAcceptRequireDisconnecting(self): | 
 |         """Test that accept() cannot be called when Runstate=DISCONNECTING""" | 
 |         await self.proto.start_server_and_accept('/not/a/real/path') | 
 |  | 
 |         # Cheat: force a disconnect. | 
 |         await self.proto.simulate_disconnect() | 
 |  | 
 |         await self._prod_session_api( | 
 |             Runstate.DISCONNECTING, | 
 |             ("NullProtocol is disconnecting." | 
 |              " Call disconnect() to return to IDLE state."), | 
 |             accept=True, | 
 |         ) | 
 |  | 
 |     @TestBase.async_test | 
 |     async def testConnectRequireDisconnecting(self): | 
 |         """Test that connect() cannot be called when Runstate=DISCONNECTING""" | 
 |         await self.proto.start_server_and_accept('/not/a/real/path') | 
 |  | 
 |         # Cheat: force a disconnect. | 
 |         await self.proto.simulate_disconnect() | 
 |  | 
 |         await self._prod_session_api( | 
 |             Runstate.DISCONNECTING, | 
 |             ("NullProtocol is disconnecting." | 
 |              " Call disconnect() to return to IDLE state."), | 
 |             accept=False, | 
 |         ) | 
 |  | 
 |  | 
 | class SimpleSession(TestBase): | 
 |  | 
 |     def setUp(self): | 
 |         super().setUp() | 
 |         self.server = LineProtocol(type(self).__name__ + '-server') | 
 |  | 
 |     async def _asyncSetUp(self): | 
 |         await super()._asyncSetUp() | 
 |         await self._watch_runstates(*self.GOOD_CONNECTION_STATES) | 
 |  | 
 |     async def _asyncTearDown(self): | 
 |         await self.proto.disconnect() | 
 |         try: | 
 |             await self.server.disconnect() | 
 |         except EOFError: | 
 |             pass | 
 |         await super()._asyncTearDown() | 
 |  | 
 |     @TestBase.async_test | 
 |     async def testSmoke(self): | 
 |         with TemporaryDirectory(suffix='.qmp') as tmpdir: | 
 |             sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock") | 
 |             server_task = create_task(self.server.start_server_and_accept(sock)) | 
 |  | 
 |             # give the server a chance to start listening [...] | 
 |             await asyncio.sleep(0) | 
 |             await self.proto.connect(sock) |