DRF-组件与源码

一. DRF源码的基本

1.1 调用as_view

# APIView中调用as_view函数
class APIView(View):
    @classmethod
    def as_view(cls, **initkwargs):
        if isinstance(getattr(cls, 'queryset', None), models.query.QuerySet):
            def force_evaluation():
                raise RuntimeError(
                    'Do not evaluate the `.queryset` attribute directly, '
                    'as the result will be cached and reused between requests. '
                    'Use `.all()` or call `.get_queryset()` instead.'
                )
            cls.queryset._fetch_all = force_evaluation
        # mark:关于super,在python2中必须写成super(APIView, self), 
        # 而在python3中可以直接省略写成super()
        # 调用父类的as_view方法,as_view方法,调用当前类的dispatch方法
        view = super().as_view(**initkwargs) 
        view.cls = cls
        view.initkwargs = initkwargs

        # Note: session based authentication is explicitly CSRF validated,
        # all other authentication is CSRF exempt.
        # 基于session的认证已经验证了CSRF,所以这里面为所有的view函数,调用SCRF exempt
        return csrf_exempt(view)

1.2 调用dispatch

as_view返回的view中调用了dispatch方法用来处理用户请求。

class APIView(View):    
  	def dispatch(self, request, *args, **kwargs):
        """
        `.dispatch()` is pretty much the same as Django's regular dispatch,
        but with extra hooks for startup, finalize, and exception handling.
        # 整体类似于django的dispatch但是加了些在执行前,最终处理时及异常的handling<钩子>
        """
        self.args = args
        self.kwargs = kwargs

        # 1. 使用DRF的Request类对request对象进行二层封装
        # 新增:parsers, authenticators, negotiator, parser_context这几个属性
        # 新增:_request属性,指向原request对象
        # 新增:一些方法
        request = self.initialize_request(request, *args, **kwargs)
        self.request = request
        # 将默认的response_headers赋予allow-method进去
        self.headers = self.default_response_headers  # deprecate?

        try:
            # Runs anything that needs to occur prior to calling the method handler.
            # 在调用handler之前,把所有需要处理的都处理了
            # 包括格式化kwarg,执行content-type协商,解析version及验证request是否允许(登录与权限);
            self.initial(request, *args, **kwargs)

            # Get the appropriate handler method
            # 通过反射获取允许的执行方法, 类似于django内部
            if request.method.lower() in self.http_method_names:
                handler = getattr(self, request.method.lower(),
                                  self.http_method_not_allowed)
            else:
                handler = self.http_method_not_allowed

            response = handler(request, *args, **kwargs)

        except Exception as exc:
          	# 为response添加了一些额外的Header
            response = self.handle_exception(exc)

        self.response = self.finalize_response(request, response, *args, **kwargs)
        return self.response

二. 认证组件

在dispatch方法调用self.initial(request, *args, **kwargs)验证认证/授权

2.1 调用initial

class APIView(View):    
  	def initial(self, request, *args, **kwargs):
        """
        Runs anything that needs to occur prior to calling the method handler.
        # 运行所有需要在handler执行前的操作;
        """
        # 格式化kwarg
        self.format_kwarg = self.get_format_suffix(**kwargs)

        # Perform content negotiation and store the accepted info on the request
        # 执行内容协商,并存储最终的content类型到request中
        neg = self.perform_content_negotiation(request)
        request.accepted_renderer, request.accepted_media_type = neg

        # Determine the API version, if versioning is in use.
        # 如果version被使用,则按照REST规范去解析API的version;
        version, scheme = self.determine_version(request, *args, **kwargs)
        request.version, request.versioning_scheme = version, scheme

        # Ensure that the incoming request is permitted
        # 确保request所请求的是被允许的
        # 1. 使用self._authenticate()做登录验证,需要放回用户
        # 1.1. 遍历authenticators的认证方法,如果成功则返回user
        self.perform_authentication(request)
        # 2. 检查当前用户请求是否有权限访问
        self.check_permissions(request)
        # 3. 检查当前用户请求是否应该被限制-限制访问次数
        self.check_throttles(request)

2.2 调用perform_authentication

在dispath中调用self.perform_authentication(request)进行认证;

class APIView(View):    
  	def perform_authentication(self, request):
        """
        Perform authentication on the incoming request.
        # 对进入的request执行认证
        """
        # 调用request的user属性
        request.user

2.3 调用request.user

在perform_authentication中调用request.user属性;

class Request:
		@property
    def user(self):
        """
        Returns the user associated with the current request, as authenticated
        by the authentication classes provided to the request.
        """
        if not hasattr(self, '_user'):
           # 这个wrap_attributeerrors是用于来with内部的异常
            with wrap_attributeerrors():
               # 调用self._authenticate()方法
                self._authenticate()
        return self._user

2.4 调用request._authenticate

在request.user中调用了 self._authenticate()方法;

class Request:
		def _authenticate(self):
        """
        Attempt to authenticate the request using each authentication instance
        in turn.
        # 按照顺序使用每种authentication的实例尝试认证request;
        """
        # 遍历所有的authenticator对象;
        for authenticator in self.authenticators:
            try:
                # 调用authenticator对象的authenticate方法,
                # 如果返回(user,token)二元组,则成功
                # 如果返回None则说明当前authenticator不做任何处理,进入下一个authenticator
                # 如果raise异常,则表明挡墙authenticator拒绝用户登录,认证结束,返回上层
                user_auth_tuple = authenticator.authenticate(self)
            except exceptions.APIException:
                # 如果认知功能中报接口异常,直接放回未认证
                self._not_authenticated()
                raise

            # 将认证成功的信息赋值给request中的一些属性并返回
            if user_auth_tuple is not None:
                self._authenticator = authenticator
                self.user, self.auth = user_auth_tuple
                return

        # 如果所有的认证方式都没通过,则返回未认证即可
        self._not_authenticated()

如果认证失败,则调用self._not_authenticated()来处理;

class Request:
      def _not_authenticated(self):
        """
        Set authenticator, user & authtoken representing an unauthenticated request.

        Defaults are None, AnonymousUser & None.
        """
        # 认证者设置为None
        self._authenticator = None

        # 将setting中的默认用户赋值给当前user
        if api_settings.UNAUTHENTICATED_USER:
            self.user = api_settings.UNAUTHENTICATED_USER()
        else:
            self.user = None

        # 将setting中的默认token赋值给当前auth
        if api_settings.UNAUTHENTICATED_TOKEN:
            self.auth = api_settings.UNAUTHENTICATED_TOKEN()
        else:
            self.auth = None

self.user与self.auth都是Request中的属性方法

class Request:
    @user.setter
    def user(self, value):
        self._user = value
        self._request.user = value
        
    @auth.setter
    def auth(self, value):
        self._auth = value
        self._request.auth = value

如果认证异常则最终会触发捕获,捕获后执行异常返回,由handle_exception执行(需特定报错)

class APIView(View):   
  	def handle_exception(self, exc):
        """
        Handle any exception that occurs, by returning an appropriate response,
        or re-raising the error.
        # 放回适当的响应来处理异常,或者再次raise异常
        """
        # 如果是NotAuthenticated和AuthenticationFailed异常,则调用authenticator的方法来获取异常返回header
        if isinstance(exc, (exceptions.NotAuthenticated,
                            exceptions.AuthenticationFailed)):
            # WWW-Authenticate header for 401 responses, else coerce to 403
            auth_header = self.get_authenticate_header(self.request)

            # 如果获取到则将其赋值给exc.auth_header
            if auth_header:
                exc.auth_header = auth_header
            else:
                # 如果没有则返回403异常的
                exc.status_code = status.HTTP_403_FORBIDDEN

        # 调用配置的默认处理机制开始处理
        exception_handler = self.get_exception_handler()

        context = self.get_exception_handler_context()
        response = exception_handler(exc, context)

        if response is None:
            self.raise_uncaught_exception(exc)

        response.exception = True
        return response

2.5 request.authenticators的初始化

在对WSGI的request对象进行restframe封装时,调用了一个函数来封装authenticators属性;

class APIView(View):    
  	def initialize_request(self, request, *args, **kwargs):
        """
        Returns the initial request object.
        """
        parser_context = self.get_parser_context(request)

        return Request(
            request,
            parsers=self.get_parsers(),
            # 在此处调用self.get_authenticators()生成authenticators属性
            authenticators=self.get_authenticators(),
            negotiator=self.get_content_negotiator(),
            parser_context=parser_context
        )

这个方法就是APIView.get_authenticators()方法;

class APIView(View):     
  	def get_authenticators(self):
        """
        Instantiates and returns the list of authenticators that this view can use.
        # 实例化并返回一个authenticators的列表,使用的列表生成式;
        """
        return [auth() for auth in self.authentication_classes]

其中self.authentication_classes为类变量设置了默认值,直接卸载APIView下面,所以可以重写;


class APIView(View):
		# APIView的类变量设置
    authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES

DEFAULT里总共两个类,分别是rest.authenticaton.py内的SessionAuthentication和BasicAuthentication;

class SessionAuthentication(BaseAuthentication):
    """
    Use Django's session framework for authentication.
    # 使用django的session去认证;
    """
    def authenticate(self, request):
        """
        Returns a `User` if the request session currently has a logged in user.
        Otherwise returns `None`.
        # 返回一个用户,如果request的session中有一个已登录的user;
        """

        # Get the session-based user from the underlying HttpRequest object
        # 获取WSGI的request中是否有user
        user = getattr(request._request, 'user', None)

        # Unauthenticated, CSRF validation not required
        # 没有就返回未认证
        if not user or not user.is_active:
            return None
      
        self.enforce_csrf(request)
        
        # CSRF passed with authenticated user
        return (user, None)

    def enforce_csrf(self, request):
        ...

2.6 自构造authenticator类

一个包含authenticate方法并返回(user, auth)的元组-成功,None-失败的类;

class BruceAuth(object):
  	# 可以基于object,也可以基于authtication.py下面的Base,只要有authenticate方法,并返回正确;
    def authenticate(self, request):
        return ('bruce', 'Brucetoken')

class MyCbv(APIView):
    # 覆盖父类中的类属性,调用自定义认证模块
    authentication_classes = (BruceAuth, )
    
    def get(self, request):
        return HttpResponse(f"get request..{request.user}")

2.7 构建API的Token认证

自建authenticator-authentication.py内的BruceAuth类,需要实现两个方法;

from rest_framework.exceptions import AuthenticationFailed
from bruceApp import models


class BruceAuth(object):

    def authenticate(self, request):
        request_token = request.query_params.get('token', None)

        if not request_token:
            raise AuthenticationFailed("认证失败")

        auth = models.auth.objects.filter(token=request_token).first()
        if auth:
            return auth.user, auth.token

        raise AuthenticationFailed("认证失败")

    # 这个在认证失败的时候会调用,其他场景调用,暂未发现;
    def authenticate_header(self, request):
        pass

在app的views中新建login用于用户登录与token获取;

class BruceResponse(JsonResponse):

    def __init__(self, code, msg=None, data=None, *args, **kwargs):
        body = {
            "code": code,
            "msg": msg or "OK",
            "data": data or {}
        }
        super(BruceResponse, self).__init__(body, *args, **kwargs)


class Login(APIView):
    authentication_classes = []

    def get(self, request):
        return BruceResponse(1, "login不允许get")

    def post(self, request):
        request_data = request.data or {}
        user = models.Users.objects.filter(username=request_data.get('username', None),
                                           password=request_data.get('password', None)).first()
        if user:
            user_token = get_token(username=user.username)
            if models.auth.objects.filter(user=user):
                models.auth.objects.update(user=user, token=user_token)
            else:
                models.auth.objects.create(user=user, token=user_token)
            return BruceResponse(0, msg="认证成功", data={'token': user_token})
        return BruceResponse(999, "认证失败,请检查用户民与密码", data=request.data)


class MyCbv(APIView):
    # authentication_classes = (BruceAuth,)

    def dispatch(self, request, *args, **kwargs):
        return super(MyCbv, self).dispatch(request, *args, **kwargs)

    def get(self, request):
        return BruceResponse(0, "hello", {
            "login_user": request.user.username,
            "token": request.auth
        })

# token的生成
from hashlib import md5
import random
import string

def _get_salt(len_of_salt=32):
    len_of_salt = len_of_salt if isinstance(len_of_salt, int) and len_of_salt < 121 else 120
    return ''.join(
        random.sample(string.ascii_letters + string.digits + string.punctuation + string.ascii_lowercase, len_of_salt))

def get_token(username):
    return md5(f"{_get_salt()}{username}".encode(encoding='ASCII')).hexdigest()
  

在setting中配置应用于默认。

REST_FRAMEWORK = {
    'DEFAULT_AUTHENTICATION_CLASSES': [
        'bruceApp.authentication.BruceAuth', # 如果直接放在views函数里面无法import,奇怪....
    ],
}

2.8 认证总结

使用:

  1. 新建一个类,必须实现两个方法,三种返回的一种或多种,建议继承BaseAuthentication类;

    # 可以直接自荐类
    class MyAuth(object):
      def authenticate(self, request, *argv, **kwargs):
        pass
      def authenticate_header(self, request, *argv, **kwargs):
        pass
    
    # 建议-直接继承rest_framework中的BaseAuthenticaton类, 只要实现au
    from rest_framework.authentication import BaseAuthentication
    
    class BruceAuth(BaseAuthentication):
    
        def authenticate(self, request):
          return None # 返回类型1
          return username, usertoken # 返回类型2 
        	raise AuthenticationFailed("认证失败") # 返回类型3
    
  2. 调用方式,分为全局调用和局部调用,同时解析为何setting中配置REST_FRAMEWORK就可以影响DRF;

    # 局部调用方式,在每个类中赋值给authentication_classes
    class MyCbv(APIView):
        authentication_classes = [BruceAuth,]
        ...
    # 全局使用,则在setting中直接设置
    # tips: 如login之类的需要豁免,则可以使用局部方式,覆盖为空即可;
    REST_FRAMEWORK = {
        'DEFAULT_AUTHENTICATION_CLASSES': [
            'bruceApp.authentication.BruceAuth', # 如果直接放在views函数里面无法import,奇怪....
        ],
      	# 是个函数,默认如下,可以更改,可以为None
      	'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
        # 是个函数,默认如下,可以更改,可以为None
      	'UNAUTHENTICATED_TOKEN': None
    }
    # 为什么django的setting中需要加REST_FRAMEWORK呢
    # 从drf的源码中来看:
    
    # 将一些api的默认设置初始化为api_settings
    class APISettings:
       ...
        def reload(self):
          for attr in self._cached_attrs:
            delattr(self, attr)
            self._cached_attrs.clear()
            if hasattr(self, '_user_settings'):
              delattr(self, '_user_settings')
              
    api_settings = APISettings(None, DEFAULTS, IMPORT_STRINGS)
    
    # connect中调用该函数,使用原生django中settings中REST_FRAMEWORK来重写api_settings的一些参数;
    def reload_api_settings(*args, **kwargs):
        setting = kwargs['setting']
        if setting == 'REST_FRAMEWORK':
            api_settings.reload()
    
    # 第一次import setting时,会调用该方法,重新加载配置
    # connect方法加载了django原生setting并通过kwargs传递给reload_api_settings函数
    setting_changed.connect(reload_api_settings)
    

源码:

  1. 封装:drf在其APIViews中重新封装Request对象,对于auth来说,其重点操作是使用authentication_classes及列表生成式生成了authenticators的列表;
  2. 认证:dispatch中的initial方法->调用perform_authentication方法->调用request.user这个属性方法->调用request._authenticate方法使用authenticators列表来认证,成功返回 (user, auth),拒绝raise,忽略None.

三. 授权组件

在dispatch方法调用self.initial(request, *args, **kwargs)验证认证/授权;

3.1 授权组件的使用

  1. 基于BasePermission来构建认证类,必须实现has_permission和has_object_permission方法:

    # 在BruceApp.utils.pemissions
    from rest_framework.permissions import BasePermission
    
    
    class BrucePermission(BasePermission):
      	message = "这个是在没有权限的时候的返回"
        def has_permission(self, request, view):
            if hasattr(request.user, "usertype") and request.user.usertype == 2:
                return True
            return False
        def has_object_permission(self, request, view, obj):
            """
            Return `True` if permission is granted, `False` otherwise.
            """
            return True
    # 返回值有2中:
    # 1. 如果返回True则标记认证通过
    # 2. 如果返回False则标记认证未通过
    
    # 内置的:
    # 1. 在rest——framework.permissions中有些内置的授权类,可参考但最好不直接使用;
    
  2. 使用的方法类似于认证,分为局部调用和全局调用;

    # 局部调用
    from bruceApp.models import Switches
    class SW(APIView):
        permission_classes = [BrucePermission, ] # 局部重写permission_classes即可
        # 获取所有的models对象
        ...
    
    # 全局调用
    # 1. 配置settings.py文件
    REST_FRAMEWORK = {
        "DEFAULT_PERMISSION_CLASSES": [ # 配置默认的PERMISSION为当前自定义的权限组件即可
          'bruceApp.utils.permissions.BrucePermission',
        ],
    }
    # 2. 对某些位置比如login设置豁免
    class Login(APIView):
        authentication_classes = []
        permission_classes = []
    
  3. has_object_permission方法在GenericAPIView中的get_object中使用,在用户访问单条数据时生效;

        def get_object(self):
            """
            Returns the object the view is displaying.
    
            You may want to override this if you need to provide non-standard
            queryset lookups.  Eg if objects are referenced using multiple
            keyword arguments in the url conf.
            """
            queryset = self.filter_queryset(self.get_queryset())
    
            # Perform the lookup filtering.
            lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
    
            assert lookup_url_kwarg in self.kwargs, (
                'Expected view %s to be called with a URL keyword argument '
                'named "%s". Fix your URL conf, or set the `.lookup_field` '
                'attribute on the view correctly.' %
                (self.__class__.__name__, lookup_url_kwarg)
            )
    
            filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]}
            obj = get_object_or_404(queryset, **filter_kwargs)
    
            # May raise a permission denied
            # 没有权限则返回异常;
            self.check_object_permissions(self.request, obj)
    
            return obj
    

3.2 授权组件的源码

permisssons的流程高度类似于认证的流程;

当用户的一个请求进来时,先经过认证后,开始进入permission流程:

  1. dispatch调用initial方法, 在initial中处理认证后,check_permissions;

    class APIViews:   
      	def initial(self, request, *args, **kwargs):
            """
            Runs anything that needs to occur prior to calling the method handler.
            """
            ...
            # 1. 处理用户的认证情况
            self.perform_authentication(request)
            # 2. 检查当前用户请求是否有权限访问
            self.check_permissions(request)
            # 3. 检查当前用户请求是否应该被限制
            self.check_throttles(request)
    
  2. 在check_permissions的处理方式,高度类似认证,区别在于permissions需要所有都通过才通过;

    class APIViews:    
      	def check_permissions(self, request):
            """
            # 测试这个request是否有权限准入,如果没有则raise一个授权exception,针对url级别;
            """
            # 认证1:遍历get_permissions的permisstion中的类,来测试是否有权限;
            for permission in self.get_permissions():
                # 1.1: 调用permission的has_permission方法, 返回True则
                if not permission.has_permission(request, self):
                    self.permission_denied(
                        request,
                        message=getattr(permission, 'message', None),
                        code=getattr(permission, 'code', None)
                    )
    
  3. 如果认证没通过调用permission_deined方法,raise异常,msg为当前permission对象的message属性;

    class APIViews:
    		def permission_denied(self, request, message=None, code=None):
            """
            If request is not permitted, determine what kind of exception to raise.
            """
            if request.authenticators and not request.successful_authenticator:
                raise exceptions.NotAuthenticated()
            raise exceptions.PermissionDenied(detail=message, code=code)
    

四. 访问频次

在dispatch方法调用self.initial(request, *args, **kwargs)验证认证/授权/访问是否超频;

4.1 访问频次组件-自建

可以直接在视图类内新建相关机制:

class SW(APIView):
    # 获取所有的models对象
    import time
    visit_history = {}

    def get(self, request):
        current_host = request._request.META.get("REMOTE_ADDR", None)
        now = int(time.time())

        # 如果没有记录,则添加一个空列表
        if not current_host in self.visit_history:
            self.visit_history[current_host] = []

        # 从整体内获取当前host的list
        histories = self.visit_history[current_host]
        while len(histories) > 0 and histories[-1] + 10 < now:
            histories.pop()

        if len(histories) > 2:
            return BruceResponse(2, f"too many visits in 10s, waiting {histories[-1] + 10 - now}s")
        # 只有访问成功才计数,如果没有成功则不计数;
        print(histories)
        histories.insert(0, now)

        switches = models.Switches.objects.all().values_list('id', 'model_type', 'name', 'ip')
        return BruceResponse(0, "OK", switches)

封装到自建的类中,并按照rest_framework的方式来调用

import time

# 所有记录存储在全局变量中
visit_history = {}
class BruceThrottle(object):

    def __init__(self):
        # 此值放置在这方便类内其他位置调用
        self.histories = []

    def allow_request(self, request, views):
        current_host = request._request.META.get("REMOTE_ADDR", None)
        now = int(time.time())
        # 如果没有记录,则添加一个空列表
        if not current_host in visit_history:
            visit_history[current_host] = [now]
            return True

        # 从整体内获取当前host的list
        self.histories = visit_history[current_host]
        while len(self.histories) > 0 and self.histories[-1] + 10 < now:
            self.histories.pop()
        print(visit_history)
        # 只有访问成功才计数,如果没有成功则不计数;
        if len(self.histories) < 3:
            self.histories.insert(0, now)
            return True
        return False

    def wait(self):
        return self.histories[-1] + 10 - int(time.time())
      
# 在正常的views类内调用,就会通过dispatch->initial->实例化并调用
class SW(APIView):
    throttle_classes = [BruceThrottle,]
    def get(self, request):
        ...

继承rest_framework中的BaseThrottle来构建

import time
from rest_framework.throttling import BaseThrottle
visit_history = {}
class BruceThrottle(BaseThrottle):

    def __init__(self):
        self.histories = []

    def allow_request(self, request, views):
        # 可以使用这个父类中的get_ident来获取终端信息,作为key
        current_host = self.get_ident(request)
        now = int(time.time())
        
        # 如果没有记录,则添加一个空列表
        if not current_host in visit_history:
            visit_history[current_host] = [now]
            return True

        # 从整体内获取当前host的list
        self.histories = visit_history[current_host]
        while len(self.histories) > 0 and self.histories[-1] + 10 < now:
            self.histories.pop()

        # 只有访问成功才计数,如果没有成功则不计数;
        if len(self.histories) < 3:
            self.histories.insert(0, now)
            return True
        return False

    def wait(self):
        return self.histories[-1] + 10 - int(time.time())

4.2 访问频次组件-内置

在rest_frameworks.throtting.py的模块中,有些内置的throttle,继承后少量配置即可使用;

# 如下两个继承,重写了SimpleRateThrottle的get_cache_key和scope属性
from rest_framework.throttling import SimpleRateThrottle
class BruceThrottle(SimpleRateThrottle):
    scope = "Bruce"  # 重写scope的属性-调用setting中的速率

    def get_cache_key(self, request, view):
        return self.get_ident(request)

class AlvinThrottle(SimpleRateThrottle):
    scope = "Alvin" # 重写scope的属性-调用setting中的速率

    def get_cache_key(self, request, view):
        return request.user.username # 基于用户的信息做控制,而非IP地址,这样也OK
      
# 需要在project的setting中新增rates
REST_FRAMEWORK = {
    # 设置默认的限速类们
  	"DEFAULT_THROTTLE_CLASSES": ["bruceApp.utils.throttling.BruceThrottle", ],
    # 设置各个key对应的速率
    "DEFAULT_THROTTLE_RATES": { 
        "Bruce": "3/m",
        "Alvin": "10/m",
    }
}

整个SimpleRateThrottle的实现中可以看到为什么需要重写那些参数:

class SimpleRateThrottle(BaseThrottle):
    cache = default_cache # 使用django的缓存来缓存所有的history
    timer = time.time # 获取当前时间的方法
    cache_format = 'throttle_%(scope)s_%(ident)s' # cache内存储的格式
    scope = None # 当key用于去setting中读取,速率的相关配置,各种scope可以个性化多个限速;
    THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES # 获取设置中所有的RATES

    def __init__(self):
        # 如果当前类没有速率,则使用get_rate以scope为key从setting中的rates来获取rate;
        if not getattr(self, 'rate', None):
            self.rate = self.get_rate()

        # 解析获取到的rate为请求数及在什么时间范围内
        self.num_requests, self.duration = self.parse_rate(self.rate)

    def get_cache_key(self, request, view):
        """
        # 这个方法必须重新写,需要返回一个唯一额cache-key - 比如用ip地址或用户名
        # 这个cache-key将被用于取出或者赋予cache,服务于throttling;
        # 可以返回None,则不会进行限速;
        """
        raise NotImplementedError('.get_cache_key() must be overridden')

    def get_rate(self):
        """
        Determine the string representation of the allowed request rate.
        # 通过scope为key从setting中的rates的dict中来获取rate;
        """
        # 如果当前类的scope没有设置,则raise异常;
        if not getattr(self, 'scope', None):
            msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
                   self.__class__.__name__)
            raise ImproperlyConfigured(msg)

        # 如果通过scope没有在setting的rates中找到rate,则raise异常;
        try:
            return self.THROTTLE_RATES[self.scope]
        except KeyError:
            msg = "No default throttle rate set for '%s' scope" % self.scope
            raise ImproperlyConfigured(msg)

    def parse_rate(self, rate):
        """
        Given the request rate string, return a two tuple of:
        <allowed number of requests>, <period of time in seconds>
        """
        if rate is None:
            return (None, None)
        num, period = rate.split('/')
        num_requests = int(num)
        duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
        return (num_requests, duration)

    def allow_request(self, request, view):
        """
        Implement the check to see if the request should be throttled.

        On success calls `throttle_success`.
        On failure calls `throttle_failure`.
        """

        # 如果没有rate,则不限速
        if self.rate is None:
            return True

        # 返回一个key,用于标记当先的限速类及cache寻址
        self.key = self.get_cache_key(request, view)
        # 如果返回None,则不限速
        if self.key is None:
            return True

        # 使用self.key获取之前的访问history
        self.history = self.cache.get(self.key, [])
        self.now = self.timer() # 获取当前的时间

        # Drop any requests from the history which have now passed the
        # throttle duration
        # 判断在now至duration之间的次数是否是超了,超了就失败,否则成功;
        while self.history and self.history[-1] <= self.now - self.duration:
            self.history.pop()
        if len(self.history) >= self.num_requests:
            return self.throttle_failure() # return False
        return self.throttle_success() # 在history插入最新访问并设置cache,最后return True

    def throttle_success(self):
        """
        Inserts the current request's timestamp along with the key
        into the cache.
        """
        self.history.insert(0, self.now)
        self.cache.set(self.key, self.history, self.duration)
        return True

    def throttle_failure(self):
        """
        Called when a request to the API has failed due to throttling.
        """
        return False

    def wait(self):
        """
        Returns the recommended next request time in seconds.
        返回在此之前需要等待的时间
        """
        if self.history:
            remaining_duration = self.duration - (self.now - self.history[-1])
        else:
            remaining_duration = self.duration

        available_requests = self.num_requests - len(self.history) + 1
        if available_requests <= 0:
            return None

        return remaining_duration / float(available_requests)

内置的其他类

# User的控制
class UserRateThrottle(SimpleRateThrottle):
    """
    Limits the rate of API calls that may be made by a given user.

    The user id will be used as a unique cache key if the user is
    authenticated.  For anonymous requests, the IP address of the request will
    be used.
    """
    scope = 'user'

    def get_cache_key(self, request, view):
        if request.user.is_authenticated:
            ident = request.user.pk
        else:
            ident = self.get_ident(request)

        return self.cache_format % {
            'scope': self.scope,
            'ident': ident
        }
# 其他的类,都类似于以上这种方式

4.3 访问频次组件-源码

Throttle的流程高度类似于认证的流程;

当用户的一个请求进来时,先经过认证后,permission后,让后进去throttling过程中;

  1. 请求进来后调用dispatch;

        def initial(self, request, *args, **kwargs):
            # 1. 使用self._authenticate()做登录验证,需要放回用户
            # 1.1. 遍历authenticators的认证方法,如果成功则返回user
            self.perform_authentication(request)
            # 2. 检查当前用户请求是否有权限访问
            self.check_permissions(request)
            # 3. 检查当前用户请求是否应该被限制
            self.check_throttles(request)
    
  2. 调用chek_throttles方法;

        def check_throttles(self, request):
            """
            Check if request should be throttled.
            Raises an appropriate exception if the request is throttled.
            # 验证请求的速度是否超限,如果需要则raise异常;
            """
            throttle_durations = []
            for throttle in self.get_throttles(): # 类似于以上两种办法
                if not throttle.allow_request(request, self):
                    throttle_durations.append(throttle.wait())
    
            if throttle_durations:
                # Filter out `None` values which may happen in case of config / rate
                # changes, see #1438
                durations = [
                    duration for duration in throttle_durations
                    if duration is not None
                ]
    
                duration = max(durations, default=None)
                self.throttled(request, duration) # 这个会raise相同的异常;
    

五. 版本控制

5.1 版本控制的实现

rest_framework的版本控制很类似于之前几个组件:

  1. 新建类并继承对应的BaseVersioning类/或使用原生的类;

    # 自建的类
    class BruceVersion(BaseVersioning):
    
        version_param = api_settings.VERSION_PARAM
    
        def determine_version(self, request, *args, **kwargs):
            # 根据设置内的VERSION_PARAM从参数中获取版本信息
            version = request.query_params.get(self.version_param, 'v1')
            if not self.is_allowed_version(version):
                raise NotFound(f"{version}不是一个合法的版本..")
            return version
          
    # 使用原生的类
    class URLPathVersioning(BaseVersioning):
        """
        To the client this is the same style as `NamespaceVersioning`.
        The difference is in the backend - this implementation uses
        Django's URL keyword arguments to determine the version.
    
        An example URL conf for two views that accept two different versions.
    		
    		# 在写表明了,是用这种方式需要使用re_path并且指明参数哦,这个是最常用的方式啦
        urlpatterns = [
            re_path(r'^(?P<version>[v1|v2]+)/users/$', users_list, name='users-list'),
            re_path(r'^(?P<version>[v1|v2]+)/users/(?P<pk>[0-9]+)/$', users_detail, name='users-detail')
        ]
    
        GET /1.0/something/ HTTP/1.1
        Host: example.com
        Accept: application/json
        """
        invalid_version_message = _('Invalid version in URL path.')
    
        def determine_version(self, request, *args, **kwargs):
            version = kwargs.get(self.version_param, self.default_version)
            if version is None:
                version = self.default_version
    
            if not self.is_allowed_version(version):
                raise exceptions.NotFound(self.invalid_version_message)
            return version
    		
        # 这里面重写了reverse方法,这里面一样只需要两个参数-viewname和request
        def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
            if request.version is not None:
                kwargs = {} if (kwargs is None) else kwargs
                kwargs[self.version_param] = request.version
    
            return super().reverse(
                viewname, args, kwargs, request, format, **extra
            )
    
    # url如下:
    from django.urls import path, re_path
    from bruceApp import views as bruce_view
    
    app_name = "bruceApp"
    
    urlpatterns = [
        re_path('^(?P<v>v\d+)/info/$', bruce_view.Info.as_view(), name="bruce"),
    ]
    
  2. 在setting中设置一些参数;

    REST_FRAMEWORK = {
        # 设置默认的CLASS,注意解析版本的只有一个-这个是自定义的
        "DEFAULT_VERSIONING_CLASS": "bruceApp.utils.versioning.BruceVersion",
        # 这个是默认种类中的一个,还有其他几个,都可以直接拿来用,至需要在seting中配置些信息
     	  "DEFAULT_VERSIONING_CLASS": "rest_framework.versioning.URLPathVersioning",
      	# 配置默认的版本
        "DEFAULT_VERSION": "v1000",
      	# 配置允许的版本信息
        "ALLOWED_VERSIONS": ["v1000", "v2000", "v3000"],
      	# 配置用于解析的字段
        "VERSION_PARAM": "v"
    }
    
  3. 之后就可以在request的参数中取得版本相关的信息

    # 在bruceApp下这么配置
    app_name = "bruceApp"
    
    urlpatterns = [
        path('info/', bruce_view.Info.as_view(), name="bruce"),
    ]
    # 在视图中调用v1参数,并且reverse生成链接
    class Info(APIView):
    
        def get(self, request, *args, **kwargs):
            print(request.version) # 获取当前的版本
            print(request.versioning_scheme) # 获取当前版本的用于解析的对象
            # 调用对象的reverse方法,重新生成url
            print(request.versioning_scheme.reverse("bruceApp:bruce", request=request))
            users = models.Users.objects.all().values_list('id', 'username', 'password')
            return BruceResponse(0, "OK", users)
    

5.2 版本控制的源码

版本控制方面的信息也在请求进入后再initial方法中处理;

当用户的一个请求进来时,先经过认证后,开始进入permission流程:

  1. dispatch调用initial方法, 在initial中调用determine_version方法;

        def initial(self, request, *args, **kwargs):
    				# Determine the API version, if versioning is in use.
            # 如果version被使用,则按照REST规范去解析API的version;
            version, scheme = self.determine_version(request, *args, **kwargs)
            # 会将版本解析的信息存放在version和versioning_scheme
            request.version, request.versioning_scheme = version, scheme
    
  2. determine_version方法中需要实例化versioning_class方法;

        def determine_version(self, request, *args, **kwargs):
            """
            If versioning is being used, then determine any API version for the
            incoming request. Returns a two-tuple of (version, versioning_scheme)
            # 如果使用版本控制,则判断请求进来的API版本,返回一个tuple(版本,版本控制方案)。
            """
            if self.versioning_class is None:
                return (None, None)
            scheme = self.versioning_class()
            return (scheme.determine_version(request, *args, **kwargs), scheme)
    
  3. 而DEFAULT_VERSIONING_CLASS可以指向BaseVersioning的类或子类;

    class BaseVersioning:
        # 配置里的默认的version
        default_version = api_settings.DEFAULT_VERSION
        # 配置里的允许的version
        allowed_versions = api_settings.ALLOWED_VERSIONS
        # version的参数名字,比如就叫version
        version_param = api_settings.VERSION_PARAM
    
        # versioning的主要工作点,被determine_version调用;
        def determine_version(self, request, *args, **kwargs):
            msg = '{cls}.determine_version() must be implemented.'
            raise NotImplementedError(msg.format(
                cls=self.__class__.__name__
            ))
    
        # 反向生成带版本的URL ???? 这个功能测试有问题,下次看到的时候记得补充...
        def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
            return _reverse(viewname, args, kwargs, request, format, **extra)
    
        # 判断是否是允许的版本
        def is_allowed_version(self, version):
            if not self.allowed_versions:
                return True
            return ((version is not None and version == self.default_version) or
                    (version in self.allowed_versions))
    

六. 解析器

rest_framework的解析器是对请求体(body)的解析;

6.1 前戏

在django中,如果通过post传入的x-www-form-urlencoded的数据则WSGI可以完成解析,这样解析前的数据放在request.body中,解析后的数据放置在request.POST,而JSON的无法解析;

image-20210822205821526

所以在django中需要满足两个条件才能完成body的解析,POST中才有值;

# 请求头:
# 1. Content-Type: application/x-www-form-urlencoded
# 请求体:
# 2. 满足x-www-form-urlencoded格式要求的字符串:value=11123&name=bruce

而在django中提交数据的场景中:

# 1. form表单提交 - 内部转换为value=11123&name=bruce
# 2. ajax方法
$.ajax(
	url: ...
  type: POST,
  data: {name:bruce,age:18}
) ===> 内部转换为value=11123&name=bruce

$.ajax(
	url: ...
  type: POST,
  headers:{'Content-Type': "application/json"},
  data: {name:bruce,age:18}
) ===> 内部转换为value=11123&name=bruce

$.ajax(
	url: ...
  type: POST,
  headers:{'Content-Type': "application/json"},
  data: JSON.stringfy({name:bruce,age:18})
) ===> 传过去的是字符串{name:bruce,age:18}

6.2 解析器的使用

# 如果为某个view特殊指定parser,可以使用这种方式
from rest_framework.parsers import JSONParser, FormParser
class Info(APIView):

    parser_classes = [JSONParser, FormParser] # 覆盖默认配置

    def post(self, request, *args, **kwargs):
        print(request.body)
        print(request.data)
        return BruceResponse(0, "OK")

# 全局的解析器配置在这里,默认情况下即有这些配置,所以默认情况下即可识别这几种的数据类型
REST_FRAMEWORK = {
 'DEFAULT_PARSER_CLASSES': [
        'rest_framework.parsers.JSONParser',
        'rest_framework.parsers.FormParser',
        'rest_framework.parsers.MultiPartParser'
    ]}

如JSON解析器

class JSONParser(BaseParser):
    """
    Parses JSON-serialized data.
    # 解析JSON的数据
    """
    media_type = 'application/json' # Content-type必须为application/json
    renderer_class = renderers.JSONRenderer # 渲染器使用的是JSONRenderer
    strict = api_settings.STRICT_JSON #

    def parse(self, stream, media_type=None, parser_context=None):
        """
        Parses the incoming bytestream as JSON and returns the resulting data.
        # 解析一个进来的bytes为json格式的数据
        """
        parser_context = parser_context or {}
        # 获取默认的编码格式
        encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)

        try:
            decoded_stream = codecs.getreader(encoding)(stream)
            parse_constant = json.strict_constant if self.strict else None
            # json.load...
            return json.load(decoded_stream, parse_constant=parse_constant)
        except ValueError as exc:
            raise ParseError('JSON parse error - %s' % str(exc))

6.3 解析器的源码

解析器的整体流程:

1. 获取用户请求
2. 获取用户请求体
3. 根据用户请求头和parser_classes = [JSONParser, FormParser]中支持的请求头进行比较
4. Parser对象解析请求体
5. 返回request.data

首先parser由dispatch调用的initialize_request来实例化PARSER_CLASSES并封装进request中。

class APIView:
  	 # 由dispatch调用initialize_request封装
     def dispatch(self, request, *args, **kwargs):
            request = self.initialize_request(request, *args, **kwargs)
        
     # 由initialize_request来封装进request
     def initialize_request(self, request, *args, **kwargs):

        return Request(
            # 2. 将WSGI生成的request放入
            request,
            # 3. 将生成的解析器对象列表封装进Reuqest对象的parsers属性;
            parsers=self.get_parsers(),
        )

区别于其他,parser的触发是在request.data被调用的时候,其会调用request中的data属性方法:

class request:
      @property
      def data(self):
          if not _hasattr(self, '_full_data'):
              self._load_data_and_files()
          return self._full_data

调用self._load_data_and_files

class request:
			def _load_data_and_files(self):
        """
        Parses the request content into `self.data`.
        """
        if not _hasattr(self, '_data'):
            # 关键点:
            self._data, self._files = self._parse()
            if self._files:
                self._full_data = self._data.copy()
                self._full_data.update(self._files)
            else:
                self._full_data = self._data

调用_parse来解析并安徽

    def _parse(self):
        """
        Parse the request content, returning a two-tuple of (data, files)

        May raise an `UnsupportedMediaType`, or `ParseError` exception.
        """
        # 获取头部的content-type
        media_type = self.content_type
        try:
            # 获取整体流,从self.body中获取,支持文件获取的?需要测试
            stream = self.stream
        except RawPostDataException:
            if not hasattr(self._request, '_post'):
                raise
            # If request.POST has been accessed in middleware, and a method='POST'
            # request was made with 'multipart/form-data', then the request stream
            # will already have been exhausted.
            # 如果已经被处理过了,直接返回
            if self._supports_form_parsing():
                return (self._request.POST, self._request.FILES)
            stream = None

        # 如果stream和media_type任意为空,则返回empty
        if stream is None or media_type is None:
            if media_type and is_form_media_type(media_type):
                empty_data = QueryDict('', encoding=self._request._encoding)
            else:
                empty_data = {}
            empty_files = MultiValueDict()
            return (empty_data, empty_files)

        # 关键点1:选择某个解析器
        parser = self.negotiator.select_parser(self, self.parsers)
				
        # 如果没有获取到parser,那么就是不支持的格式,raise异常
        if not parser:
            raise exceptions.UnsupportedMediaType(media_type)

        try:
          	# 关键点2:使用获取的parser来输出一个
            parsed = parser.parse(stream, media_type, self.parser_context)
        except Exception:
            # If we get an exception during parsing, fill in empty data and
            # re-raise.  Ensures we don't simply repeat the error when
            # attempting to render the browsable renderer response, or when
            # logging the request or similar.
            self._data = QueryDict('', encoding=self._request._encoding)
            self._files = MultiValueDict()
            self._full_data = self._data
            raise
.
        try:
            return (parsed.data, parsed.files)
        except AttributeError:
            empty_files = MultiValueDict()
            return (parsed, empty_files)

在negotiator.select_parser对比

class DefaultContentNegotiation(BaseContentNegotiation):
    settings = api_settings

    def select_parser(self, request, parsers):
        """
        Given a list of parsers and a media type, return the appropriate
        parser to handle the incoming request.
        """
        for parser in parsers:
            # 关键点1:根据请求头对象与parser对象来选择返回哪个parser
            if media_type_matches(parser.media_type, request.content_type):
                return parser
        return None

七. 序列化

序列号主要有两大功能:对QuerySet的序列化, 对请求数据的验证;

7.1 对QuerySet序列化

7.1.1 简单的使用

from rest_framework import serializers

class UserSerializer(serializers.Serializer):
  	# 这些字段必须在users中有的字段名和同种字段类型才OK
    username = serializers.CharField()
    password = serializers.CharField()

class Info(APIView):

    def get(self, request, *args, **kwargs):
        users = models.Users.objects.all()
        # 多个many=True,单个为False
        ser = UserSerializer(instance=users, many=True)
        # 序列化后的值,可以直接传给Json去处理了
        return BruceResponse(0, "OK", ser.data)

7.1.2 较复杂使用

支持source,支持.多个,支持可执行的自动执行,支持自动method等;

from rest_framework import serializers

class UserSerializer(serializers.Serializer):
    No = serializers.CharField(source="id") # 执行的动作就是obj.id
    name = serializers.CharField(source='username')
    pwd = serializers.CharField(source='password')
    # 所以可以obj.get_usertype_display,类似于上面的方式来调用obj的方法或属性
    # get_usertype_display是一个方法,在Serializer内部会进行一个判断
    # 如果是callable则会执行一次,否则不执行
    utype = serializers.CharField(source='get_usertype_display')
    # 在内部是会一直调用属性,如这里面的obj.group.title
    gp = serializers.CharField(source='group.title')

    # many2many的字段通过source无法取到过于细的字段了
    rs = serializers.CharField(source='role.all')
    # 所以需要使用自定义字段
    # 自定义字段-的定义方式:
    # 1. 定义字段中方法需要为SerializerMethodField
    rls = serializers.SerializerMethodField()
    # 2. 定义row的处理方法,方法就是get_ + 字段名即可
    def get_rls(self, row):
        role_obj_list = row.role.all()
        return [role.title for role in role_obj_list]

class Info(APIView):

    def get(self, request, *args, **kwargs):
        users = models.Users.objects.all()
        ser = UserSerializer(instance=users, many=True)
        return BruceResponse(0, "OK", ser.data)

7.1.3 ModelSerializer的使用

from rest_framework import serializers

class UserSerializer(serializers.ModelSerializer):
    No = serializers.CharField(source='id')
    utype = serializers.CharField(source="get_usertype_display")
    gp = serializers.CharField(source='group.title')
    # 在源码中lookyp_field默认为pk, lookup_url_kwarg默认也为pf
    glink = serializers.HyperlinkedIdentityField(view_name="bruceApp:gp", 
                                                 lookup_field='group_id',             		
                                                 lookup_url_kwarg='gp_id')
    rls = serializers.SerializerMethodField()
    def get_rls(self, row):
        all_roles = row.role.all()
        return [role.title for role in all_roles]
    class Meta:
        # 需要在meta中指定这俩个值
        # fields = "__all__" or ["id", "name"]
        exclude = ["id", "usertype", "group", "role"]
        model = models.Users

class Info(APIView):

    def get(self, request, *args, **kwargs):
        users = models.Users.objects.all()
        ser = UserSerializer(instance=users, many=True)
        return BruceResponse(0, "OK", ser.data)

7.2 对请求数据验证

7.2.1 调用的方式

从源码中来看HOOK的定义方式总共有两处,分别是在field的init时传入钩子实例及serializers中定义的特定名称函数-如验证username使用validate_username

# 自定义的钩子类
class BruceValidator(object):
    def __init__(self, base):
        self.base = base
    def __call__(self, value, *args, **kwargs):
        if not value.startswith(self.base):
            raise serializers.ValidationError(f"必须以{self.base}开头.")
    def set_content(self, serializer_field):
        pass

# django内部的钩子类
from django.core.validators import MaxLengthValidator, validate_ipv4_address

class InfoSerializer(serializers.Serializer):
    username = serializers.CharField(error_messages={'required': "名字不能为空"}, 
                                     # 调用钩子的方式1
                                     validators=[BruceValidator, ])
	
  	# 调用钩子的方式2
    # 两种方式可以并存,方式1咸鱼方式2调用
    def validate_username(self, value):
        print(f"hello, I am a function...{value}")
        # 验证是否为
        # BruceValidator("Bruce")(value)
        # MaxLengthValidator(10)(value)
        validate_ipv4_address(value)
        raise serializers.ValidationError(f"{value}", code="bruceDefined")
        return value
      
class Info(APIView):
    def get(self, request, *args, **kwargs):
        users = models.Users.objects.all()
        ser = UserSerializer(instance=users, many=True, context={'request': request})
        return BruceResponse(0, "OK", ser.data)

    def post(self, request, *args, **kwargs):
        ser = InfoSerializer(data=request.data)
        if ser.is_valid():
            print(ser.validated_data['username'])
        else:
            print(ser.errors)
        return BruceResponse(0, "OK", ser.errors)

7.2.2 内置的钩子

# 1. django内部的钩子类, 在django.core.validators中,也有些基类可以用来继承与处理
from django.core.validators import MaxLengthValidator, validate_ipv4_address

# 2. rest_framework.validators中有些可以验证唯一的方法;
from rest_framework.validators import BaseUniqueForValidator

# 3. Validator这类钩子的基类
@deconstructible
class BaseValidator:
    # 如果没有message,则给出这条msg吧,猜测是。.
    message = _('Ensure this value is %(limit_value)s (it is %(show_value)s).')
    # 是的limit_value,说明这个hook的作用
    code = 'limit_value'
		
    # 使用限制的值来实例化对象
    def __init__(self, limit_value, message=None):
        self.limit_value = limit_value
        if message:
            self.message = message
		
    # 调用该对象,传入value与limit_value或者limit_value()的返回值进行比较
    # 如果调用compare中规则后,返回True,则说明验证异常,raise异常;
    def __call__(self, value):
        cleaned = self.clean(value)
        limit_value = self.limit_value() if callable(self.limit_value) else self.limit_value
        params = {'limit_value': limit_value, 'show_value': cleaned, 'value': value}
        if self.compare(cleaned, limit_value):
            raise ValidationError(self.message, code=self.code, params=params)
		
    # 如果调用 = 时调用当前方法
    def __eq__(self, other):
        return (
            isinstance(other, self.__class__) and
            self.limit_value == other.limit_value and
            self.message == other.message and
            self.code == other.code
        )
		
    # 对clean后的值的验证,True为不通过,False为通过
    def compare(self, a, b):
        return a is not b
		
    # 对传入值的处理
    def clean(self, x):
        return x

7.3 序列化的源码

7.3.1 Serializer的实例化

Serializer的起始在view中的调用:


class Info(APIView):
    def get(self, request, *args, **kwargs):
        users = models.Users.objects.all()
        # 实例化一个Serialzer对象
        ser = UserSerializer(instance=users, many=True, context={'request': request})
        return BruceResponse(0, "OK", ser.data)

实例化对象调用__new__方法,内部根据many=True/False分别实例话出两个不同的对象:

  1. True:实例化出的是ListSerializer对象,内部的child为UserSerializer对象;
  2. False:实例化出的直接是UserSerializer对象;
class BaseSerializer(Field):
		def __new__(cls, *args, **kwargs):
        # 根据many为True或False来决定实例化的类
        # 1. True-ListSerializer
        # 2. False=使用Field的New方法创建实例
        if kwargs.pop('many', False):
            return cls.many_init(*args, **kwargs)
        return super().__new__(cls, *args, **kwargs)
      
    @classmethod
    def many_init(cls, *args, **kwargs):
        allow_empty = kwargs.pop('allow_empty', None)
        # 关键点1:在kwargs.pop("many")后,可以进行单实例的初始化
        child_serializer = cls(*args, **kwargs)
        list_kwargs = {
            'child': child_serializer,
        }
        if allow_empty is not None:
            list_kwargs['allow_empty'] = allow_empty
        list_kwargs.update({
            key: value for key, value in kwargs.items()
            if key in LIST_SERIALIZER_KWARGS
        })
        # 获取meta
        meta = getattr(cls, 'Meta', None)
        # 关键点2:如果meta中没有特殊指定,则使用ListSerializer
        list_serializer_class = getattr(meta, 'list_serializer_class', ListSerializer)
        # 关键点3:使用过滤后的list_kwargs初始化则使用ListSerializer
        return list_serializer_class(*args, **list_kwargs)
     

而ListSerializer的特殊点在于其重写了to_representation方法,循环调用child的方法生成列表;

class ListSerializer(BaseSerialzer):    
  	def to_representation(self, data):
        """
        List of object instances -> List of dicts of primitive datatypes.
        """
        # Dealing with nested relationships, data can be a Manager,
        # so, first get a queryset from the Manager if needed
        iterable = data.all() if isinstance(data, models.Manager) else data
        # 列表生成式中调用原Serializer实例
        return [
            self.child.to_representation(item) for item in iterable
        ]

7.3.2 Serializer取数据

ListSerializer比UserSerializer多了一步自身的to_representation调用,在ListSerializer.to_representation中多次调用了UserSerializer的to_representation方法;

Class BaseSerializer(Field):    
  	def data(self):
            # 关键点1:调用当前类的o_representation方法,这里直接跳过ListSerializer,从User的开始
            self._data = self.to_representation(self.instance)
        return self._data

在父类Serializer中找到了to_representation方法

    def to_representation(self, instance):
        """
        Object instance -> Dict of primitive datatypes.
        """
        ret = OrderedDict() # 返回的是一个有序字典
        fields = self._readable_fields  # 获取可读的fields

        # 循环当前row的所有fields去解析内部的属性值
        for field in fields: # field即field对象,如CharField(source='id')
            try:
                # 去数据库中取到当前field对应的值
                # 如id=1
                # 但是如果是HyperlinkedIdentityField:返回PKOnlyObject(pk=value)
                attribute = field.get_attribute(instance)
            except SkipField:
                continue

            if check_for_none is None:
                ret[field.field_name] = None
            else:
              	# 对各个字段取出的值,进行representation处理,让其可以被json处理之类
                # 调用field字段的to_representation,比如charfield的就是
                # def to_representation(self, value):
                # 	   return str(value)
                ret[field.field_name] = field.to_representation(attribute)
        return ret

7.3.3 Serializer验证数据

由Serializer.is_valid()触发,会调用BaseSerializer的is_valid()方法:

class BaseSerializer(Field):
    def is_valid(self, raise_exception=False):
        assert hasattr(self, 'initial_data'), (
            'Cannot call `.is_valid()` as no `data=` keyword argument was '
            'passed when instantiating the serializer instance.'
        )

        if not hasattr(self, '_validated_data'):
            try:
                # 关键点:调用obj的run_validation来验证数据
                self._validated_data = self.run_validation(self.initial_data)
            except ValidationError as exc:
                self._validated_data = {}
                self._errors = exc.detail
            else:
                self._errors = {}

        if self._errors and raise_exception:
            raise ValidationError(self.errors)

        return not bool(self._errors)

在Serializer中找到run_validation方法

class Serializer(BaseSerializer, metaclass=SerializerMetaclass):
    def run_validation(self, data=empty):
        """
        We override the default `run_validation`, because the validation
        performed by validators and the `.validate()` method should
        be coerced into an error dictionary with a 'non_fields_error' key.
        """
        (is_empty_value, data) = self.validate_empty_values(data)
        if is_empty_value:
            return data
				
        # 关键点
        value = self.to_internal_value(data)
        try:
            self.run_validators(value)
            value = self.validate(value)
            assert value is not None, '.validate() should return the validated data'
        except (ValidationError, DjangoValidationError) as exc:
            raise ValidationError(detail=as_serializer_error(exc))

        return value

在Serializer中找到to_internal_value

class Serializer(BaseSerializer, metaclass=SerializerMetaclass):    
  	    def to_internal_value(self, data):
        """
        Dict of native values <- Dict of primitive datatypes.
        """
        if not isinstance(data, Mapping):
            message = self.error_messages['invalid'].format(
                datatype=type(data).__name__
            )
            raise ValidationError({
                api_settings.NON_FIELD_ERRORS_KEY: [message]
            }, code='invalid')

        ret = OrderedDict()
        errors = OrderedDict()
        fields = self._writable_fields

        for field in fields:
            # 获取自定义的钩子函数
            validate_method = getattr(self, 'validate_' + field.field_name, None)
            # 原始数据
            primitive_value = field.get_value(data)
            try:
                # 验证配置的validation或者基于field类型自动生成validation
                validated_value = field.run_validation(primitive_value)
                if validate_method is not None:
                    # 这个是我们自定义额外添加的
                    validated_value = validate_method(validated_value)
            except ValidationError as exc:
                errors[field.field_name] = exc.detail
            except DjangoValidationError as exc:
                errors[field.field_name] = get_error_detail(exc)
            except SkipField:
                pass
            else:
                set_value(ret, field.source_attrs, validated_value)

        if errors:
            raise ValidationError(errors)

        # 如果中间有异常则抛出异常,否则返回结果
        return ret

八. 分页

在rest_framework中共有3种分页,分别是普通,位置及隐藏;

8.1 普通分页

看第n页,每页显示m条数据;

from bruceApp.utils.serializers.myserializer import RoleSerializer
from rest_framework import pagination

class MyPageNumberPagination(pagination.PageNumberPagination):
    # 一页的显示数量
    #  page_size = api_settings.PAGE_SIZE # 父类中调用的是配置,所以可以在setting中配置
    page_size = 2
    # 可以通过get携带的参数来变更page_size
    # 默认为空,重写为size,则可以通过?size=10这种来更改
    page_size_query_param = 'size'
    # 限制最大page_size的大小,默认没有限制
    max_page_size = 5
    # 用于指定页码的参数名,默认就是page
    page_query_param = "page"


class Role(APIView):

    def get(self, request, *args, **kwargs):
        roles = models.Role.objects.all()
        # 实例化对象
        pg = MyPageNumberPagination()
        # 调用paginate_queryset返回分页后的对象
        pager_roles = pg.paginate_queryset(queryset=roles, request=request, view=self)
        # queryset序列化
        ser = RoleSerializer(instance=pager_roles, many=True)
        # 返回结果
        return Response(ser.data)
        # 也可以通过pg中的一个函数来返回值
        # 这种返回会带上一些额外参数比如count,next及previous对应的url
        return pg.get_paginated_response(ser.data)

8.2 位置分页

在n条这个位置,向后查看n条数据;

from rest_framework import pagination
from bruceApp.utils.serializers.myserializer import RoleSerializer
from django.utils.translation import ugettext_lazy as _
class MyLimitOffsetPagination(pagination.LimitOffsetPagination):
    # default_limit = api_settings.PAGE_SIZE 默认在PAGE_SIZE中配置
    default_limit = 2
    limit_query_param = 'limit' # 区别于page
    limit_query_description = _('Number of results to return per page.')
    offset_query_param = 'offset' # 偏移量
    offset_query_description = _('The initial index from which to return the results.')
    max_limit = 10 # 最大limit默认没写
    template = 'rest_framework/pagination/numbers.html'

class Role(APIView):

    def get(self, request, *args, **kwargs):
        roles = models.Role.objects.all()
        # 实例化对象
        pg = MyLimitOffsetPagination()
        # 调用paginate_queryset返回分页后的对象
        pager_roles = pg.paginate_queryset(queryset=roles, request=request, view=self)
        # queryset序列化
        ser = RoleSerializer(instance=pager_roles, many=True)
        # 返回结果
        # return Response(ser.data)
        # 也可以通过pg中的一个函数来返回值
        # 这种返回会带上一些额外参数比如count,next及previous对应的url
        return pg.get_paginated_response(ser.data)

8.3 隐藏分页

仅可看上一页和下一页, 页码被加密,好处就是用户只能按照顺序来获取,加快读取速度;

from rest_framework import pagination
from bruceApp.utils.serializers.myserializer import RoleSerializer
from django.utils.translation import ugettext_lazy as _

class MyCursorPagination(pagination.CursorPagination):
    cursor_query_param = 'cursor' # 默认的字段
    cursor_query_description = _('The pagination cursor value.')
    # page_size = api_settings.PAGE_SIZE, 分页大小默认也是这么大
    page_size = 5
    invalid_cursor_message = _('Invalid cursor')
    # ordering = '-created' # 以什么来排序, 加-,意为反序
    ordering = 'id'
    template = 'rest_framework/pagination/previous_and_next.html'

    # Client can control the page size using this query parameter.
    # Default is 'None'. Set to eg 'page_size' to enable usage.
    page_size_query_param = None
    page_size_query_description = _('Number of results to return per page.')

    # Set to an integer to limit the maximum page size the client may request.
    # Only relevant if 'page_size_query_param' has also been set.
    max_page_size = 100

    # 限制最大偏移量,防止用户恶意调用
    offset_cutoff = 1000

class Role(APIView):

    def get(self, request, *args, **kwargs):
        roles = models.Role.objects.all()
        # 实例化对象
        pg = MyCursorPagination()
        # 调用paginate_queryset返回分页后的对象
        pager_roles = pg.paginate_queryset(queryset=roles, request=request, view=self)
        # queryset序列化
        ser = RoleSerializer(instance=pager_roles, many=True)
        # 返回结果
        # 这种返回会带上一些额外参数比如count,next及previous对应的url-对隐藏的分页方式尤为重要
        return pg.get_paginated_response(ser.data)

image-20210823224242887

8.4 分页数据库查询

Question:如果数据量大的话,如何分页?

1. 使用数据库主键,记住当前页的最大值与最小值,在前一页或者后一页的时候可以迅速加载;
2. 使用隐藏分页的方式,限制用户的恶意访问,降低数据库的压力;

九. 视图View

在django中我们使用View,在rest_framework中有多个View的实现

9.1 APIView

APIView是rest_framework的基础view里面主要体现:

  1. 调用api_setting中的一些配置,将其封装进APIView中;

  2. 重写了def dispatch(self, request, *args, **kwargs)方法,在其中添加了些额外功能(组件):

    class APIView(View):
          def dispatch(self, request, *args, **kwargs):
            """
            `.dispatch()` is pretty much the same as Django's regular dispatch,
            but with extra hooks for startup, finalize, and exception handling.
            # 整体类似于django的dispatch但是加了些在启动,最终能处理及异常的handling
            """
            self.args = args
            self.kwargs = kwargs
    
            # 1. 使用DRF的Request类对request对象进行二层封装
            # 新增:parsers, authenticators, negotiator, parser_context这几个属性
            # 新增:_request属性,指向原request对象
            # 新增:一些方法
            request = self.initialize_request(request, *args, **kwargs)
            self.request = request
            # 获取request请求头,并封装allow-method进去
            self.headers = self.default_response_headers  # deprecate?
    
            try:
                # Runs anything that needs to occur prior to calling the method handler.
                # 在调用handler处理请求前,把所有需要处理的都处理了
                # 包括格式化kwarg,执行content-type协商,解析version及验证request是否允许(登录与权限);
                self.initial(request, *args, **kwargs)
    
                # Get the appropriate handler method
                # 通过反射获取允许的执行方法, 类似于django内部
                if request.method.lower() in self.http_method_names:
                    handler = getattr(self, request.method.lower(),
                                      self.http_method_not_allowed)
                else:
                    handler = self.http_method_not_allowed
    
                #
                response = handler(request, *args, **kwargs)
    
            except Exception as exc:
                # 异常处理的handle
                response = self.handle_exception(exc)
    
            self.response = self.finalize_response(request, response, *args, **kwargs)
            return self.response
    

9.2 GenericAPIView

通用APIView是对APIView的又一层封装,本身并不是很强大,但是封装了一些属性与方法,方便其他的APIView基于它来实现;

from rest_framework.generics import GenericAPIView
from bruceApp.utils.serializers.myserializer import RoleSerializer

class Role(GenericAPIView):
    queryset = models.Role.objects.all()
    serializer_class = RoleSerializer
    pagination_class = MyCursorPagination

    def get(self, request, *args, **kwargs):
        roles = self.get_queryset()
        # 实例化对象
        pager_roles = self.paginate_queryset(roles)

        # queryset序列化
        ser = self.get_serializer(instance=pager_roles, many=True)
        # 返回结果
        # 这种返回会带上一些额外参数比如count,next及previous对应的url-对隐藏的分页方式尤为重要
        return self.paginator.get_paginated_response(ser.data)

9.3 GenericViewSet

GenericViewSet是在GenericAPIView的基础上mix了ViewSetMixin所生成的一个新类,在这个新类中as-view方法被ViewSetMixin所重写,需要传入以method:funname键值对组成的字典,用于在相关的请求进来时候的处理,也就是说get请求不应要在View中写get了,可以写成其他的;

class ViewSetMixin:
    """
    This is the magic.

    Overrides `.as_view()` so that it takes an `actions` keyword that performs
    the binding of HTTP methods to actions on the Resource.

    For example, to create a concrete view binding the 'GET' and 'POST' methods
    to the 'list' and 'create' actions...

    view = MyViewSet.as_view({'get': 'list', 'post': 'create'})
    """

    @classonlymethod
    def as_view(cls, actions=None, **initkwargs):
        """
        Because of the way class based views create a closure around the
        instantiated view, we need to totally reimplement `.as_view`,
        and slightly modify the view function that is created and returned.
        """
        ...

对上面的GenericAPIView进行一个简单的重写

# 在url中, 路由系统开始发生变化,GenericAPIView有个存在的意义
urlpatterns = [
		# 需要在as_view中添加method:all_get字典
    re_path('^(?P<v>v\d+)/role/$', bruce_view.Role.as_view({"get": "all_get"}), 
            name="role"),
]

# 在view中
from rest_framework.viewsets import GenericViewSet
from bruceApp.utils.serializers.myserializer import RoleSerializer

class Role(GenericViewSet):
    queryset = models.Role.objects.all()
    serializer_class = RoleSerializer
    pagination_class = MyCursorPagination

    def all_get(self, request, *args, **kwargs):
        roles = self.get_queryset()
        # 实例化对象
        pager_roles = self.paginate_queryset(roles)

        # queryset序列化
        ser = self.get_serializer(instance=pager_roles, many=True)
        # 返回结果
        # 这种返回会带上一些额外参数比如count,next及previous对应的url-对隐藏的分页方式尤为重要
        return self.paginator.get_paginated_response(ser.data)

9.4 ModelViewSet

相较于GenericViewSet视图,ModelViewSet更进一步的mix了多个Mixin,每个添加额外的方法:

class ModelViewSet(mixins.CreateModelMixin, # create方法-创建单个instance
                   mixins.RetrieveModelMixin, # retrieve方法-取回单个instance----需要id
                   mixins.UpdateModelMixin, # update方法-更新单个instance------需要id
                   mixins.DestroyModelMixin, # destroy方法-删除单个instance----需要id
                   mixins.ListModelMixin, # list方法-展示一组instance
                   GenericViewSet):
    """
    A viewset that provides default `create()`, `retrieve()`, `update()`,
    `partial_update()`, `destroy()` and `list()` actions.
    """
    pass

ModelViewSet的使用需要与路由配合,从上面可以看出有三个方法需要id,有两个方法不需要id,那么我们就可以为这个model设计两个URL来实现增删改查;

# 在url中, 针对不需要id和需要id的分为两种,然后分别映射对应名称的方法,完成增删改查
# !!!! 在这里面可以看到可以同一个get被分开了,get一个或get多个
urlpatterns = [
    re_path('^(?P<v>v\d+)/role/$',
            bruce_view.Role.as_view({"get": "list",
                                     "post": "create"}), name="roles"),
    re_path('^(?P<v>v\d+)/role/(?P<pk>\d+)/$',
            bruce_view.Role.as_view({"get": "retrieve",
                                     "put": "update",
                                     "delete": "destroy"}), name="role"),
]

from rest_framework.viewsets import ModelViewSet
from bruceApp.utils.serializers.myserializer import RoleSerializer

class Role(ModelViewSet):
    queryset = models.Role.objects.all()
    serializer_class = RoleSerializer
    pagination_class = MyCursorPagination

9.5 使用总结

1. 如果你只需要实现简单的增删改查,可以直接进程ModelViewSet;
2. 如果你需要实现基于增删改查的但是较复杂的功能,可以继承GenericViewSet然后调用/重写mixin来适配;
3. 建议继承GenericViewSet,这样可以实现不同get请求的区分;

十. 路由与渲染

10.1 路由

结合渲染器一个module的路由可以写出4中格式:

urlpatterns = [
    re_path('^(?P<v>v\d+)/role/$',
            bruce_view.Role.as_view({"get": "list",
                                     "post": "create"}), name="roles"),
    re_path('^(?P<v>v\d+)/role\.(?P<format>\w+)/$',
            bruce_view.Role.as_view({"get": "list",
                                     "post": "create"}), name="roles"),
    re_path('^(?P<v>v\d+)/role/(?P<pk>\d+)/$',
            bruce_view.Role.as_view({"get": "retrieve",
                                     "put": "update",
                                     "delete": "destroy"}), name="role"),
    re_path('^(?P<v>v\d+)/role/(?P<pk>\d+)\.(?P<format>\w+)/$',
            bruce_view.Role.as_view({"get": "retrieve",
                                     "put": "update",
                                     "delete": "destroy"}), name="role"),
]

# 主要结合渲染器,通过format传入渲染器支持的格式,比如json;
# 如果不写也ok,可以通过?format=json这种方式传入
# 现在可以通过http://127.0.0.1:8002/api/v1000/role.json/方式传入

drf提供自动生成路由的方式,可以用省事,也可以不用-因为生成的路由过多;

from django.urls import re_path, include
from bruceApp import views as bruce_view

app_name = "bruceApp"
from rest_framework import routers

router = routers.DefaultRouter()
router.register(r'role', bruce_view.Role)

urlpatterns = [
    re_path('', include(router.urls)),
]

image-20210824003028435

10.2 渲染器

通过不同的方法来对返回的数据进行渲染,默认的是JSONRenderer和BrowsableAPIRenderer。

根据 用户请求URL 或 用户可接受的类型,筛选出合适的 渲染组件。
用户请求URL:

用户请求头:

  • Accept:text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,/;q=0.8
from rest_framework.viewsets import ModelViewSet
from bruceApp.utils.serializers.myserializer import RoleSerializer
from rest_framework.renderers import JSONRenderer, JSONOpenAPIRenderer, BrowsableAPIRenderer, AdminRenderer
class Role(ModelViewSet):
    renderer_classes = [JSONRenderer, BrowsableAPIRenderer, AdminRenderer]
    queryset = models.Role.objects.all()
    serializer_class = RoleSerializer
    pagination_class = MyCursorPagination

如admin的显示的页面admin,这个目前没大用吧,将来可能有用;

image-20210824004602264

参考:Django Rest Framework

posted @ 2021-08-24 00:54  FcBlogs  阅读(54)  评论(0编辑  收藏  举报