Testing FastAPI with async database session

whchi

whchi

Posted on July 23, 2023

Testing FastAPI with async database session

Get started

FastAPI uses Python's asyncio module to improve its I/O performance.

According to the official documentation, when using path or Depends, it will always be asynchronous, regardless of whether you use async def (to run in coroutines) or def (to run in the thread pool).

When you use async def for your function, you MUST use the await keyword avoid "sequence" behavior.

This behavior is slightly different from JavaScript's async-await, which could be the subject of another significant discussion.

The code

Suppose your application looks like this

# create db connection
engine = create_async_engine(
    url=get_db_settings().async_connection_string,
    echo=True,
)

async_session_global = sessionmaker(
    autocommit=False,
    autoflush=False,
    bind=engine,
    class_=AsyncSession,
    expire_on_commit=False,
)


async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
    async with async_session_global.begin() as session:
        try:
            yield session
        except:
            await session.rollback()
            raise
        finally:
            await session.close()

# defind fastapi application
app = FastAPI()

router = APIRouter()
@router.get('/api/async-examples/{id}')
def get_example(id: int, db = Depends(get_async_session)):
    return await db.execute(select(Example)).all()

@router.put('/api/async-examples/{id}')
def put_example(id: int, db = Depends(get_async_session)):
    await db.execute(update(Example).where(id=id).values(name='testtest', age=123))
    await db.commit()
    await db.refresh(Example)
    return await db.execute(select(Example).filter_by(id=id)).scalar_one()

app.include_router(router)
Enter fullscreen mode Exit fullscreen mode

Firstly, we need fixtures for our tests. Here, I'll be using asyncpg as my async database connector.

# conftest.py
import asyncio
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from fastapi import FastAPI
import pytest

engine = create_async_engine(
    url='postgresql+asyncpg://...',
    echo=True,
)

# drop all database every time when test complete
@pytest.fixture(scope='session')
async def async_db_engine():
    async with async_engine.begin() as conn:
        await conn.run_sync(SQLModel.metadata.create_all)

    yield async_engine

    async with async_engine.begin() as conn:
        await conn.run_sync(SQLModel.metadata.drop_all)

# truncate all table to isolate tests
@pytest.fixture(scope='function')
async def async_db(async_db_engine):
    async_session = sessionmaker(
        expire_on_commit=False,
        autocommit=False,
        autoflush=False,
        bind=async_db_engine,
        class_=AsyncSession,
    )

    async with async_session() as session:
        await session.begin()

        yield session

        await session.rollback()

        for table in reversed(SQLModel.metadata.sorted_tables):
            await session.execute(f'TRUNCATE {table.name} CASCADE;')
            await session.commit()

@pytest.fixture(scope='session')
async def async_client() -> AsyncClient:
    return AsyncClient(app=FastAPI(), base_url='http://localhost')

# let test session to know it is running inside event loop
@pytest.fixture(scope='session')
def event_loop():
    policy = asyncio.get_event_loop_policy()
    loop = policy.new_event_loop()
    yield loop
    loop.close()

# assume we have a example model
@pytest.fixture
async def async_example_orm(async_db: AsyncSession) -> Example:
    example = Example(name='test', age=18, nick_name='my_nick')
    async_db.add(example)
    await async_db.commit()
    await async_db.refresh(example)
    return example


Enter fullscreen mode Exit fullscreen mode

Then, write our tests

# test_what_ever_you_want.py
# make all test mark with `asyncio`
pytestmark = pytest.mark.asyncio

async def test_get_example(async_client: AsyncClient, async_db: AsyncSession,
                           async_example_orm: Example) -> None:
    response = await async_client.get(f'/api/async-examples/{async_example_orm.id}')

    assert response.status_code == status.HTTP_200_OK
    assert (await async_db.execute(select(Example).filter_by(id=async_example_orm.id)
                                  )).scalar_one().id == async_example_orm.id

async def test_update_example(async_client: AsyncClient, async_db: AsyncSession,
                              async_example_orm: Example) -> None:
    payload = {'name': 'updated_name', 'age': 20}

    response = await async_client.put(f'/api/async-examples/{async_example_orm.id}',
                                      json=payload)
    assert response.status_code == status.HTTP_200_OK
    await async_db.refresh(async_example_orm)
    assert (await
            async_db.execute(select(Example).filter_by(id=async_example_orm.id)
                            )).scalar_one().name == response.json()['data']['name']
Enter fullscreen mode Exit fullscreen mode

The key here is async_db and event_loop, and also you have to make sure your program's db session does not using global commit.

💖 💪 🙅 🚩
whchi
whchi

Posted on July 23, 2023

Join Our Newsletter. No Spam, Only the good stuff.

Sign up to receive the latest update from our blog.

Related