How to efficiently use drf_social_oauth2 and django_rest_framework_simplejwt

codewitgabi

codewitgabi

Posted on June 13, 2024

How to efficiently use drf_social_oauth2 and django_rest_framework_simplejwt

Hey guys!!

So I was working on a rest api application with django and django rest framework but then I happened to run into a lot of issues using drf_social_oauth2 and django_rest_framework_simplejwt. The issue was that the former strictly uses Bearer authentication header while the latter uses any authorization header of your choice so what I did initially was to use JWT for simplejwt. This worked fine even though on the client side, my team and I had to write some logic to know if the user initially logged in via oauth or regular auth. This was okay but I was not too comfortable with it. I later thought of a solution to fix the issue and today, I will be showing you what I did to fix this boring issue.

# settings.py
# before

REST_FRAMEWORK = {
    "DEFAULT_AUTHENTICATION_CLASSES": (
        "rest_framework_simplejwt.authentication.JWTAuthentication",
        "oauth2_provider.contrib.rest_framework.OAuth2Authentication",
        "drf_social_oauth2.authentication.SocialAuthentication",
    ),
}
Enter fullscreen mode Exit fullscreen mode
# settings
# after

REST_FRAMEWORK = {
    "DEFAULT_AUTHENTICATION_CLASSES": (
        "rest_framework_simplejwt.authentication.JWTAuthentication",
    ),
}
Enter fullscreen mode Exit fullscreen mode

So from the two code in our settings.py, you can see we had to remove the oauth2_provider and simplejwt authentication from the authentication classes and that's because we will no longer be using it to authenticate. We will default back to using simplejwt and that brings us to the next part.

# views.py

# builtin imports
from json import loads as json_loads
from datetime import datetime

from oauth2_provider.settings import oauth2_settings
from oauth2_provider.views.mixins import OAuthLibMixin
from oauthlib.oauth2.rfc6749.errors import (
    InvalidClientError,
    UnsupportedGrantTypeError,
    AccessDeniedError,
    MissingClientIdError,
    InvalidRequestError,
)

from django.views.decorators.csrf import csrf_exempt
from django.utils.decorators import method_decorator
from django.contrib.auth import get_user_model

# third party imports
from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework.permissions import AllowAny
from rest_framework_simplejwt.tokens import RefreshToken, AccessToken
from rest_framework.status import HTTP_400_BAD_REQUEST
from rest_framework.request import Request

from drf_social_oauth2.serializers import ConvertTokenSerializer
from drf_social_oauth2.oauth2_backends import KeepRequestCore
from drf_social_oauth2.oauth2_endpoints import SocialTokenServer


# user object
User = get_user_model()


class CsrfExemptMixin:
    """
    Exempts the view from CSRF requirements.
    NOTE:
        This should be the left-most mixin of a view.
    """

    @method_decorator(csrf_exempt)
    def dispatch(self, *args, **kwargs):
        return super(CsrfExemptMixin, self).dispatch(*args, **kwargs)


class ConvertTokenView(CsrfExemptMixin, OAuthLibMixin, APIView):
    """
    Implements an endpoint to convert a provider token to an access token

    The endpoint is used in the following flows:

    * Authorization code
    * Client credentials
    """

    server_class = SocialTokenServer
    validator_class = oauth2_settings.OAUTH2_VALIDATOR_CLASS
    oauthlib_backend_class = KeepRequestCore
    permission_classes = (AllowAny,)

    def post(self, request: Request, *args, **kwargs):
        serializer = ConvertTokenSerializer(data=request.data)
        serializer.is_valid(raise_exception=True)
        # Use the rest framework `.data` to fake the post body of the django request.
        request._request.POST = request._request.POST.copy()
        for key, value in serializer.validated_data.items():
            request._request.POST[key] = value

        try:
            url, headers, body, status = self.create_token_response(request._request)
        except InvalidClientError:
            return Response(
                data={"invalid_client": "Missing client type."},
                status=HTTP_400_BAD_REQUEST,
            )
        except MissingClientIdError as ex:
            return Response(
                data={"invalid_request": ex.description},
                status=HTTP_400_BAD_REQUEST,
            )
        except InvalidRequestError as ex:
            return Response(
                data={"invalid_request": ex.description},
                status=HTTP_400_BAD_REQUEST,
            )
        except UnsupportedGrantTypeError:
            return Response(
                data={"unsupported_grant_type": "Missing grant type."},
                status=HTTP_400_BAD_REQUEST,
            )
        except AccessDeniedError:
            return Response(
                {"access_denied": f"The token you provided is invalid or expired."},
                status=HTTP_400_BAD_REQUEST,
            )

        body = json_loads(body)

        if "error" in body:
            return Response(data=body, status=status)

        token = body.get("access_token")
        user = User.objects.filter(oauth2_provider_accesstoken__token=token)[0]

        refresh = RefreshToken.for_user(user)
        access_token = str(refresh.access_token)
        decoded_token = AccessToken(access_token)
        expiration_time = datetime.fromtimestamp(decoded_token["exp"])

        return Response(
            {
                "id": user.id,
                "first_name": user.first_name,
                "last_name": user.last_name,
                "middle_name": user.middle_name,
                "fullname": f"{user.first_name} {user.last_name}",
                "email": user.email,
                "phone": user.phone.as_e164,
                "profile_picture": (user.profile_pic.url if user.profile_pic else None),
                "refresh": str(refresh),
                "access": access_token,
                "expiry": expiration_time,
            },
            status=status,
        )
Enter fullscreen mode Exit fullscreen mode

It's a lot but for now, just ctrl + c and ctrl + v. The code is from the official drf_social_oauth codebase, I'm just overriding it.

token = body.get("access_token") gets the access_token from the body after the social auth access token has been converted.
user = User.objects.filter(oauth2_provider_accesstoken__token=token)[0] gets the user that is associated with the token.

refresh = RefreshToken.for_user(user)
access_token = str(refresh.access_token)
decoded_token = AccessToken(access_token)
expiration_time = datetime.fromtimestamp(decoded_token["exp"])
Enter fullscreen mode Exit fullscreen mode

Here is where we change things.
First, we create a simplejwt refresh token for the user. Then we get the access token and expiration_time of the token. This is the token that will be used by the user to authenticate views.

With this, everything is done. You can now authenticate using Bearer in the authentication header.

💖 💪 🙅 🚩
codewitgabi
codewitgabi

Posted on June 13, 2024

Join Our Newsletter. No Spam, Only the good stuff.

Sign up to receive the latest update from our blog.

Related