drf——源码分析

drf——源码分析

  • 认证源码分析
  • 权限源码分析
  • 频率类源码分析

三大认证的源码分析

之前读取的APIView的源码的执行流程中包装了新的request,执行了三大认证,执行视图类的方法,处理了全局异常

  • 查看源码的入口

    ​ APIView的dispatch

  • 进入后在APIView的dispatch的496行上下

    ​ self.initial(request, *args, **kwargs)中

  • 查看APIView的initial

    413行上下有三句话,分别是认证、权限、频率

    self.perform_authentication(request)
    self.check_permissions(request)
    self.check_throttles(request)
    

    这三个分别是三大认证的源码分析的读取入口

认证源码分析

  • 读取认证类源码——APIView的perform_authentication(request)

    def perform_authentication(self, request):
        request.user  # 新的request
    

    request是新的request

    Request类中找user属性(方法),是个方法包装成了数据属性

  • 点击request类找到user属性(方法)

    def user(self):
        if not hasattr(self, '_user'): # Request类的对象中反射_user
            with wrap_attributeerrors():
                self._authenticate()  # 第一次会走这个代码
       	return self._user
    

    查找到user属性后先进行判断函数是否包含_user属性,不包含则进行if内操作调用 self._authenticate()

  • 点击 Request的self._authenticate()

    def _authenticate(self):
        for authenticator in self.authenticators: # 配置在视图类中所有的认证类的对象 
            try:
                #(user_token.user, token)
                user_auth_tuple = authenticator.authenticate(self) 
                # 调用认证类对象的authenticate
            except exceptions.APIException:
                self._not_authenticated()
                raise
    
            if user_auth_tuple is not None:
                self._authenticator = authenticator # 忽略
                self.user, self.auth = user_auth_tuple # 解压赋值
                return  self._not_authenticated()
                    # 认证类可以配置多个,但是如果有一个返回了两个值,后续的就不执行了
                
    
  • 总结

    认证类,要重写authenticate方法,认证通过返回两个值或None,

    认证不通过抛AuthenticationFailed(继承了APIException)异常

权限源码分析

  • 读取权限类源码 APIView的check_permissions(request)

    def check_permissions(self, request):
        for permission in self.get_permissions():
            # permission是咱们配置在视图类中权限类的对象,对象调用它的绑定方法has_permission
            # 对象调用自己的绑定方法会把自己传入(权限类的对象,request,视图类的对象)
            if not permission.has_permission(request, self):
                self.permission_denied(
                    request,
                    message=getattr(permission, 'message', None),
                    code=getattr(permission, 'code', None)
                )
    
  • 读取APIVIew的 self.get_permissions()

    return [permission() for permission in self.permission_classes]
    """self.permission_classes 就是咱们在视图类中配的权限类的列表"""
    

    所以这个get_permissions返回的是 在视图类中配的权限类的对象列表[UserTypePermession(),]

  • 总结

    为什么要写一个类,重写has_permission方法,

    has_permission有三个参数,分别是 权限类的对象,request,视图类的对象

    为什么一定要return True或False:

    ​ 返回True是通过权限,Flase是未通过权限

    messgage的作用:

    ​ 调用self.messgage可以更改提示信息

频率类源码

  • 读取频率类源码 APIView的check_throttles

    def check_throttles(self, request):
        throttle_durations = []
        for throttle in self.get_throttles():
            if not throttle.allow_request(request, self):
                throttle_durations.append(throttle.wait())
               
    def get_throttles(self):
        return [throttle() for throttle in self.throttle_classes]
    
    

    ​ 要写频率类,必须重写allow_request方法,然后结束for循环返回值

  • 读取 allow_request 源码

    源码里执行的频率类的allow_request,读SimpleRateThrottle的allow_request

    class SimpleRateThrottle(BaseThrottle):
        cache = default_cache
        timer = time.time
        cache_format = 'throttle_%(scope)s_%(ident)s'
        scope = None
        THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
        def __init__(self):  # 只要类实例化得到对象就会执行,一执行,self.rate就有值了,而且self.num_requests和self.duration
            if not getattr(self, 'rate', None): # 去频率类中反射rate属性或方法,发现没有,返回了None,这个if判断就符合,执行下面的代码
                self.rate = self.get_rate()  #返回了  '3/m'
            #  self.num_requests=3
            #  self.duration=60
            self.num_requests, self.duration = self.parse_rate(self.rate)
    
        def get_rate(self):
             return self.THROTTLE_RATES[self.scope] # 字典取值,配置文件中咱们配置的字典{'ss': '3/m',},根据ss取到了 '3/m'
    
        def parse_rate(self, rate):
            if rate is None:
                return (None, None)
            # rate:字符串'3/m'  根据 / 切分,切成了 ['3','m']
            # num=3,period=m
            num, period = rate.split('/')
            # num_requests=3  数字3
            num_requests = int(num)
            # period='m'  ---->period[0]--->'m'
            # {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
            # duration=60
            duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
            # 3     60
            return (num_requests, duration)
    
        def allow_request(self, request, view):
            if self.rate is None:
                return True
            # 咱们自己写的,返回什么就以什么做限制  咱们返回的是ip地址
            # self.key=当前访问者的ip地址
            self.key = self.get_cache_key(request, view)
            if self.key is None:
                return True
            # self.history 访问者的时间列表,从缓存中拿到,如果拿不到就是空列表,如果之前有 [时间2,时间1]
            self.history = self.cache.get(self.key, [])
            # 当前时间
            self.now = self.timer()
            while self.history and self.history[-1] <= self.now - self.duration:
                self.history.pop()
            if len(self.history) >= self.num_requests:
                return self.throttle_failure()
            return self.throttle_success()
    
  • 总结

    要写频率类,必须重写allow_request方法

    返回True(没有到频率的限制)或False(到了频率的限制)

    以后要再写频率类,只需要继承SimpleRateThrottle,重写get_cache_key,配置类属性scope,配置文件中配置一下就可以了

排序和过滤源码分析

继承了GenericAPIView+ListModelMixin,只要在视图类中配置filter_backends它就能实现过滤和排序

  • drf内置的过滤类(SearchFilter),排序类(OrderingFiler)

    from rest_framework.filters import SearchFilter,OrderingFiler,BaseFilterBackend
    

    在类中只需继承了GenericAPIView+ListModelMixin,配置filter_backends即可直接使用模块的过滤和排序

  • 排序和过滤源码剖析

    排序和过滤只有涉及到查看多个(list)才起作用

    故需继承ListAPIView(继承了GenericAPIView+ListModelMixin)

    ListModelMixin:
    def list(self, request, *args, **kwargs):
        # self.get_queryset()所有数据,经过了self.filter_queryset返回了qs
         # self.filter_queryset完成的过滤
        queryset = self.filter_queryset(self.get_queryset())
           # 如果有分页,走的分页----》视图类中配置了分页类
        page = self.paginate_queryset(queryset)
        if page is not None:
            serializer = self.get_serializer(page, many=True)
            return self.get_paginated_response(serializer.data)
         # 如果没有分页,走正常的序列化,返回
        serializer = self.get_serializer(queryset, many=True)
        return Response(serializer.data)
    

    self.filter_queryset完成了过滤,当前在视图类中,self是视图类的对象

    视图类中没找到去其父类找 ,找到 GenericAPIView 下的 filter_queryset

    def filter_queryset(self, queryset):
        for backend in list(self.filter_backends):
            queryset = backend().filter_queryset(self.request, queryset, self)
            return queryset
    

重写 filter_queryset

drf内置的过滤类(SearchFilter),排序类(OrderingFiler)

即重写过滤类(SearchFilter),排序类(OrderingFiler)内部的filter_queryset方法即可自定义

排序

from rest_framework.filters import SearchFilter,OrderingFiler,BaseFilterBackend

class 类名(BaseFilterBackend):
    def filter_queryset(self, request, queryset, view):
       重写方法
    	return queryset
  • 总结:

    写的过滤类要重写filter_queryset,返回qs(过滤或排序后)对象

    后期如果不写过滤类,只要在视图类中重写filter_queryset,在里面实现过滤也可以

restframework-jwt执行流程

restframework-jwt 就是签发流程

本质就是登录接口,为了校验用户是否正确

如果正确签发token,写到了序列化类中,如果不正确返回错误

读取源码的入口:

obtain_jwt_token:核心代码--ObtainJSONWebToken.as_view()

  • ObtainJSONWebToken

    视图类,实现了登录功能

    class ObtainJSONWebToken(JSONWebTokenAPIView):
        serializer_class = JSONWebTokenSerializer
    

    找其父类

    class JSONWebTokenAPIView(APIView):
        # 局部禁用掉权限和认证
        permission_classes = () 
        authentication_classes = ()
    
        def get_serializer_context(self):
            return {
                'request': self.request,
                'view': self,
            }
    
        def get_serializer_class(self):
            return self.serializer_class
    
        def get_serializer(self, *args, **kwargs):
            serializer_class = self.get_serializer_class()
            kwargs['context'] = self.get_serializer_context()
            return serializer_class(*args, **kwargs)
    
        def post(self, request, *args, **kwargs):
            # JSONWebTokenSerializer实例化得到一个序列号类的对象,传入前端传的只
            serializer = self.get_serializer(data=request.data)
    
            if serializer.is_valid(): # 校验前端传入的数据是否合法:
                #1 字段自己的规则 2 局部钩子 3 全局钩子(序列化类的validate方法)
                # 获取当前登录用户和签发token是在序列化类中完成的
                # 从序列化类对象中取出了当前登录用户
                user = serializer.object.get('user') or request.user
                # # 从序列化类对象中取出了token
                token = serializer.object.get('token')
                # 自定义过
                response_data = jwt_response_payload_handler(token, user, request)
                response = Response(response_data)
                return response
    
            return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
    
  • 序列化类 JSONWebTokenSerializer

    class JSONWebTokenSerializer(Serializer):
        def validate(self, attrs):
            credentials = {
                'username': attrs.get('username'),
                'password': attrs.get('password')
            }
    
            if all(credentials.values()):
                # auth的校验用户名和密码是否正确
                user = authenticate(**credentials)
    
                if user:
                    # 通过用户获得payload:{}
                    payload = jwt_payload_handler(user)
                    return {
                        'token': jwt_encode_handler(payload),
                        'user': user
                    }
                else:
                    # 根据用户名和密码查不到用户
                    raise serializers.ValidationError(msg)
                    else:	
                        # 用户名和密码不传,传多了都不行
                        raise serializers.ValidationError(msg)
    
  • 认证类 JSONWebTokenAuthentication

    class JSONWebTokenAuthentication(BaseJSONWebTokenAuthentication):
        def get_jwt_value(self, request):
            # get_authorization_header(request)根据请求头中HTTP_AUTHORIZATION,取出token
            # jwt adsfasdfasdfad
            # auth=['jwt','真正的token']
            auth = get_authorization_header(request).split()
            auth_header_prefix = api_settings.JWT_AUTH_HEADER_PREFIX.lower()
            if not auth:
                if api_settings.JWT_AUTH_COOKIE:
                    return request.COOKIES.get(api_settings.JWT_AUTH_COOKIE)
                return None
            if smart_text(auth[0].lower()) != auth_header_prefix:
                return None
            if len(auth) == 1:
                msg = _('Invalid Authorization header. No credentials provided.')
                raise exceptions.AuthenticationFailed(msg)
                elif len(auth) > 2:
                    msg = _('Invalid Authorization header. Credentials string '
                            'should not contain spaces.')
                    raise exceptions.AuthenticationFailed(msg)
                    return auth[1]
    

    其父类 BaseJSONWebTokenAuthentication---》authenticate

    class BaseJSONWebTokenAuthentication(BaseAuthentication):
        def authenticate(self, request):
            # jwt_value前端传入的token
            jwt_value = self.get_jwt_value(request)
            # 前端没有传入token,return None,没有带token,认证类也能过,所有咱们才加权限类
            if jwt_value is None:
                return None
            try:
                payload = jwt_decode_handler(jwt_value) # 验证token,token合法,返回payload
                except jwt.ExpiredSignature:
                    msg = _('Signature has expired.')
                    raise exceptions.AuthenticationFailed(msg)
                    except jwt.DecodeError:
                        msg = _('Error decoding signature.')
                        raise exceptions.AuthenticationFailed(msg)
                        except jwt.InvalidTokenError:
                            raise exceptions.AuthenticationFailed()
    
                            user = self.authenticate_credentials(payload) # 通过payload得到当前登录用户
    
                            return (user, jwt_value) # 后期的request.user就是当前登录用户
    

    它这个认证类:只要带了token,request.user就有只,如果没带token,不管了,继续往后走

posted @   Nirvana*  阅读(242)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
点击右上角即可分享
微信分享提示