Drf simple_jwt实现单点登录

227 阅读3分钟

rest_framework_simplejwt实现单点登录,期望实现的功能是同一时间内只能有一台设备登录

自定义认证类

from rest_framework.exceptions import APIException
from rest_framework_simplejwt.authentication import JWTAuthentication
from rest_framework.authentication import BaseAuthentication
from rest_framework_simplejwt.tokens import AccessToken
from rest_framework_simplejwt.exceptions import InvalidToken
from .models import CustomUser as User
from config import IS_SINGLE_TOKEN, REDIS_SINGLE_TOKEN_DB_NAME, REDIS_TOKEN_KEY_PREFIX
from django_redis import get_redis_connection
from django.contrib.auth.models import AnonymousUser


class SingleLoginAuthenticate(BaseAuthentication):
    """
    单点登录认证,
    reids缓存登录过的token,多设备登录只能存储一个token
    校验登录,非匿名用户
    token过期,提示登录失败
    token对不上,提示已在其他地方登录
    """

    def authenticate(self, request):
        jwt_token = request.META.get('HTTP_AUTHORIZATION', None)
        if jwt_token:
            try:
                # 拿到了token,需要检验token是否合法,是否被篡改/伪造,是否过期,如果都通过,根据payload中的user_id取出当前用户。
                '''我们可以去simplejwt的认证类中看它是怎么写的'''
                '''validated_token =get_validated_token(token)--->返回了AccessToken(token)传入的对象'''
                validated_token = AccessToken(jwt_token)  
            except Exception as e:
                raise APIException({'code': 999, 'message': str(e)})

            '''拿到token串对应的用户id'''
            # user = validated_token.payload['user_id']
            # user_id = validated_token['user_id']
            user: User = User.objects.filter(pk=validated_token['user_id']).first()

            # 这里进行单点登录校验
            if IS_SINGLE_TOKEN:
                # 不为空且不是匿名用户
                if user and user.is_authenticated():
                    # if user and not isinstance(user, AnonymousUser):
                    redis_conn = get_redis_connection(REDIS_SINGLE_TOKEN_DB_NAME)
                    # 从redis获取最近登录的token
                    k = REDIS_TOKEN_KEY_PREFIX + str(user.id)
                    cache_token = redis_conn.get(k)
                    if not cache_token:
                        raise InvalidToken('登录已失效,请重新登录')
                    if str(jwt_token) != str(cache_token):
                        raise InvalidToken('已在其他地方登录,请重新登录')

            return user, jwt_token
        else:
            raise APIException({'code': 101, 'message': 'token必须携带'})

config.py

IS_SINGLE_TOKEN = True  # 是否只允许单用户单一地点登录(只有一个人在线上)(默认多地点登录),只针对后台用户生效
REDIS_SINGLE_TOKEN_DB_NAME = 'singletoken'
REDIS_TOKEN_KEY_PREFIX = "single-token"  # 单点登录token存储到redis的前缀

在登录视图类中存储登录的token,在另外一台设备登录了,覆盖redis存储的token。我这里是写在序列化器的,你可以在视图类的进行校验

serializers.py

def store_single_token_if_single_login(user_id, access):
    try:
        if IS_SINGLE_TOKEN:
            redis_conn = get_redis_connection(REDIS_SINGLE_TOKEN_DB_NAME)
            k = REDIS_TOKEN_KEY_PREFIX + str(user_id)
            # 根据setting 设置的jwt过期时间设置expire的过期时间
            token_expire_config = getattr(settings, 'SIMPLE_JWT', None)
            if token_expire_config:
                token_expire = token_expire_config['ACCESS_TOKEN_LIFETIME']
                redis_conn.set(k, access, token_expire)
    except Exception as e:
        print('存储单点登录token失败')
        print(e)


class LoginSerializer(TokenObtainPairSerializer):
    """
    重写TokenObtainPairSerializer登录验证的序列化类
    """

    def validate(self, attrs: Dict[str, Any]) -> Dict[str, str]:
        # todo 这里其实可以改改的,直接复用父类的方法,省去很多代码
        # 重写validate,实现单点登录
        # 现在是使用手机号码加密码登录的
        mobile = attrs['mobile']
        # username = attrs['username']
        password = attrs['password']
        user: User = User.objects.filter(mobile=mobile).first()

        if not user:
            raise ValidationError('账户\密码错误')
        if user and not user.is_active:
            raise ValidationError("该该账号已禁用,请联系管理人员")
        if user and user.check_password(password):
            data = super().validate(attrs)
            # 实现单点登录, 缓存token到redis,到中间件进行校验
            store_single_token_if_single_login(user_id=user.id, access=data['access'])

            return data
        else:
            raise ValidationError('账户\密码错误')
views.py 

class LoginView(TokenObtainPairView):
    queryset = User.objects.all()
    authentication_classes = []
    serializer_class = LoginSerializer


    def post(self, request: Request, *args, **kwargs) -> Response:
        serializer = self.get_serializer(data=request.data)
        if serializer.is_valid():
            info_serializer = UserSerializer(instance=serializer.user, context={'request': request})
            data = {
                'token': serializer.validated_data,
                'userInfo': info_serializer.data,
            }
            return Response(data, status=status.HTTP_200_OK)
        else:
            return Response(serializer.errors, status=status.HTTP_200_OK)

settings配置认证类

settings.py 

REST_FRAMEWORK = { 
	    'DEFAULT_AUTHENTICATION_CLASSES': [
	        'authentications.SingleLoginAuthenticate'
	    ],
	    'DEFAULT_PERMISSION_CLASSES': [
	        'rest_framework.permissions.IsAuthenticated',
	    ],
	}

不需要认证的视图,重新配置视图类的permission_class属性