Skip to main content
  1. Words/

Testing asyncio Code

·2053 words·10 mins·
Table of Contents
Asyncio: We Did It Wrong - This article is part of a series.
Part 6: This Article

Example code can be found on GitHub. All code on this post is licensed under MIT.


Mayhem Mandrill Recap
#

The goal for this 7-part series is to build a mock chaos monkey-like service called “Mayhem Mandrill”. This is an event-driven service that consumes from a pub/sub, and initiates a mock restart of a host. We could get thousands of messages in seconds, so as we get a message, we shouldn’t block the handling of the next message we receive.

For a more simplistic starting point, we’re going to test asyncio code that doesn’t have to deal with threading. Here’s the starting point of what we’re going to test:

import asyncio
import functools
import logging
import random
import signal
import string
import uuid

import attr


logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s,%(msecs)d %(levelname)s: %(message)s",
    datefmt="%H:%M:%S",
)


@attr.s
class PubSubMessage:
    instance_name = attr.ib()
    message_id    = attr.ib(repr=False)
    hostname      = attr.ib(repr=False, init=False)
    restarted     = attr.ib(repr=False, default=False)
    saved         = attr.ib(repr=False, default=False)
    acked         = attr.ib(repr=False, default=False)
    extended_cnt  = attr.ib(repr=False, default=0)

    def __attrs_post_init__(self):
        self.hostname = f"{self.instance_name}.example.net"


class RestartFailed(Exception):
    pass


async def publish(queue):
    choices = string.ascii_lowercase + string.digits

    while True:
        msg_id = str(uuid.uuid4())
        host_id = "".join(random.choices(choices, k=4))
        instance_name = f"cattle-{host_id}"
        msg = PubSubMessage(message_id=msg_id, instance_name=instance_name)
        asyncio.create_task(queue.put(msg))
        logging.debug(f"Published message {msg}")
        await asyncio.sleep(random.random())


async def restart_host(msg):
    await asyncio.sleep(random.random())
    if random.randrange(1, 5) == 3:
        raise RestartFailed(f"Could not restart {msg.hostname}")
    msg.restarted = True
    logging.info(f"Restarted {msg.hostname}")


async def save(msg):
    await asyncio.sleep(random.random())
    # if random.randrange(1, 5) == 3:
    #     raise Exception(f"Could not save {msg}")
    msg.saved = True
    logging.info(f"Saved {msg} into database")


async def cleanup(msg, event):
    await event.wait()
    await asyncio.sleep(random.random())
    msg.acked = True
    logging.info(f"Done. Acked {msg}")


async def extend(msg, event):
    while not event.is_set():
        msg.extended_cnt += 1
        logging.info(f"Extended deadline by 3 seconds for {msg}")
        await asyncio.sleep(2)


def handle_results(results, msg):
    for result in results:
        if isinstance(result, RestartFailed):
            logging.error(f"Retrying for failure to restart: {msg.hostname}")
        elif isinstance(result, Exception):
            logging.error(f"Handling general error: {result}")


async def handle_message(msg):
    event = asyncio.Event()

    asyncio.create_task(extend(msg, event))
    asyncio.create_task(cleanup(msg, event))

    results = await asyncio.gather(
        save(msg), restart_host(msg), return_exceptions=True
    )
    handle_results(results, msg)
    event.set()


async def consume(queue):
    while True:
        msg = await queue.get()
        logging.info(f"Pulled {msg}")
        asyncio.create_task(handle_message(msg))


def handle_exception(loop, context):
    msg = context.get("exception", context["message"])
    logging.error(f"Caught exception: {msg}")
    logging.info("Shutting down...")
    asyncio.create_task(shutdown(loop))


async def shutdown(loop, signal=None):
    if signal:
        logging.info(f"Received exit signal {signal.name}...")
    logging.info("Closing database connections")
    logging.info("Nacking outstanding messages")
    tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]

    [task.cancel() for task in tasks]

    logging.info("Cancelling outstanding tasks")
    await asyncio.gather(*tasks, return_exceptions=True)
    logging.info(f"Flushing metrics")
    loop.stop()


def main():
    loop = asyncio.get_event_loop()
    signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
    for s in signals:
        loop.add_signal_handler(
            s, lambda s=s: asyncio.create_task(shutdown(loop, signal=s))
        )
    loop.set_exception_handler(handle_exception)
    queue = asyncio.Queue()

    try:
        loop.create_task(publish(queue))
        loop.create_task(consume(queue))
        loop.run_forever()
    finally:
        loop.close()
        logging.info("Successfully shutdown the Mayhem service.")



if __name__ == "__main__":  # pragma: no cover
    main()
View Full Source

Simple Testing with pytest
#

We will be using pytest since I prefer writing simple assert statements for my tests.

We will start simple, and test the “happy path” of the save function (i.e. not the code path that raises an exception):

async def save(msg):
    # unhelpful simulation of i/o work
    await asyncio.sleep(random.random()) 
    msg.saved = True
    logging.info(f"Saved {msg} into database")

Since save is a coroutine function, we’ll need to run it on the loop:

import asyncio

import pytest

import mayhem


@pytest.fixture
def message():
    return mayhem.PubSubMessage(message_id="1234", instance_name="mayhem_test")


def test_save(message):
    assert not message.saved  # sanity check
    asyncio.run(mayhem.save(message))
    assert message.saved
View Full Source

Running this via pytest:

$ pytest -v test_mayhem_1.py
test_mayhem_1.py::test_save PASSED                                                           [100%]

Sweet! However if you’re not on 3.7+ yet, you’ll have to construct and deconstruct the loop yourself, rather than making use of asyncio.run:

def test_save(message):
    assert not message.saved  # sanity check
    loop = asyncio.get_event_loop()
    loop.run_until_complete(mayhem.save(message))
    loop.close()
    assert message.saved

This can get annoying, especially when you have many coroutine functions to test. Thankfully, there is a plugin for pytest called pytest-asyncio. This plugin allows you to define your tests themselves as coroutine functions, and manages the event loop for you:

@pytest.mark.asyncio
async def test_save(message):  # <-- now a coroutine!
    assert not message.saved  # sanity check
    await mayhem.save(message)
    assert message.saved
View Full Source

Much cleaner! Using pytest-asyncio can get you pretty far.

Mocking Coroutines
#

When writing unit tests, you’ll often need to mock out coroutine functions that are called within your tested function.

For instance, our save coroutine function calls another coroutine function, asyncio.sleep (a stand-in for a network I/O call to a database):

async def save(msg):
    # unhelpful simulation of i/o work
    await asyncio.sleep(random.random())  # <-- let's mock this out
    msg.saved = True
    logging.info(f"Saved {msg} into database")

You don’t actually want to wait for asyncio.sleep to complete while running your tests, nor do you want an actual call to a database to happen.

Both the unittest.mock and pytest-mock libraries do not support asynchronous mocks, so we’ll have to work around this.

First, in making use of the pytest-mock library, we’ll create a pytest fixture that will return a function:

@pytest.fixture
def create_mock_coro(mocker, monkeypatch):
    """Create a mock-coro pair.

    The coro can be used to patch an async method while the mock can
    be used to assert calls to the mocked out method.
    """

    def _create_mock_coro_pair(to_patch=None):
        mock = mocker.Mock()

        async def _coro(*args, **kwargs):
            return mock(*args, **kwargs)

        if to_patch:
            monkeypatch.setattr(to_patch, _coro)

        return mock, _coro

    return _create_mock_coro_pair
View Full Source

Then, we’ll create another pytest fixture that will use the create_mock_coro fixture to mock and patch asyncio.sleep:

def mock_sleep(create_mock_coro, monkeypatch):
    mock_sleep, _ = create_mock_coro(to_patch="mayhem.asyncio.sleep")
    return mock_sleep
View Full Source

Now let’s use the mock_sleep fixture in our test_save:

@pytest.mark.asyncio
async def test_save(mock_sleep, message):
    assert not message.saved  # sanity check
    await mayhem.save(message)
    assert message.saved
    assert 1 == mock_sleep.call_count
View Full Source

What we’ve done here is basically patched asyncio.sleep in our mayhem module with a coroutine function that returns a mocked object. Then, we assert that the mocked asyncio.sleep object is called once when mayhem.save is called. Because we now have a mock object instead of the actual coroutine, we can now do anything that’s supported with unittest.mock.Mock objects, i.e. our_mocked_object.assert_called_once_with(...), our_mocked_object.return_value = "foo", etc.

Testing create.task
#

For testing coroutine functions that have calls to create.task, we can’t simply use the create_mock_coro fixture. For instance, let’s try to test our consume coroutine function:

async def consume(queue):
    while True:
        msg = await queue.get()
        logging.info(f"Pulled {msg}")
        asyncio.create_task(handle_message(msg))

I have the following fixtures for the asyncio.queue:

@pytest.fixture
def mock_queue(mocker, monkeypatch):
    queue = mocker.Mock()
    monkeypatch.setattr(mayhem.asyncio, "Queue", queue)
    return queue.return_value


@pytest.fixture
def mock_get(mock_queue, create_mock_coro):
    mock_get, coro_get = create_mock_coro()
    mock_queue.get = coro_get
    return mock_get
View Full Source

So let’s try to use create_mock_coro to mock and match the call to handle_message coroutine via create_task.

Note: we’re setting mock_get.side_effect with one “real” value, and one Exception to make sure we’re not permanently stuck within the while True loop that consume has.

@pytest.mark.asyncio
async def test_consume(mock_get, mock_queue, message, create_mock_coro):
    mock_get.side_effect = [message, Exception("break while loop")]
    mock_handle_message, _ = create_mock_coro("mayhem.handle_message")

    with pytest.raises(Exception, match="break while loop"):
        await mayhem.consume(mock_queue)

    mock_handle_message.assert_called_once_with(message)
View Full Source

When running this, we see that mock_handle_message does not actually get called, like we’re expecting:

$ pytest -v test_mayhem_4.py
test_mayhem_4.py::test_consume FAILED                                                        [100%]

=========================================== FAILURES ============================================
_________________________________________ test_consume __________________________________________

mock_get = <Mock id='4477488824'>, mock_queue = <Mock name='mock()' id='4477488880'>
message = Message(instance_name='cattle-1234')
create_mock_coro = <function create_mock_coro.<locals>._create_mock_patch_coro at 0x10add9840>

    @pytest.mark.asyncio
    async def test_consume(mock_get, mock_queue, message, create_mock_coro):
        mock_get.side_effect = [message, Exception("break while loop")]
        mock_handle_message = create_mock_coro("mayhem.handle_message")

        with pytest.raises(Exception, match="break while loop"):
            await mayhem.consume(mock_queue)

>       mock_handle_message.assert_called_once_with(message)
E       AssertionError: Expected 'mock' to be called once. Called 0 times.

test_mayhem_4.py:230: AssertionError
------------------------------------- Captured stderr call --------------------------------------
15:30:37,721 INFO: Pulled Message(instance_name='cattle-1234')
============================== 1 failed, 1 passed in 0.10 seconds ===============================

This is because the scheduled tasks are only scheduled and pending at this point; we need to nudge them along. We do this by collecting all running tasks (that’s not the test itself), and running them explicitly:

@pytest.mark.asyncio
async def test_consume(mock_get, mock_queue, message, create_mock_coro):
    mock_get.side_effect = [message, Exception("break while loop")]
    mock_handle_message, _ = create_mock_coro("mayhem.handle_message")

    with pytest.raises(Exception, match="break while loop"):
        await mayhem.consume(mock_queue)

    ret_tasks = [
        t for t in asyncio.all_tasks() if t is not asyncio.current_task()
    ]
    # should be 1 per side effect minus the Exception (i.e. messages consumed)
    assert 1 == len(ret_tasks)
    mock_handle_message.assert_not_called()  # <-- sanity check

    # explicitly await tasks scheduled by `asyncio.create_task`
    await asyncio.gather(*ret_tasks)

    mock_handle_message.assert_called_once_with(message)
View Full Source

Now pytest is happy:

$ pytest -v test_mayhem_5.py
test_mayhem_5.py::test_consume PASSED                                                       [100%]

Non-async Testing of the Event Loop
#

I hear you wanting to get to 100% test coverage, which may seem difficult for our main function. We’ll make use of pytest-asyncio’s event_loop fixture, with a slight modification.

First, we’ll create our own fixture by inheriting from pytest-asyncio’s event_loop fixture:

@pytest.fixture
def event_loop(event_loop, mocker):
    new_loop = asyncio.get_event_loop_policy().new_event_loop()
    asyncio.set_event_loop(new_loop)
    new_loop._close = new_loop.close
    new_loop.close = mocker.Mock()

    yield new_loop

    new_loop._close()
View Full Source

We’re essentially setting a different event loop that pytest-asyncio will use when it injects it into the tested code. We want to update the testing event loop so we can override the close() behavior, which gets called in our main function. If we close the loop during the test, we’ll lose access to the signal handlers that we setup within the main function. We can replace the close() method with a mock object to still assert that it has been called.

So now, we’ll write a test_main function that actually borders on an integration or functional test. We want to make sure – in addition to the expected calls to publish and consume – that shutdown gets called when expected.

We can’t exactly mock out shutdown with create_mock_coro since it will patch it with just another coroutine and therefore run the mocked coroutine each time it receives a signal. Instead, we’ll mock out the asyncio.gather within the shutdown coroutine. Instead, we’ll just mock out the coroutine that shutdown calls (the asyncio.gather).

And finally, in order to see if the loop actually responds to signals, we need to send a signal to it. We do this by starting a separate thread from which we’ll send a signal to the process itself.

# <-- snip -->
import os
import signal
import time
import threading

# <-- snip -->
def test_main(create_mock_coro, event_loop, mock_queue):
    mock_consume, _ = create_mock_coro("mayhem.consume")
    mock_publish, _ = create_mock_coro("mayhem.publish")
    # mock out `asyncio.gather` that `shutdown` calls instead
    # of `shutdown` itself
    mock_asyncio_gather, _ = create_mock_coro("mayhem.asyncio.gather")

    def _send_signal():
        # allow the loop to start and work a little bit...
        time.sleep(0.1)
        # ...then send a signal
        os.kill(os.getpid(), signal.SIGTERM)

    thread = threading.Thread(target=_send_signal, daemon=True)
    thread.start()

    mayhem.main()

    assert signal.SIGTERM in event_loop._signal_handlers
    assert mayhem.handle_exception == event_loop.get_exception_handler()

    mock_asyncio_gather.assert_called_once_with(return_exceptions=True)
    mock_consume.assert_called_once_with(mock_queue)
    mock_publish.assert_called_once_with(mock_queue)

    # asserting the loop is stopped but not closed
    assert not event_loop.is_running()
    assert not event_loop.is_closed()
    event_loop.close.assert_called_once_with()
View Full Source
$ pytest -v test_mayhem_6.py
test_mayhem_6.py::test_main PASSED                                                          [100%]

We can further parametrize this, as well as test for behavior when the loop receives a signal other than SIGINT, SIGTERM, and SIGHUP. This requires us to add another signal since we want to make sure that other signals do not invoke the defined shutdown behavior. We’ll make use of SIGUSR1 and add a different shutdown mock to the test event loop:

@pytest.mark.parametrize(
    "tested_signal", ("SIGINT","SIGTERM", "SIGHUP", "SIGUSR1")
)
def test_main(tested_signal, create_mock_coro, event_loop, mock_queue, mocker):
    tested_signal = getattr(signal, tested_signal)
    mock_asyncio_gather, _ = create_mock_coro("mayhem.asyncio.gather")
    mock_consume, _ = create_mock_coro("mayhem.consume")
    mock_publish, _ = create_mock_coro("mayhem.publish")

    mock_shutdown = mocker.Mock()
    def _shutdown():
        mock_shutdown()
        event_loop.stop()

    event_loop.add_signal_handler(signal.SIGUSR1, _shutdown)

    def _send_signal():
        time.sleep(0.1)
        os.kill(os.getpid(), tested_signal)

    thread = threading.Thread(target=_send_signal, daemon=True)
    thread.start()

    mayhem.main()

    assert tested_signal in event_loop._signal_handlers
    assert mayhem.handle_exception == event_loop.get_exception_handler()

    mock_consume.assert_called_once_with(mock_queue)
    mock_publish.assert_called_once_with(mock_queue)

    if tested_signal is not signal.SIGUSR1:
        mock_asyncio_gather.assert_called_once_with(return_exceptions=True)
        mock_shutdown.assert_not_called()
    else:
        mock_asyncio_gather.assert_not_called()
        mock_shutdown.assert_called_once_with()

    # asserting the loop is stopped but not closed
    assert not event_loop.is_running()
    assert not event_loop.is_closed()
    event_loop.close.assert_called_once_with()
View Full Source

Look at those happy tests:

pytest -v part-5/test_mayhem_7.py

test_mayhem_7.py::test_main[SIGINT] PASSED                                           [ 25%]
test_mayhem_7.py::test_main[SIGTERM] PASSED                                          [ 50%]
test_mayhem_7.py::test_main[SIGHUP] PASSED                                           [ 75%]
test_mayhem_7.py::test_main[SIGUSR1] PASSED                                          [100%]

To see what near-100% test coverage looks like for mayhem.py, check out part-5/test_mayhem_full.py.

Third-Party Libraries
#

  • The aforementioned pytest-asyncio has other helpful things too like the event_loop , unused_tcp_port, and unused_tcp_port_factory fixtures; and the ability to create your own asynchronous fixtures.
  • asynctest that has a lot of helpful tooling, including coroutine mocks and exhausting callbacks so we don’t have to manually await tasks made by create_task. It does require the use of the unittest and define your tests as asynctest.TestCase. I’m not much of a fan of the unittest style, so perhaps someday someone will create asynctest for pytest :).
  • aiohttp has some really nice built-in test utilities supporting both pytest and unittest.

Recap
#

Basically, by using pytest-asyncio, testing asyncio code isn’t too much different than non-asynchronous code. There is still the clunkiness of needing to manually mock coroutine functions and exhausting the event loop when testing code that uses create_task (an open source contribution opportunity, maybe??).

Asyncio: We Did It Wrong - This article is part of a series.
Part 6: This Article