Example code can be found on GitHub. All code on this post is licensed under MIT.
Mayhem Mandrill Recap #
The goal for this 7-part series is to build a mock chaos monkey-like service called “Mayhem Mandrill”. This is an event-driven service that consumes from a pub/sub, and initiates a mock restart of a host. We could get thousands of messages in seconds, so as we get a message, we shouldn’t block the handling of the next message we receive.
For a more simplistic starting point, we’re going to test asyncio
code that doesn’t have to deal with threading. Here’s the starting point of what we’re going to test:
import asyncio
import functools
import logging
import random
import signal
import string
import uuid
import attr
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s,%(msecs)d %(levelname)s: %(message)s",
datefmt="%H:%M:%S",
)
@attr.s
class PubSubMessage:
instance_name = attr.ib()
message_id = attr.ib(repr=False)
hostname = attr.ib(repr=False, init=False)
restarted = attr.ib(repr=False, default=False)
saved = attr.ib(repr=False, default=False)
acked = attr.ib(repr=False, default=False)
extended_cnt = attr.ib(repr=False, default=0)
def __attrs_post_init__(self):
self.hostname = f"{self.instance_name}.example.net"
class RestartFailed(Exception):
pass
async def publish(queue):
choices = string.ascii_lowercase + string.digits
while True:
msg_id = str(uuid.uuid4())
host_id = "".join(random.choices(choices, k=4))
instance_name = f"cattle-{host_id}"
msg = PubSubMessage(message_id=msg_id, instance_name=instance_name)
asyncio.create_task(queue.put(msg))
logging.debug(f"Published message {msg}")
await asyncio.sleep(random.random())
async def restart_host(msg):
await asyncio.sleep(random.random())
if random.randrange(1, 5) == 3:
raise RestartFailed(f"Could not restart {msg.hostname}")
msg.restarted = True
logging.info(f"Restarted {msg.hostname}")
async def save(msg):
await asyncio.sleep(random.random())
# if random.randrange(1, 5) == 3:
# raise Exception(f"Could not save {msg}")
msg.saved = True
logging.info(f"Saved {msg} into database")
async def cleanup(msg, event):
await event.wait()
await asyncio.sleep(random.random())
msg.acked = True
logging.info(f"Done. Acked {msg}")
async def extend(msg, event):
while not event.is_set():
msg.extended_cnt += 1
logging.info(f"Extended deadline by 3 seconds for {msg}")
await asyncio.sleep(2)
def handle_results(results, msg):
for result in results:
if isinstance(result, RestartFailed):
logging.error(f"Retrying for failure to restart: {msg.hostname}")
elif isinstance(result, Exception):
logging.error(f"Handling general error: {result}")
async def handle_message(msg):
event = asyncio.Event()
asyncio.create_task(extend(msg, event))
asyncio.create_task(cleanup(msg, event))
results = await asyncio.gather(
save(msg), restart_host(msg), return_exceptions=True
)
handle_results(results, msg)
event.set()
async def consume(queue):
while True:
msg = await queue.get()
logging.info(f"Pulled {msg}")
asyncio.create_task(handle_message(msg))
def handle_exception(loop, context):
msg = context.get("exception", context["message"])
logging.error(f"Caught exception: {msg}")
logging.info("Shutting down...")
asyncio.create_task(shutdown(loop))
async def shutdown(loop, signal=None):
if signal:
logging.info(f"Received exit signal {signal.name}...")
logging.info("Closing database connections")
logging.info("Nacking outstanding messages")
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
[task.cancel() for task in tasks]
logging.info("Cancelling outstanding tasks")
await asyncio.gather(*tasks, return_exceptions=True)
logging.info(f"Flushing metrics")
loop.stop()
def main():
loop = asyncio.get_event_loop()
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
for s in signals:
loop.add_signal_handler(
s, lambda s=s: asyncio.create_task(shutdown(loop, signal=s))
)
loop.set_exception_handler(handle_exception)
queue = asyncio.Queue()
try:
loop.create_task(publish(queue))
loop.create_task(consume(queue))
loop.run_forever()
finally:
loop.close()
logging.info("Successfully shutdown the Mayhem service.")
if __name__ == "__main__": # pragma: no cover
main()
Simple Testing with pytest
#
We will be using pytest
since I prefer writing simple assert
statements for my tests.
We will start simple, and test the “happy path” of the save
function (i.e. not the code path that raises an exception):
async def save(msg):
# unhelpful simulation of i/o work
await asyncio.sleep(random.random())
msg.saved = True
logging.info(f"Saved {msg} into database")
Since save
is a coroutine function, we’ll need to run it on the loop:
import asyncio
import pytest
import mayhem
@pytest.fixture
def message():
return mayhem.PubSubMessage(message_id="1234", instance_name="mayhem_test")
def test_save(message):
assert not message.saved # sanity check
asyncio.run(mayhem.save(message))
assert message.saved
Running this via pytest
:
$ pytest -v test_mayhem_1.py
test_mayhem_1.py::test_save PASSED [100%]
Sweet! However if you’re not on 3.7+ yet, you’ll have to construct and deconstruct the loop yourself, rather than making use of asyncio.run
:
def test_save(message):
assert not message.saved # sanity check
loop = asyncio.get_event_loop()
loop.run_until_complete(mayhem.save(message))
loop.close()
assert message.saved
This can get annoying, especially when you have many coroutine functions to test. Thankfully, there is a plugin for pytest
called pytest-asyncio
. This plugin allows you to define your tests themselves as coroutine functions, and manages the event loop for you:
@pytest.mark.asyncio
async def test_save(message): # <-- now a coroutine!
assert not message.saved # sanity check
await mayhem.save(message)
assert message.saved
Much cleaner! Using pytest-asyncio
can get you pretty far.
Mocking Coroutines #
When writing unit tests, you’ll often need to mock out coroutine functions that are called within your tested function.
For instance, our save
coroutine function calls another coroutine function, asyncio.sleep
(a stand-in for a network I/O call to a database):
async def save(msg):
# unhelpful simulation of i/o work
await asyncio.sleep(random.random()) # <-- let's mock this out
msg.saved = True
logging.info(f"Saved {msg} into database")
You don’t actually want to wait for asyncio.sleep
to complete while running your tests, nor do you want an actual call to a database to happen.
Both the unittest.mock
and pytest-mock
libraries do not support asynchronous mocks, so we’ll have to work around this.
First, in making use of the pytest-mock
library, we’ll create a pytest
fixture that will return a function:
@pytest.fixture
def create_mock_coro(mocker, monkeypatch):
"""Create a mock-coro pair.
The coro can be used to patch an async method while the mock can
be used to assert calls to the mocked out method.
"""
def _create_mock_coro_pair(to_patch=None):
mock = mocker.Mock()
async def _coro(*args, **kwargs):
return mock(*args, **kwargs)
if to_patch:
monkeypatch.setattr(to_patch, _coro)
return mock, _coro
return _create_mock_coro_pair
Then, we’ll create another pytest
fixture that will use the create_mock_coro
fixture to mock and patch asyncio.sleep
:
def mock_sleep(create_mock_coro, monkeypatch):
mock_sleep, _ = create_mock_coro(to_patch="mayhem.asyncio.sleep")
return mock_sleep
Now let’s use the mock_sleep
fixture in our test_save
:
@pytest.mark.asyncio
async def test_save(mock_sleep, message):
assert not message.saved # sanity check
await mayhem.save(message)
assert message.saved
assert 1 == mock_sleep.call_count
What we’ve done here is basically patched asyncio.sleep
in our mayhem
module with a coroutine function that returns a mocked object. Then, we assert that the mocked asyncio.sleep
object is called once when mayhem.save
is called. Because we now have a mock object instead of the actual coroutine, we can now do anything that’s supported with unittest.mock.Mock
objects, i.e. our_mocked_object.assert_called_once_with(...)
, our_mocked_object.return_value = "foo"
, etc.
Testing create.task
#
For testing coroutine functions that have calls to create.task
, we can’t simply use the create_mock_coro
fixture. For instance, let’s try to test our consume
coroutine function:
async def consume(queue):
while True:
msg = await queue.get()
logging.info(f"Pulled {msg}")
asyncio.create_task(handle_message(msg))
I have the following fixtures for the asyncio.queue
:
@pytest.fixture
def mock_queue(mocker, monkeypatch):
queue = mocker.Mock()
monkeypatch.setattr(mayhem.asyncio, "Queue", queue)
return queue.return_value
@pytest.fixture
def mock_get(mock_queue, create_mock_coro):
mock_get, coro_get = create_mock_coro()
mock_queue.get = coro_get
return mock_get
So let’s try to use create_mock_coro
to mock and match the call to handle_message
coroutine via create_task
.
Note: we’re setting mock_get.side_effect
with one “real” value, and one Exception
to make sure we’re not permanently stuck within the while True
loop that consume
has.
@pytest.mark.asyncio
async def test_consume(mock_get, mock_queue, message, create_mock_coro):
mock_get.side_effect = [message, Exception("break while loop")]
mock_handle_message, _ = create_mock_coro("mayhem.handle_message")
with pytest.raises(Exception, match="break while loop"):
await mayhem.consume(mock_queue)
mock_handle_message.assert_called_once_with(message)
When running this, we see that mock_handle_message
does not actually get called, like we’re expecting:
$ pytest -v test_mayhem_4.py
test_mayhem_4.py::test_consume FAILED [100%]
=========================================== FAILURES ============================================
_________________________________________ test_consume __________________________________________
mock_get = <Mock id='4477488824'>, mock_queue = <Mock name='mock()' id='4477488880'>
message = Message(instance_name='cattle-1234')
create_mock_coro = <function create_mock_coro.<locals>._create_mock_patch_coro at 0x10add9840>
@pytest.mark.asyncio
async def test_consume(mock_get, mock_queue, message, create_mock_coro):
mock_get.side_effect = [message, Exception("break while loop")]
mock_handle_message = create_mock_coro("mayhem.handle_message")
with pytest.raises(Exception, match="break while loop"):
await mayhem.consume(mock_queue)
> mock_handle_message.assert_called_once_with(message)
E AssertionError: Expected 'mock' to be called once. Called 0 times.
test_mayhem_4.py:230: AssertionError
------------------------------------- Captured stderr call --------------------------------------
15:30:37,721 INFO: Pulled Message(instance_name='cattle-1234')
============================== 1 failed, 1 passed in 0.10 seconds ===============================
This is because the scheduled tasks are only scheduled and pending at this point; we need to nudge them along. We do this by collecting all running tasks (that’s not the test itself), and running them explicitly:
@pytest.mark.asyncio
async def test_consume(mock_get, mock_queue, message, create_mock_coro):
mock_get.side_effect = [message, Exception("break while loop")]
mock_handle_message, _ = create_mock_coro("mayhem.handle_message")
with pytest.raises(Exception, match="break while loop"):
await mayhem.consume(mock_queue)
ret_tasks = [
t for t in asyncio.all_tasks() if t is not asyncio.current_task()
]
# should be 1 per side effect minus the Exception (i.e. messages consumed)
assert 1 == len(ret_tasks)
mock_handle_message.assert_not_called() # <-- sanity check
# explicitly await tasks scheduled by `asyncio.create_task`
await asyncio.gather(*ret_tasks)
mock_handle_message.assert_called_once_with(message)
Now pytest
is happy:
$ pytest -v test_mayhem_5.py
test_mayhem_5.py::test_consume PASSED [100%]
Non-async Testing of the Event Loop #
I hear you wanting to get to 100% test coverage, which may seem difficult for our main
function. We’ll make use of pytest-asyncio
’s event_loop
fixture, with a slight modification.
First, we’ll create our own fixture by inheriting from pytest-asyncio
’s event_loop
fixture:
@pytest.fixture
def event_loop(event_loop, mocker):
new_loop = asyncio.get_event_loop_policy().new_event_loop()
asyncio.set_event_loop(new_loop)
new_loop._close = new_loop.close
new_loop.close = mocker.Mock()
yield new_loop
new_loop._close()
We’re essentially setting a different event loop that pytest-asyncio
will use when it injects it into the tested code. We want to update the testing event loop so we can override the close()
behavior, which gets called in our main
function. If we close the loop during the test, we’ll lose access to the signal handlers that we setup within the main
function. We can replace the close()
method with a mock object to still assert that it has been called.
So now, we’ll write a test_main
function that actually borders on an integration or functional test. We want to make sure – in addition to the expected calls to publish
and consume
– that shutdown
gets called when expected.
We can’t exactly mock out shutdown
with create_mock_coro
since it will patch it with just another coroutine and therefore run the mocked coroutine each time it receives a signal. Instead, we’ll mock out the asyncio.gather
within the shutdown
coroutine. Instead, we’ll just mock out the coroutine that shutdown
calls (the asyncio.gather
).
And finally, in order to see if the loop
actually responds to signals, we need to send a signal to it. We do this by starting a separate thread from which we’ll send a signal to the process itself.
# <-- snip -->
import os
import signal
import time
import threading
# <-- snip -->
def test_main(create_mock_coro, event_loop, mock_queue):
mock_consume, _ = create_mock_coro("mayhem.consume")
mock_publish, _ = create_mock_coro("mayhem.publish")
# mock out `asyncio.gather` that `shutdown` calls instead
# of `shutdown` itself
mock_asyncio_gather, _ = create_mock_coro("mayhem.asyncio.gather")
def _send_signal():
# allow the loop to start and work a little bit...
time.sleep(0.1)
# ...then send a signal
os.kill(os.getpid(), signal.SIGTERM)
thread = threading.Thread(target=_send_signal, daemon=True)
thread.start()
mayhem.main()
assert signal.SIGTERM in event_loop._signal_handlers
assert mayhem.handle_exception == event_loop.get_exception_handler()
mock_asyncio_gather.assert_called_once_with(return_exceptions=True)
mock_consume.assert_called_once_with(mock_queue)
mock_publish.assert_called_once_with(mock_queue)
# asserting the loop is stopped but not closed
assert not event_loop.is_running()
assert not event_loop.is_closed()
event_loop.close.assert_called_once_with()
$ pytest -v test_mayhem_6.py
test_mayhem_6.py::test_main PASSED [100%]
We can further parametrize this, as well as test for behavior when the loop receives a signal other than SIGINT
, SIGTERM
, and SIGHUP
. This requires us to add another signal since we want to make sure that other signals do not invoke the defined shutdown behavior. We’ll make use of SIGUSR1
and add a different shutdown mock to the test event loop:
@pytest.mark.parametrize(
"tested_signal", ("SIGINT","SIGTERM", "SIGHUP", "SIGUSR1")
)
def test_main(tested_signal, create_mock_coro, event_loop, mock_queue, mocker):
tested_signal = getattr(signal, tested_signal)
mock_asyncio_gather, _ = create_mock_coro("mayhem.asyncio.gather")
mock_consume, _ = create_mock_coro("mayhem.consume")
mock_publish, _ = create_mock_coro("mayhem.publish")
mock_shutdown = mocker.Mock()
def _shutdown():
mock_shutdown()
event_loop.stop()
event_loop.add_signal_handler(signal.SIGUSR1, _shutdown)
def _send_signal():
time.sleep(0.1)
os.kill(os.getpid(), tested_signal)
thread = threading.Thread(target=_send_signal, daemon=True)
thread.start()
mayhem.main()
assert tested_signal in event_loop._signal_handlers
assert mayhem.handle_exception == event_loop.get_exception_handler()
mock_consume.assert_called_once_with(mock_queue)
mock_publish.assert_called_once_with(mock_queue)
if tested_signal is not signal.SIGUSR1:
mock_asyncio_gather.assert_called_once_with(return_exceptions=True)
mock_shutdown.assert_not_called()
else:
mock_asyncio_gather.assert_not_called()
mock_shutdown.assert_called_once_with()
# asserting the loop is stopped but not closed
assert not event_loop.is_running()
assert not event_loop.is_closed()
event_loop.close.assert_called_once_with()
Look at those happy tests:
pytest -v part-5/test_mayhem_7.py
test_mayhem_7.py::test_main[SIGINT] PASSED [ 25%]
test_mayhem_7.py::test_main[SIGTERM] PASSED [ 50%]
test_mayhem_7.py::test_main[SIGHUP] PASSED [ 75%]
test_mayhem_7.py::test_main[SIGUSR1] PASSED [100%]
To see what near-100% test coverage looks like for mayhem.py
, check out part-5/test_mayhem_full.py
.
Third-Party Libraries #
- The aforementioned
pytest-asyncio
has other helpful things too like theevent_loop
,unused_tcp_port
, andunused_tcp_port_factory
fixtures; and the ability to create your own asynchronous fixtures. asynctest
that has a lot of helpful tooling, including coroutine mocks and exhausting callbacks so we don’t have to manually await tasks made bycreate_task
. It does require the use of theunittest
and define your tests asasynctest.TestCase
. I’m not much of a fan of theunittest
style, so perhaps someday someone will createasynctest
forpytest
:).- aiohttp has some really nice built-in test utilities supporting both
pytest
andunittest
.
Recap #
Basically, by using pytest-asyncio
, testing asyncio
code isn’t too much different than non-asynchronous code. There is still the clunkiness of needing to manually mock coroutine functions and exhausting the event loop when testing code that uses create_task
(an open source contribution opportunity, maybe??).