How to efficiently use drf_social_oauth2 and django_rest_framework_simplejwt
codewitgabi
Posted on June 13, 2024
Hey guys!!
So I was working on a rest api application with django and django rest framework but then I happened to run into a lot of issues using drf_social_oauth2 and django_rest_framework_simplejwt. The issue was that the former strictly uses Bearer
authentication header while the latter uses any
authorization header of your choice so what I did initially was to use JWT
for simplejwt. This worked fine even though on the client side, my team and I had to write some logic to know if the user initially logged in via oauth or regular auth. This was okay but I was not too comfortable with it. I later thought of a solution to fix the issue and today, I will be showing you what I did to fix this boring issue.
# settings.py
# before
REST_FRAMEWORK = {
"DEFAULT_AUTHENTICATION_CLASSES": (
"rest_framework_simplejwt.authentication.JWTAuthentication",
"oauth2_provider.contrib.rest_framework.OAuth2Authentication",
"drf_social_oauth2.authentication.SocialAuthentication",
),
}
# settings
# after
REST_FRAMEWORK = {
"DEFAULT_AUTHENTICATION_CLASSES": (
"rest_framework_simplejwt.authentication.JWTAuthentication",
),
}
So from the two code in our settings.py, you can see we had to remove the oauth2_provider and simplejwt authentication from the authentication classes and that's because we will no longer be using it to authenticate. We will default back to using simplejwt and that brings us to the next part.
# views.py
# builtin imports
from json import loads as json_loads
from datetime import datetime
from oauth2_provider.settings import oauth2_settings
from oauth2_provider.views.mixins import OAuthLibMixin
from oauthlib.oauth2.rfc6749.errors import (
InvalidClientError,
UnsupportedGrantTypeError,
AccessDeniedError,
MissingClientIdError,
InvalidRequestError,
)
from django.views.decorators.csrf import csrf_exempt
from django.utils.decorators import method_decorator
from django.contrib.auth import get_user_model
# third party imports
from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework.permissions import AllowAny
from rest_framework_simplejwt.tokens import RefreshToken, AccessToken
from rest_framework.status import HTTP_400_BAD_REQUEST
from rest_framework.request import Request
from drf_social_oauth2.serializers import ConvertTokenSerializer
from drf_social_oauth2.oauth2_backends import KeepRequestCore
from drf_social_oauth2.oauth2_endpoints import SocialTokenServer
# user object
User = get_user_model()
class CsrfExemptMixin:
"""
Exempts the view from CSRF requirements.
NOTE:
This should be the left-most mixin of a view.
"""
@method_decorator(csrf_exempt)
def dispatch(self, *args, **kwargs):
return super(CsrfExemptMixin, self).dispatch(*args, **kwargs)
class ConvertTokenView(CsrfExemptMixin, OAuthLibMixin, APIView):
"""
Implements an endpoint to convert a provider token to an access token
The endpoint is used in the following flows:
* Authorization code
* Client credentials
"""
server_class = SocialTokenServer
validator_class = oauth2_settings.OAUTH2_VALIDATOR_CLASS
oauthlib_backend_class = KeepRequestCore
permission_classes = (AllowAny,)
def post(self, request: Request, *args, **kwargs):
serializer = ConvertTokenSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
# Use the rest framework `.data` to fake the post body of the django request.
request._request.POST = request._request.POST.copy()
for key, value in serializer.validated_data.items():
request._request.POST[key] = value
try:
url, headers, body, status = self.create_token_response(request._request)
except InvalidClientError:
return Response(
data={"invalid_client": "Missing client type."},
status=HTTP_400_BAD_REQUEST,
)
except MissingClientIdError as ex:
return Response(
data={"invalid_request": ex.description},
status=HTTP_400_BAD_REQUEST,
)
except InvalidRequestError as ex:
return Response(
data={"invalid_request": ex.description},
status=HTTP_400_BAD_REQUEST,
)
except UnsupportedGrantTypeError:
return Response(
data={"unsupported_grant_type": "Missing grant type."},
status=HTTP_400_BAD_REQUEST,
)
except AccessDeniedError:
return Response(
{"access_denied": f"The token you provided is invalid or expired."},
status=HTTP_400_BAD_REQUEST,
)
body = json_loads(body)
if "error" in body:
return Response(data=body, status=status)
token = body.get("access_token")
user = User.objects.filter(oauth2_provider_accesstoken__token=token)[0]
refresh = RefreshToken.for_user(user)
access_token = str(refresh.access_token)
decoded_token = AccessToken(access_token)
expiration_time = datetime.fromtimestamp(decoded_token["exp"])
return Response(
{
"id": user.id,
"first_name": user.first_name,
"last_name": user.last_name,
"middle_name": user.middle_name,
"fullname": f"{user.first_name} {user.last_name}",
"email": user.email,
"phone": user.phone.as_e164,
"profile_picture": (user.profile_pic.url if user.profile_pic else None),
"refresh": str(refresh),
"access": access_token,
"expiry": expiration_time,
},
status=status,
)
It's a lot but for now, just ctrl + c
and ctrl + v
. The code is from the official drf_social_oauth codebase, I'm just overriding it.
token = body.get("access_token")
gets the access_token
from the body after the social auth access token has been converted.
user = User.objects.filter(oauth2_provider_accesstoken__token=token)[0]
gets the user that is associated with the token.
refresh = RefreshToken.for_user(user)
access_token = str(refresh.access_token)
decoded_token = AccessToken(access_token)
expiration_time = datetime.fromtimestamp(decoded_token["exp"])
Here is where we change things.
First, we create a simplejwt
refresh token for the user. Then we get the access token
and expiration_time
of the token. This is the token that will be used by the user to authenticate views.
With this, everything is done. You can now authenticate using Bearer
in the authentication header.
Posted on June 13, 2024
Join Our Newsletter. No Spam, Only the good stuff.
Sign up to receive the latest update from our blog.