Hirusha Fernando
Posted on March 17, 2024
FastAPI is a modern web framework that's really fast and works great for creating APIs using Python 3.8 or higher. It's known for its speed and efficiency compared to other Python frameworks. When you're building APIs with any framework, handling authentication and authorization is crucial. In this article, we'll focus on implementing Role-based access control using JWT in FastAPI.
Prerequisites
Python programming knowledge
Basic knowledge about FastAPI
Before you start you have to install these python modules.
fastapi
pydantic
uvicorn[standard]
passlib[bcrypt]
python-jose[cryptography]
Setting Up The Environment
Let’s create two API endpoints in the main.py file.
from fastapi import FastAPI
app = FastAPI()
@app.get("/hello")
def hello_func():
return "Hello World"
@app.get("/data")
def get_data():
return {"data": "This is important data"}
Let’s create a User Model and Token Model in models.py
from pydantic import BaseModel
class User(BaseModel):
username: str | None = None
email: str | None = None
role: str | None = None
disabled: bool| None = None
hashed_password: str | None = None
class Token(BaseModel):
access_token:: str | None = None
refresh_token: str | None = None
For this tutorial, I will create a Python dictionary containing dummy users in data.py. Also, I will create another list for store refresh tokens. You can use any database for this like PostgreSQL, MongoDB, etc.
fake_user_db = [
{
"username": "johndoe",
"email": "john@emaik.com",
"role": "admin",
"hashed_password": "hdjsbdvdhxbzbsksjdbdbzjdhh45tbdbd7bdbd",
"is_active": True
},
{
"username": "alice",
"email": "al8ce@emaik.com",
"role": "user",
"hashed_password": "hdjsbdvdhxbzbsksjdbdbzjdhh45tbdbd7bdbd",
"is_active": True
}
]
refresh_tokens = []
How This Works
To set up authentication for our API, we'll follow these steps: First, users log in with their username and password through a post request. Then, our backend checks if their details are correct and generates two types of tokens: an access token and a refresh token. The access token is short-lived, while the refresh token lasts longer. Once validated, the backend sends back these tokens to the user. To access secure parts of the API, users need to include the access token in their request header. If the access token expires, users can request a new one by sending their refresh token to the backend. This process ensures secure access to our API endpoints.
Role Based Access Control (RBAC)
FastAPI provides several ways to deal with security. Here we use the OAuth2 with password flow. (You can get more details from this link.) We do that using the OAuth2PasswordBearer class. Also, we use passlib CryptContext to hash and verify passwords.
Let’s create auth.py. First, create instances of the above classes.
#auth.py
from fastapi.security import OAuth2PasswordBearer
from passlib.context import CryptContext
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
We pass the tokenUrl parameter to this class. This parameter contains the URL that the client uses to send the username and password in order to get a token. We haven’t created this endpoint yet. But we will create it later.
Now create a method to get the user details from db and another method to authenticate users. This method will check the password.
#auth.py
from db import User
from passlib.context import CryptContext
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def get_user(db, username: str):
if username in db:
user = db[username]
return User(**user)
def authenticate_user(fake_db, username: str, password: str):
user = get_user(fake_db, username)
if not user:
return False
if not pwd_context.verify(plain_password, hashed_password):
return False
return user
Now let’s handle the JWT. To do that create some variables and a method to create JWT token.
#auth.py
from jose import JWTError, jwt
from datetime import datetime, timedelta, timezone
from data import refresh_tokens
SECRET_KEY = "hdhfh5jdnb7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 20
REFRESH_TOKEN_EXPIRE_MINUTES = 120
def create_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
We pass our data and token lifetime to this method and it returns the JWT token.
Once authentication is in place, we'll develop a method to retrieve details about the currently logged-in user. This method will take the token as input. It will decode the token to extract user data and then verify if the user exists in the database. If the user exists, the method will return the user's details. However, if the user does not exist, it will raise an exception to indicate the issue. This approach ensures that only valid users can access their information, maintaining security and integrity within the system.
#auth.py
from typing import Annotated
from jose import JWTError, jwt
from fastapi import Depends, HTTPException, status
SECRET_KEY = "hdhfh5jdnb7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
ALGORITHM = "HS256"
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
except JWTError:
raise credentials_exception
user = get_user(fake_users_db, username=username)
if user is None:
raise credentials_exception
return user
async def get_current_active_user(
current_user: Annotated[User, Depends(get_current_user)]
):
if current_user.disabled:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
In addition, we'll implement another method to verify if the user is enabled or disabled. If the user is disabled, the method will raise an exception. The "Depends()" function in the provided code signifies a dependency relationship. For instance, the "get_current_active_user()" method relies on the "get_current_user()" method. When debugging, you'll notice that the "get_current_user()" method executes before "get_current_active_user()".
Now, we'll introduce the "RoleChecker" class to validate user roles. If the user's role grants sufficient permissions, the method will return True. Otherwise, it will raise an exception. This class helps ensure that users only access functionalities appropriate for their assigned roles, maintaining security and access control within the system.
#auth.py
from typing import Annotated
from fastapi import Depends, HTTPException, status
class RoleChecker:
def __init__(self, allowed_roles):
self.allowed_roles = allowed_roles
def __call__(self, user: Annotated[User, Depends(get_current_active_user)]):
if user.role in self.allowed_roles:
return True
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="You don't have enough permissions")
We have to create one more method to validate the refresh token. When the access token expires, we have to request our refresh token to get a new access token.
#auth.py
from typing import Annotated
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from data import refresh_tokens
SECRET_KEY = "hdhfh5jdnb7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
ALGORITHM = "HS256"
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
async def validate_refresh_token(token: Annotated[str, Depends(oauth2_scheme)]):
credentials_exception = HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials")
try:
if token in refresh_tokens:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
role: str = payload.get("role")
if username is None or role is None:
raise credentials_exception
else:
raise credentials_exception
except (JWTError, ValidationError):
raise credentials_exception
user = get_user(fake_users_db, username=username)
if user is None:
raise credentials_exception
return user, token
The final auth.py file looks like this.
from fastapi.security import OAuth2PasswordBearer
from passlib.context import CryptContext
from db import User
from jose import JWTError, jwt
from datetime import datetime, timedelta, timezone
from data import refresh_tokens
from typing import Annotated
from fastapi import Depends, HTTPException, status
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
SECRET_KEY = "hdhfh5jdnb7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
ALGORITHM = "HS256"
def get_user(db, username: str):
if username in db:
user = db[username]
return User(**user)
def authenticate_user(fake_db, username: str, password: str):
user = get_user(fake_db, username)
if not user:
return False
if not pwd_context.verify(plain_password, hashed_password):
return False
return user
def create_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
except JWTError:
raise credentials_exception
user = get_user(fake_users_db, username=username)
if user is None:
raise credentials_exception
return user
async def get_current_active_user(current_user: Annotated[User, Depends(get_current_user)]):
if current_user.disabled:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
async def validate_refresh_token(token: Annotated[str, Depends(oauth2_scheme)]):
credentials_exception = HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials")
try:
if token in refresh_tokens:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
role: str = payload.get("role")
if username is None or role is None:
raise credentials_exception
else:
raise credentials_exception
except (JWTError, ValidationError):
raise credentials_exception
user = get_user(fake_users_db, username=username)
if user is None:
raise credentials_exception
return user, token
class RoleChecker:
def __init__(self, allowed_roles):
self.allowed_roles = allowed_roles
def __call__(self, user: Annotated[User, Depends(get_current_active_user)]):
if user.role in self.allowed_roles:
return True
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="You don't have enough permissions")
Ok. We created the authentication and authorization parts. Now we can add these to our API endpoints. Before doing that we should create two endpoints. One is login and the other one is for refreshing tokens. Let’s go to main.py again.
from datetime import timedelta
from typing import Annotated
from fastapi import Depends, FastAPI, HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from auth import create_token, authenticate_user, RoleChecker, get_current_active_user, validate_refresh_token
from data import fake_users_db, refresh_tokens
from models import User, Token
app = FastAPI()
ACCESS_TOKEN_EXPIRE_MINUTES = 20
REFRESH_TOKEN_EXPIRE_MINUTES = 120
@app.get("/hello")
def hello_func():
return "Hello World"
@app.get("/data")
def get_data():
return {"data": "This is important data"}
@app.post("/token")
async def login_for_access_token(form_data: Annotated[OAuth2PasswordRequestForm, Depends()]) -> Token:
user = authenticate_user(fake_users_db, form_data.username, form_data.password)
if not user:
raise HTTPException(status_code=400, detail="Incorrect username or password")
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
refresh_token_expires = timedelta(minutes=REFRESH_TOKEN_EXPIRE_MINUTES)
access_token = create_token(data={"sub": user.username, "role": user.role}, expires_delta=access_token_expires)
refresh_token = create_token(data={"sub": user.username, "role": user.role}, expires_delta=refresh_token_expires)
refresh_tokens.append(refresh_token)
return Token(access_token=access_token, refresh_token=refresh_token)
@app.post("/refresh")
async def refresh_access_token(token_data: Annotated[tuple[User, str], Depends(validate_refresh_token)]):
user, token = token_data
access_token = create_token(data={"sub": user.username, "role": user.role}, expires_delta=access_token_expires)
refresh_token = create_token(data={"sub": user.username, "role": user.role}, expires_delta=refresh_token_expires)
refresh_tokens.remove(token)
refresh_tokens.append(refresh_token)
return Token(access_token=access_token, refresh_token=refresh_token)
Add RBAC To API
Now let’s add RBAC to our endpoints. For now, the “/data” endpoint is not protected. It can be accessed by anyone. You can check it using Swagger Docs or Postman. Now let’s add RBAC to this endpoint.
@app.get("/data")
def get_data(_: Annotated[bool, Depends(RoleChecker(allowed_roles=["admin"]))]):
return {"data": "This is important data"}
After doing this, it can be only accessed after login as an admin user. Like that you can add this to any endpoint that you want to protect. Now you know how to add RBAC to FastAPI. This is only one method. There are some other methods to do this. You can find it on the Internet. Happy Coding !
You can reach me on
Posted on March 17, 2024
Join Our Newsletter. No Spam, Only the good stuff.
Sign up to receive the latest update from our blog.