FastAPI Role Base Access Control With JWT

hirushafernando

Hirusha Fernando

Posted on March 17, 2024

FastAPI Role Base Access Control With JWT

FastAPI is a modern web framework that's really fast and works great for creating APIs using Python 3.8 or higher. It's known for its speed and efficiency compared to other Python frameworks. When you're building APIs with any framework, handling authentication and authorization is crucial. In this article, we'll focus on implementing Role-based access control using JWT in FastAPI.

Prerequisites

  • Python programming knowledge

  • Basic knowledge about FastAPI

Before you start you have to install these python modules.

  • fastapi

  • pydantic

  • uvicorn[standard]

  • passlib[bcrypt]

  • python-jose[cryptography]

Setting Up The Environment
Let’s create two API endpoints in the main.py file.

from fastapi import FastAPI

app = FastAPI()

@app.get("/hello")
def hello_func():
  return "Hello World"

@app.get("/data")
def get_data():
  return {"data": "This is important data"} 
Enter fullscreen mode Exit fullscreen mode

Let’s create a User Model and Token Model in models.py

from pydantic import BaseModel

class User(BaseModel):
  username: str | None = None
  email: str | None = None
  role: str | None = None
  disabled: bool| None = None
  hashed_password: str | None = None

class Token(BaseModel):
  access_token:: str | None = None
  refresh_token: str | None = None
Enter fullscreen mode Exit fullscreen mode

For this tutorial, I will create a Python dictionary containing dummy users in data.py. Also, I will create another list for store refresh tokens. You can use any database for this like PostgreSQL, MongoDB, etc.

fake_user_db = [
  {
     "username": "johndoe",
     "email": "john@emaik.com",
     "role": "admin",
     "hashed_password": "hdjsbdvdhxbzbsksjdbdbzjdhh45tbdbd7bdbd",
     "is_active": True
  },
  {
     "username": "alice",
     "email": "al8ce@emaik.com",
     "role": "user",
     "hashed_password": "hdjsbdvdhxbzbsksjdbdbzjdhh45tbdbd7bdbd",
     "is_active": True
  }
]

refresh_tokens = []
Enter fullscreen mode Exit fullscreen mode

How This Works
To set up authentication for our API, we'll follow these steps: First, users log in with their username and password through a post request. Then, our backend checks if their details are correct and generates two types of tokens: an access token and a refresh token. The access token is short-lived, while the refresh token lasts longer. Once validated, the backend sends back these tokens to the user. To access secure parts of the API, users need to include the access token in their request header. If the access token expires, users can request a new one by sending their refresh token to the backend. This process ensures secure access to our API endpoints.

Role Based Access Control (RBAC)
FastAPI provides several ways to deal with security. Here we use the OAuth2 with password flow. (You can get more details from this link.) We do that using the OAuth2PasswordBearer class. Also, we use passlib CryptContext to hash and verify passwords.

Let’s create auth.py. First, create instances of the above classes.

#auth.py
from fastapi.security import OAuth2PasswordBearer 
from passlib.context import CryptContext

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
Enter fullscreen mode Exit fullscreen mode

We pass the tokenUrl parameter to this class. This parameter contains the URL that the client uses to send the username and password in order to get a token. We haven’t created this endpoint yet. But we will create it later.

Now create a method to get the user details from db and another method to authenticate users. This method will check the password.

#auth.py
from db import User
from passlib.context import CryptContext

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

def get_user(db, username: str):
  if username in db:
    user = db[username]
    return User(**user)


def authenticate_user(fake_db, username: str, password: str):
    user = get_user(fake_db, username)
    if not user:
        return False
    if not pwd_context.verify(plain_password, hashed_password):
        return False
    return user
Enter fullscreen mode Exit fullscreen mode

Now let’s handle the JWT. To do that create some variables and a method to create JWT token.

#auth.py
from jose import JWTError, jwt
from datetime import datetime, timedelta, timezone
from data import refresh_tokens


SECRET_KEY = "hdhfh5jdnb7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 20
REFRESH_TOKEN_EXPIRE_MINUTES = 120


def create_token(data: dict, expires_delta: timedelta | None = None):
    to_encode = data.copy()
    if expires_delta:
        expire = datetime.now(timezone.utc) + expires_delta
    else:
        expire = datetime.now(timezone.utc) + timedelta(minutes=15)
    to_encode.update({"exp": expire})
    encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
    return encoded_jwt
Enter fullscreen mode Exit fullscreen mode

We pass our data and token lifetime to this method and it returns the JWT token.

Once authentication is in place, we'll develop a method to retrieve details about the currently logged-in user. This method will take the token as input. It will decode the token to extract user data and then verify if the user exists in the database. If the user exists, the method will return the user's details. However, if the user does not exist, it will raise an exception to indicate the issue. This approach ensures that only valid users can access their information, maintaining security and integrity within the system.

#auth.py
from typing import Annotated
from jose import JWTError, jwt
from fastapi import Depends, HTTPException, status

SECRET_KEY = "hdhfh5jdnb7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
ALGORITHM = "HS256"

async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]):
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Could not validate credentials",
        headers={"WWW-Authenticate": "Bearer"},
    )
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        username: str = payload.get("sub")
        if username is None:
            raise credentials_exception
    except JWTError:
        raise credentials_exception
    user = get_user(fake_users_db, username=username)
    if user is None:
        raise credentials_exception
    return user

async def get_current_active_user(
    current_user: Annotated[User, Depends(get_current_user)]
):
    if current_user.disabled:
        raise HTTPException(status_code=400, detail="Inactive user")
    return current_user
Enter fullscreen mode Exit fullscreen mode

In addition, we'll implement another method to verify if the user is enabled or disabled. If the user is disabled, the method will raise an exception. The "Depends()" function in the provided code signifies a dependency relationship. For instance, the "get_current_active_user()" method relies on the "get_current_user()" method. When debugging, you'll notice that the "get_current_user()" method executes before "get_current_active_user()".

Now, we'll introduce the "RoleChecker" class to validate user roles. If the user's role grants sufficient permissions, the method will return True. Otherwise, it will raise an exception. This class helps ensure that users only access functionalities appropriate for their assigned roles, maintaining security and access control within the system.

#auth.py
from typing import Annotated
from fastapi import Depends, HTTPException, status


class RoleChecker:
  def __init__(self, allowed_roles):
    self.allowed_roles = allowed_roles

  def __call__(self, user: Annotated[User, Depends(get_current_active_user)]):
    if user.role in self.allowed_roles:
      return True
    raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, 
detail="You don't have enough permissions")
Enter fullscreen mode Exit fullscreen mode

We have to create one more method to validate the refresh token. When the access token expires, we have to request our refresh token to get a new access token.

#auth.py
from typing import Annotated
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from data import refresh_tokens


SECRET_KEY = "hdhfh5jdnb7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
ALGORITHM = "HS256"


oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")


async def validate_refresh_token(token: Annotated[str, Depends(oauth2_scheme)]):
    credentials_exception = HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials")
    try:
        if token in refresh_tokens:
            payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
            username: str = payload.get("sub")
            role: str = payload.get("role")
            if username is None or role is None:
                raise credentials_exception
        else:
            raise credentials_exception

    except (JWTError, ValidationError):
        raise credentials_exception

    user = get_user(fake_users_db, username=username)

    if user is None:
        raise credentials_exception

    return user, token
Enter fullscreen mode Exit fullscreen mode

The final auth.py file looks like this.

from fastapi.security import OAuth2PasswordBearer 
from passlib.context import CryptContext
from db import User
from jose import JWTError, jwt
from datetime import datetime, timedelta, timezone
from data import refresh_tokens
from typing import Annotated
from fastapi import Depends, HTTPException, status

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

SECRET_KEY = "hdhfh5jdnb7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
ALGORITHM = "HS256"

def get_user(db, username: str):
  if username in db:
    user = db[username]
    return User(**user)


def authenticate_user(fake_db, username: str, password: str):
    user = get_user(fake_db, username)
    if not user:
        return False
    if not pwd_context.verify(plain_password, hashed_password):
        return False
    return user


def create_token(data: dict, expires_delta: timedelta | None = None):
    to_encode = data.copy()
    if expires_delta:
        expire = datetime.now(timezone.utc) + expires_delta
    else:
        expire = datetime.now(timezone.utc) + timedelta(minutes=15)
    to_encode.update({"exp": expire})
    encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
    return encoded_jwt


async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]):
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Could not validate credentials",
        headers={"WWW-Authenticate": "Bearer"},
    )
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        username: str = payload.get("sub")
        if username is None:
            raise credentials_exception
    except JWTError:
        raise credentials_exception
    user = get_user(fake_users_db, username=username)
    if user is None:
        raise credentials_exception
    return user


async def get_current_active_user(current_user: Annotated[User, Depends(get_current_user)]):
    if current_user.disabled:
        raise HTTPException(status_code=400, detail="Inactive user")
    return current_user


async def validate_refresh_token(token: Annotated[str, Depends(oauth2_scheme)]):
    credentials_exception = HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials")
    try:
        if token in refresh_tokens:
            payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
            username: str = payload.get("sub")
            role: str = payload.get("role")
            if username is None or role is None:
                raise credentials_exception
        else:
            raise credentials_exception

    except (JWTError, ValidationError):
        raise credentials_exception
    user = get_user(fake_users_db, username=username)
    if user is None:
        raise credentials_exception
    return user, token


class RoleChecker:
  def __init__(self, allowed_roles):
    self.allowed_roles = allowed_roles

  def __call__(self, user: Annotated[User, Depends(get_current_active_user)]):
    if user.role in self.allowed_roles:
      return True
    raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, 
detail="You don't have enough permissions")
Enter fullscreen mode Exit fullscreen mode

Ok. We created the authentication and authorization parts. Now we can add these to our API endpoints. Before doing that we should create two endpoints. One is login and the other one is for refreshing tokens. Let’s go to main.py again.

from datetime import timedelta
from typing import Annotated

from fastapi import Depends, FastAPI, HTTPException
from fastapi.security import OAuth2PasswordRequestForm

from auth import create_token, authenticate_user, RoleChecker, get_current_active_user, validate_refresh_token
from data import fake_users_db, refresh_tokens
from models import User, Token

app = FastAPI()

ACCESS_TOKEN_EXPIRE_MINUTES = 20
REFRESH_TOKEN_EXPIRE_MINUTES = 120

@app.get("/hello")
def hello_func():
  return "Hello World"

@app.get("/data")
def get_data():
  return {"data": "This is important data"} 

@app.post("/token")
async def login_for_access_token(form_data: Annotated[OAuth2PasswordRequestForm, Depends()]) -> Token:
    user = authenticate_user(fake_users_db, form_data.username, form_data.password)
    if not user:
        raise HTTPException(status_code=400, detail="Incorrect username or password")

    access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
    refresh_token_expires = timedelta(minutes=REFRESH_TOKEN_EXPIRE_MINUTES)

    access_token = create_token(data={"sub": user.username, "role": user.role}, expires_delta=access_token_expires)
    refresh_token = create_token(data={"sub": user.username, "role": user.role}, expires_delta=refresh_token_expires)
    refresh_tokens.append(refresh_token)
    return Token(access_token=access_token, refresh_token=refresh_token)

@app.post("/refresh")
async def refresh_access_token(token_data: Annotated[tuple[User, str], Depends(validate_refresh_token)]):
    user, token = token_data
    access_token = create_token(data={"sub": user.username, "role": user.role}, expires_delta=access_token_expires)
    refresh_token = create_token(data={"sub": user.username, "role": user.role}, expires_delta=refresh_token_expires)

    refresh_tokens.remove(token)
    refresh_tokens.append(refresh_token)
    return Token(access_token=access_token, refresh_token=refresh_token)
Enter fullscreen mode Exit fullscreen mode

Add RBAC To API
Now let’s add RBAC to our endpoints. For now, the “/data” endpoint is not protected. It can be accessed by anyone. You can check it using Swagger Docs or Postman. Now let’s add RBAC to this endpoint.

@app.get("/data")
def get_data(_: Annotated[bool, Depends(RoleChecker(allowed_roles=["admin"]))]):
  return {"data": "This is important data"}
Enter fullscreen mode Exit fullscreen mode

After doing this, it can be only accessed after login as an admin user. Like that you can add this to any endpoint that you want to protect. Now you know how to add RBAC to FastAPI. This is only one method. There are some other methods to do this. You can find it on the Internet. Happy Coding !

Buy Me A Coffee

You can reach me on

💖 💪 🙅 🚩
hirushafernando
Hirusha Fernando

Posted on March 17, 2024

Join Our Newsletter. No Spam, Only the good stuff.

Sign up to receive the latest update from our blog.

Related