Make AWS Lambda handler testable

jaymecd

Nikolai Zujev

Posted on May 31, 2020

Make AWS Lambda handler testable

To unit test or not to unit test, that is the question many developers and engineers ask themselves when dealing with AWS Lambda functions.

For a one-time hack, the code is straightforward and testing is an overhead here. However, code meant to be used in production daily at scale, especially Lambda functions in a complex domain, must be testable and tested.

This is the first article from my AWS Lambda packaging series, in which I'd like to show how to make AWS Lambda handler loosely-coupled and more testable.

Note: I'm not native english speaker, so I beg your pardon in advance.

Start simple they say

Let's take a look at a simple Lambda function, an abstract example -
add a CostCenter tag to the resource arn from an incoming request and return back all of its tags.

Initial layout of the files used in this example:

$ tree python

python
├── src
│   ├── __init__.py
│   └── main.py
└── test
    ├── __init__.py
    └── test_main.py

Code would be written in Python, however the implementation is portable to AWS Lambda runtimes, such as nodejs, go & etc..., as well as adopted for other cloud providers.

# src/main.py
import boto3
import os

cost_center = os.environ["COST_CENTER"]
client = boto3.client('lambda')

def handler(event: dict, context: object) -> dict:
    assert cost_center, "expecting cost center ID"

    client.tag_resource(
        Resource=event["arn"],
        Tags={"CostCenter": cost_center},
    )

    response = client.list_tags(
        Resource=event["arn"],
    )

    return response["Tags"]

There is nothing wrong with this code at first glance. It's concise and does what it was designed to do. Except there is an issue - testability.

First, cost_center and client are global residents. Second, domain logic is embedded within I/O. Third, the Lambda function will always fail during a cold start, if COST_CENTER envvar is missing or empty.

What should I do to make it safer? Write a unit test. Potentially, it may look like this:

# test/test_main.py
import sys
from unittest import mock
import pytest

@mock.patch("boto3.client", autospec=True)
@mock.patch.dict("os.environ", {"COST_CENTER": "a1b2c3"})
def test__handler__success(client_fn):
    try:
        # clean up modules cache
        del sys.modules["src.main"]
    except KeyError:
        pass

    client = client_fn.return_value

    import src.main

    client.list_tags.return_value = {"Tags": {"CostCenter": "a1b2c3", "Stage": "dev"}}
    event = {"arn": "arn:aws:lambda:eu-west-1:222222222:function:tag-me"}
    context = object()

    result = src.main.handler(event, context)

    client_fn.assert_called_once_with("lambda")

    assert client.list_tags.return_value["Tags"] == result

    client.tag_resource.assert_called_once_with(
        Resource=event["arn"], Tags={"CostCenter": "a1b2c3"},
    )

    client.list_tags.assert_called_once_with(Resource=event["arn"],)

@mock.patch("boto3.client", autospec=True)
@mock.patch.dict("os.environ", clear=True)
def test__handler__missing_envvar(client_fn):
    try:
        # clean up modules cache
        del sys.modules["src.main"]
    except KeyError:
        pass

    client = client_fn.return_value

    with pytest.raises(KeyError) as excinfo:
        import src.main

    assert "'COST_CENTER'" == str(excinfo.value)

@mock.patch("boto3.client", autospec=True)
@mock.patch.dict("os.environ", {"COST_CENTER": "a1b2c3"})
def test__handler__invalid_payload(client_fn):
    try:
        # clean up modules cache
        del sys.modules["src.main"]
    except KeyError:
        pass

    client = client_fn.return_value

    import src.main

    client.list_tags.return_value = {"Tags": {"CostCenter": "a1b2c3", "Stage": "dev"}}
    event = {"invalid": "payload"}
    context = object()

    with pytest.raises(KeyError) as excinfo:
        src.main.handler(event, context)

    assert "'arn'" == str(excinfo.value)

    client_fn.assert_called_once_with("lambda")

    client.tag_resource.assert_not_called()
    client.list_tags.assert_not_called()

Have you noticed a repetitive pattern in the tests prerequisites set up? This is called accidental complexity, we are attempting to set up what is not relevant for the test itself.

To improve testability, the architecture should be improved as well. Once the domain gets bigger and function starts to interact with other AWS services, test complexity will grow drastically. And that's what you want you to avoid.

Code is more like guidelines

There are dozens of code examples out there and many people believe, these examples show the only way to solve the problem, but it's not.

As a big fan of "Pirates of the Caribbean" franchise, I can't help but remember following quote:

The code is more what you'd call "guidelines" than actual rules.

-- Captain Hector Barbossa

This means, we can push the boundaries and make this code more testable and modular by making it loosely-coupled. And we should not be afraid of doing that.

To achieve that, it's important to off-load domain logic from I/O to a separate file:

# src/services.py
from typing import Callable
import boto3

def bootstrap(cost_center: str) -> Callable[[dict, object], dict]:
    session = boto3.Session()

    return handler_factory(session, cost_center)

def handler_factory(session: boto3.Session, cost_center: str) -> Callable[[dict, object], dict]:
    assert cost_center, "expecting cost center ID"

    lambda_client = session.client('lambda')

    def handler(event: dict, context: object) -> dict:
        lambda_client.tag_resource(
            Resource=event["arn"],
            Tags={"CostCenter": cost_center},
        )

        response = lambda_client.list_tags(
            Resource=event["arn"],
        )

        return response["Tags"]

    return handler

We've introduced two major changes:

  • wrap the handler with a parameterised factory method to pass dependencies and encapsulate the constructor.
  • add bootstrap, an essential part of our Lambda function, to bring the puzzle pieces together.

As a result, handler file is now lean:

# src/main.py
import os
from .services import bootstrap

handler = bootstrap(
    cost_center=os.getenv("COST_CENTER"),
)

Isn't it beautiful? Hell yeah! It's just pure initialisation, nothing else.

So, due to the evolution of source code, tests are upgraded as well.

The handler test file was shrunken. We don't care about what the handler does or what it should return. There is no need to mock and assert dependencies outside of the test scope.

# test/test_main.py
import sys
from unittest import mock

@mock.patch("src.services.bootstrap", autospec=True)
@mock.patch.dict("os.environ", {"COST_CENTER": "a1b2c3"})
def test__handler__success(bootstrap):
    try:
        # clean up modules cache
        del sys.modules["src.main"]
    except KeyError:
        pass

    bootstrap.return_value = lambda event, context: event

    from src.main import handler

    assert handler == bootstrap.return_value

    bootstrap.assert_called_once_with(cost_center="a1b2c3")

@mock.patch("src.services.bootstrap", autospec=True)
@mock.patch.dict("os.environ", clear=True)
def test__handler__missing_envvar(bootstrap):
    try:
        # clean up modules cache
        del sys.modules["src.main"]
    except KeyError:
        pass

    bootstrap.return_value = lambda event, context: event

    from src.main import handler

    assert handler == bootstrap.return_value

    bootstrap.assert_called_once_with(cost_center=None)

Et voilá, the Lambda function handler file is tested. Wait, where is the test__handler__invalid_payload test? It's gone, it's out of scope.

And now we can test bootstrap and handler_factory safely within an isolated test case:

# test/test_services.py
from unittest import mock
import pytest

from src import services

@mock.patch("src.services.handler_factory", autospec=True)
@mock.patch("src.services.boto3.Session", autospec=True)
def test__bootstrap__success(Session, handler_factory):
    session = Session.return_value
    handler_factory.return_value = lambda event, context: event

    handler = services.bootstrap(cost_center="a1b2c3")

    assert handler == handler_factory.return_value

    handler_factory.assert_called_once_with(session, "a1b2c3")
    Session.assert_called_once()

@mock.patch("src.services.handler_factory", autospec=True)
@mock.patch("src.services.boto3.Session", autospec=True)
def test__bootstrap__error(Session, handler_factory):
    session = Session.return_value
    handler_factory.side_effect = RuntimeError("error happened")

    with pytest.raises(RuntimeError) as excinfo:
        services.bootstrap(cost_center="a1b2c3")

    assert "error happened" == str(excinfo.value)

    handler_factory.assert_called_once_with(session, "a1b2c3")
    Session.assert_called_once()

@mock.patch("src.services.boto3.Session", autospec=True)
def test__handler_factory__success(Session):
    session = Session.return_value

    lambda_client = session.client.return_value
    lambda_client.list_tags.return_value = {
        "Tags": {"Stage": "dev", "CostCenter": "a1b2c3"}
    }

    handler = services.handler_factory(session, "a1b2c3")

    assert callable(handler)

    event = {"arn": "arn:aws:lambda:eu-west-1:222222222:function:tag-me"}
    context = object()

    result = handler(event, context)

    assert lambda_client.list_tags.return_value["Tags"] == result

    session.client.assert_called_once_with("lambda")

    lambda_client.tag_resource.assert_called_once_with(
        Resource=event["arn"], Tags={"CostCenter": "a1b2c3"},
    )

@mock.patch("src.services.boto3.Session", autospec=True)
def test__handler_factory__empty_param(Session):
    session = Session.return_value

    with pytest.raises(AssertionError) as excinfo:
        services.handler_factory(session, "")

    assert "expecting cost center ID" == str(excinfo.value)

@mock.patch("src.services.boto3.Session", autospec=True)
def test__handler_factory__invalid_payload(Session):
    session = Session.return_value

    lambda_client = session.client.return_value

    handler = services.handler_factory(session, "a1b2c3")

    assert callable(handler)

    event = {"invalid": "payload"}
    context = object()

    with pytest.raises(AssertionError) as excinfo:
        handler(event, context)

    assert "expecting 'arn' key" == str(excinfo.value)

    lambda_client.tag_resource.assert_not_called()
    lambda_client.list_tags.assert_not_called()

A clear focus inside each test. Our goal is achieved!

The complete working code can be found on GitHub jaymecd/package-aws-lambda repository.

Stay tuned for the next article from AWS Lambda packaging series, in which I share how to keep Lambda warm, even if it fails during a cold start.


UPD: 2020-06-01

Based on valuable feedback I got on the first day of the publication:

  • improve the readability of a post
  • simplify raising error tests
  • connect GitHub repository

UPD: 2020-06-04

  • add nodejs implementation on GitHub

/creds: cover image by Joshua Sortino


💖 💪 🙅 🚩
jaymecd
Nikolai Zujev

Posted on May 31, 2020

Join Our Newsletter. No Spam, Only the good stuff.

Sign up to receive the latest update from our blog.

Related

Make AWS Lambda handler testable