深入解析当下大热的前后端分离组件django-rest_framework系列三

三剑客之认证、权限与频率组件

认证组件

局部视图认证

在app01.service.auth.py:

复制代码
复制代码
class Authentication(BaseAuthentication):

    def authenticate(self,request):
        token=request._request.GET.get("token")
        token_obj=UserToken.objects.filter(token=token).first()
        if not token_obj:
            raise exceptions.AuthenticationFailed("验证失败!")
        return (token_obj.user,token_obj)
复制代码
复制代码

在views.py:

复制代码
复制代码
def get_random_str(user):
    import hashlib,time
    ctime=str(time.time())

    md5=hashlib.md5(bytes(user,encoding="utf8"))
    md5.update(bytes(ctime,encoding="utf8"))

    return md5.hexdigest()


from app01.service.auth import *

from django.http import JsonResponse
class LoginViewSet(APIView):
    authentication_classes = [Authentication,]
    def post(self,request,*args,**kwargs):
        res={"code":1000,"msg":None}
        try:
            user=request._request.POST.get("user")
            pwd=request._request.POST.get("pwd")
            user_obj=UserInfo.objects.filter(user=user,pwd=pwd).first()
            print(user,pwd,user_obj)
            if not user_obj:
                res["code"]=1001
                res["msg"]="用户名或者密码错误"
            else:
                token=get_random_str(user)
                UserToken.objects.update_or_create(user=user_obj,defaults={"token":token})
                res["token"]=token

        except Exception as e:
            res["code"]=1002
            res["msg"]=e

        return JsonResponse(res,json_dumps_params={"ensure_ascii":False})
复制代码
复制代码

备注:一个知识点:update_or_create(参数1,参数2...,defaults={‘字段’:'对应的值'}),这个方法使用于:如果对象存在,则进行更新操作,不存在,则创建数据。使用时会按照前面的参数进行filter,结果为True,则执行update  defaults中的值,否则,创建。

使用方式: 

  使用认证组件时,自定制一个类,在自定制的类下,必须实现方法:authenticate(self,request),必须接收一个request参数,结果必须返回一个含有两个值的元组,认证失败,就抛一个异常rest_framework下的exceptions.APIException。然后在对应的视图类中添加一个属性:authentication_classes=[自定制类名]即可。

 备注:from rest_framework.exceptions import AuthenticationFailed   也可以抛出这个异常(这个异常是针对认证的异常)

使用这个异常时,需要在自定制的类中,定制一个方法:def authenticate_header(self,request):pass,每次这样定制一个方法很麻烦,我们可以在自定义类的时候继承一个类: BaseAuthentication

调用方式:from rest_framework.authentication import BaseAuthentication

 

为什么是这个格式,我们看看这个组件内部的 实现吧!
不管是认证,权限还是频率都发生在分发之前:

复制代码
class APIView(View):
        def dispatch(self, request, *args, **kwargs):
        """
        `.dispatch()` is pretty much the same as Django's regular dispatch,
        but with extra hooks for startup, finalize, and exception handling.
        """
        self.args = args
        self.kwargs = kwargs
        request = self.initialize_request(request, *args, **kwargs)
        self.request = request
        self.headers = self.default_response_headers  # deprecate?

        try:
            self.initial(request, *args, **kwargs)

            # Get the appropriate handler method
            if request.method.lower() in self.http_method_names:
                handler = getattr(self, request.method.lower(),
                                  self.http_method_not_allowed)
            else:
                handler = self.http_method_not_allowed

            response = handler(request, *args, **kwargs)

        except Exception as exc:
            response = self.handle_exception(exc)

        self.response = self.finalize_response(request, response, *args, **kwargs)
        return self.response
复制代码

所有的实现都发生在       self.initial(request, *args, **kwargs)        这行代码中

复制代码
class APIView(View):
     def initial(self, request, *args, **kwargs):
        """
        Runs anything that needs to occur prior to calling the method handler.
        """
        self.format_kwarg = self.get_format_suffix(**kwargs)

        # Perform content negotiation and store the accepted info on the request
        neg = self.perform_content_negotiation(request)
        request.accepted_renderer, request.accepted_media_type = neg

        # Determine the API version, if versioning is in use.
       #处理版本信息 version, scheme = self.determine_version(request, *args, **kwargs) request.version, request.versioning_scheme = version, scheme # Ensure that the incoming request is permitted
      # 认证 self.perform_authentication(request)
      # 权限 self.check_permissions(request)
      # 用户访问频率的限制 self.check_throttles(request)
复制代码

我们先看认证相关   self.perform_authentication

复制代码
class APIView(View):
        def perform_authentication(self, request):
        """
        Perform authentication on the incoming request.

        Note that if you override this and simply 'pass', then authentication
        will instead be performed lazily, the first time either
        `request.user` or `request.auth` is accessed.
        """
        request.user
复制代码

整个这个方法就执行了一句代码:request.user   ,这个request是新的request,是Request()类的实例对象,  .user 肯定在Request类中有一个user的属性方法。

复制代码
class Request(object):
    @property
    def user(self):
        """
        Returns the user associated with the current request, as authenticated
        by the authentication classes provided to the request.
        """
        if not hasattr(self, '_user'):
            with wrap_attributeerrors():
                self._authenticate()
        return self._user
复制代码

这个user方法中,执行了一个 self._authenticate()方法,继续看这个方法中执行了什么:

复制代码
class Request(object):
        def _authenticate(self):
        """
        Attempt to authenticate the request using each authentication instance
        in turn.
        """
        for authenticator in self.authenticators:
            try:
                user_auth_tuple = authenticator.authenticate(self)
            except exceptions.APIException:
                self._not_authenticated()
                raise

            if user_auth_tuple is not None:
                self._authenticator = authenticator
                self.user, self.auth = user_auth_tuple
                return

        self._not_authenticated()
复制代码

这个方法中,循环一个东西:self.authenticators  ,这个self,是新的request,我们看看这个self.authenticators是什么。

复制代码
class Request(object):
    """
    Wrapper allowing to enhance a standard `HttpRequest` instance.

    Kwargs:
        - request(HttpRequest). The original request instance.
        - parsers_classes(list/tuple). The parsers to use for parsing the
          request content.
        - authentication_classes(list/tuple). The authentications used to try
          authenticating the request's user.
    """

    def __init__(self, request, parsers=None, authenticators=None,
                 negotiator=None, parser_context=None):
       
        self._request = request
        self.parsers = parsers or ()
        self.authenticators = authenticators or ()
复制代码

这个self.authenticators是实例化时初始化的一个属性,这个值是实例化新request时传入的参数。

复制代码
class APIView(View):
   
     def initialize_request(self, request, *args, **kwargs):
        """
        Returns the initial request object.
        """
        parser_context = self.get_parser_context(request)

        return Request(
            request,
            parsers=self.get_parsers(),
            authenticators=self.get_authenticators(),
            negotiator=self.get_content_negotiator(),
            parser_context=parser_context
        )
复制代码

实例化传参时,这个参数来自于self.get_authenticators(),这里的self就不是request,是谁调用它既是谁,看看吧:

复制代码
class APIView(View):
        def get_authenticators(self):
        """
        Instantiates and returns the list of authenticators that this view can use.
        """
        return [auth() for auth in self.authentication_classes]
复制代码

这个方法,返回了一个列表推导式   [auth() for auth in self.authentication_classes] ,到了这里,是不是很晕,这个self.authentication_classes又是啥,这不是我们在视图函数中定义的属性吗,值是一个列表,里面放着我们自定义认证的类

class APIView(View):

    authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES

如果我们不定义这个属性,会走默认的值  api_settings.DEFAULT_AUTHENTICATION_CLASSES

  走了一圈了,我们再走回去呗,首先,这个列表推导式中放着我们自定义类的实例,回到这个方法中def _authenticate(self),当中传的值self指的是新的request,这个我们必须搞清楚,然后循环这个包含所有自定义类实例的列表,得到一个个实例对象。

 

for authenticator in self.authenticators:
try: user_auth_tuple = authenticator.authenticate(self) except exceptions.APIException: self._not_authenticated() raise if user_auth_tuple is not None: self._authenticator = authenticator self.user, self.auth = user_auth_tuple return

 

然后执行每一个实例对象的authenticate(self)方法,也就是我们在自定义类中必须实现的方法,这就是原因,因为源码执行时,会找这个方法,认证成功,返回一个元组 ,认证失败,捕捉一个异常,APIException。认证成功,这个元组会被self.user,self.auth接收值,所以我们要在认证成功时返回含有两个值的元组。这里的self是我们新的request,这样我们在视图函数和模板中,只要在这个request的生命周期,我们都可以通过request.user得到我们返回的第一个值,通过request.auth得到我们返回的第二个值。这就是认证的内部实现,很牛逼。

备注:所以我们认证成功后返回值时,第一个值,最好时当前登录人的名字,第二个值,按需设置,一般是token值(标识身份的随机字符串)

 这种方式只是实现了对某一张表的认证,如果我们有100张表,那这个代码我就要写100遍,复用性很差,所以需要在全局中定义。

全局视图认证组件

settings.py配置如下:

1
2
3
REST_FRAMEWORK={
    "DEFAULT_AUTHENTICATION_CLASSES":["app01.service.auth.Authentication",]
}

为什么要这样配置呢?我们看看内部的实现吧:

  从哪开始呢?在局部中,我们在视图类中加一个属性authentication_classes=[自定义类],那么在内部,肯定有一个默认的值:

class APIView(View):

    authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES

默认值是api_settings中的一个属性,看看这个默认值都实现了啥

api_settings = APISettings(None, DEFAULTS, IMPORT_STRINGS)

首先,这个api_settings是APISettings类的一个实例化对象,传了三个参数,那这个DEFAULTS是啥,看看:

复制代码
#rest_framework中的settings.py

DEFAULTS = {
    # Base API policies
    'DEFAULT_RENDERER_CLASSES': (
        'rest_framework.renderers.JSONRenderer',
        'rest_framework.renderers.BrowsableAPIRenderer',
    ),
    'DEFAULT_PARSER_CLASSES': (
        'rest_framework.parsers.JSONParser',
        'rest_framework.parsers.FormParser',
        'rest_framework.parsers.MultiPartParser'
    ),
    'DEFAULT_AUTHENTICATION_CLASSES': (
        'rest_framework.authentication.SessionAuthentication',
        'rest_framework.authentication.BasicAuthentication'
    ),
    'DEFAULT_PERMISSION_CLASSES': (
        'rest_framework.permissions.AllowAny',
    ),
    'DEFAULT_THROTTLE_CLASSES': (),
    'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'rest_framework.negotiation.DefaultContentNegotiation',
    'DEFAULT_METADATA_CLASS': 'rest_framework.metadata.SimpleMetadata',
    'DEFAULT_VERSIONING_CLASS': None,

    # Generic view behavior
    'DEFAULT_PAGINATION_CLASS': None,
    'DEFAULT_FILTER_BACKENDS': (),

    # Schema
    'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema',

    # Throttling
    'DEFAULT_THROTTLE_RATES': {
        'user': None,
        'anon': None,
    },
    'NUM_PROXIES': None,

    # Pagination
    'PAGE_SIZE': None,

    # Filtering
    'SEARCH_PARAM': 'search',
    'ORDERING_PARAM': 'ordering',

    # Versioning
    'DEFAULT_VERSION': None,
    'ALLOWED_VERSIONS': None,
    'VERSION_PARAM': 'version',

    # Authentication
    'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
    'UNAUTHENTICATED_TOKEN': None,

    # View configuration
    'VIEW_NAME_FUNCTION': 'rest_framework.views.get_view_name',
    'VIEW_DESCRIPTION_FUNCTION': 'rest_framework.views.get_view_description',

    # Exception handling
    'EXCEPTION_HANDLER': 'rest_framework.views.exception_handler',
    'NON_FIELD_ERRORS_KEY': 'non_field_errors',

    # Testing
    'TEST_REQUEST_RENDERER_CLASSES': (
        'rest_framework.renderers.MultiPartRenderer',
        'rest_framework.renderers.JSONRenderer'
    ),
    'TEST_REQUEST_DEFAULT_FORMAT': 'multipart',

    # Hyperlink settings
    'URL_FORMAT_OVERRIDE': 'format',
    'FORMAT_SUFFIX_KWARG': 'format',
    'URL_FIELD_NAME': 'url',

    # Input and output formats
    'DATE_FORMAT': ISO_8601,
    'DATE_INPUT_FORMATS': (ISO_8601,),

    'DATETIME_FORMAT': ISO_8601,
    'DATETIME_INPUT_FORMATS': (ISO_8601,),

    'TIME_FORMAT': ISO_8601,
    'TIME_INPUT_FORMATS': (ISO_8601,),

    # Encoding
    'UNICODE_JSON': True,
    'COMPACT_JSON': True,
    'STRICT_JSON': True,
    'COERCE_DECIMAL_TO_STRING': True,
    'UPLOADED_FILES_USE_URL': True,

    # Browseable API
    'HTML_SELECT_CUTOFF': 1000,
    'HTML_SELECT_CUTOFF_TEXT': "More than {count} items...",

    # Schemas
    'SCHEMA_COERCE_PATH_PK': True,
    'SCHEMA_COERCE_METHOD_NAMES': {
        'retrieve': 'read',
        'destroy': 'delete'
    },
}
复制代码

很直观的DEFAULTS是一个字典,包含多组键值,key是一个字符串,value是一个元组,我们需要的数据:

'DEFAULT_AUTHENTICATION_CLASSES': ( 'rest_framework.authentication.SessionAuthentication', 'rest_framework.authentication.BasicAuthentication' ), 这有啥用呢,api_settings是怎么使用这个 参数的呢

复制代码
class APISettings(object):
    """
    A settings object, that allows API settings to be accessed as properties.
    For example:

        from rest_framework.settings import api_settings
        print(api_settings.DEFAULT_RENDERER_CLASSES)

    Any setting with string import paths will be automatically resolved
    and return the class, rather than the string literal.
    """
    def __init__(self, user_settings=None, defaults=None, import_strings=None):
        if user_settings:
            self._user_settings = self.__check_user_settings(user_settings)
        self.defaults = defaults or DEFAULTS
        self.import_strings = import_strings or IMPORT_STRINGS
        self._cached_attrs = set()
复制代码

好像也看不出什么有用的信息,只是一些赋值操作,将DEFAULTS赋给了self.defaults,那我们再回去看,

在视图函数中,我们不定义authentication_classes 就会执行默认的APIView下的authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES,通过这个api_settings对象调用值时,调不到会怎样,我们看看怎么取值吧,首先会在__getattribute__方法中先找,找不到,取自己的属性中找,找不到去本类中找,在找不到就去父类中找,再找不到就去__getattr__中找,最后报错。按照这个逻辑,执行默认 值时最后会走到类中的__getattr__方法中

 

复制代码
class APISettings(object):    
    def __getattr__(self, attr):
        if attr not in self.defaults:
            raise AttributeError("Invalid API setting: '%s'" % attr)

        try:
            # Check if present in user settings
            val = self.user_settings[attr]
        except KeyError:
            # Fall back to defaults
            val = self.defaults[attr]

        # Coerce import strings into classes
        if attr in self.import_strings:
            val = perform_import(val, attr)

        # Cache the result
        self._cached_attrs.add(attr)
        setattr(self, attr, val)
        return val
复制代码

 

备注:调用__getattr__方法时,会将取的值,赋值给__getattr__中的参数attr,所有,此时的attr是DEFAULT_AUTHENTICATION_CLASSES,这个值在self.defaluts中,所以,代码会执行到try中,val=self.user_settings[attr],这句代码的意思是,val的值是self.user_settings这个东西中取attr,所以self.user_settings一定是个字典,我们看看这个东西做l什么

 

复制代码
class APISettings(object):
    @property
    def user_settings(self):
        if not hasattr(self, '_user_settings'):
            self._user_settings = getattr(settings, 'REST_FRAMEWORK', {})
        return self._user_settings
复制代码

这个属性方法,返回self._user_settings,这个值从哪里来,看代码逻辑,通过反射,来取settings中的REST_FRAMEWORK的值,取不到,返回一个空字典。那这个settings是哪的,是我们项目的配置文件。

  综合起来,意思就是如果在api_settings对象中找不到这个默认值,就从全局settings中找一个变量REST_FRAMEWORK,这个变量是个字典,从这个字典中找attr(DEFAULT_AUTHENTICATION_CLASSES)的值,找不到,返回一个空字典。而我们的配置文件settings中并没有这个变量,很明显,我们可以在全局中配置这个变量,从而实现一个全局的认证。怎么配?

  配置格式:REST_FRAMEWORK={“DEFAULT_AUTHENTICATION_CLASSES”:("认证代码所在路径","...")}

代码继续放下执行,执行到if attr in self.import_strings:       val = perform_import(val, attr)   这两行就开始,根据取的值去通过字符串的形式去找路径,最后得到我们配置认证的类,在通过setattr(self, attr, val)   实现,调用这个默认值,返回配置类的执行。

 

复制代码
def perform_import(val, setting_name):
    """
    If the given setting is a string import notation,
    then perform the necessary import or imports.
    """
    if val is None:
        return None
    elif isinstance(val, six.string_types):
        return import_from_string(val, setting_name)
    elif isinstance(val, (list, tuple)):
        return [import_from_string(item, setting_name) for item in val]
    return val


def import_from_string(val, setting_name):
    """
    Attempt to import a class from a string representation.
    """
    try:
        # Nod to tastypie's use of importlib.
        module_path, class_name = val.rsplit('.', 1)
        module = import_module(module_path)
        return getattr(module, class_name)
    except (ImportError, AttributeError) as e:
        msg = "Could not import '%s' for API setting '%s'. %s: %s." % (val, setting_name, e.__class__.__name__, e)
        raise ImportError(msg)
复制代码

 

备注:通过importlib模块的import_module找py文件。

 

权限组件

局部视图权限

在app01.service.permissions.py中:

复制代码
from rest_framework.permissions import BasePermission
class SVIPPermission(BasePermission):
    message="SVIP才能访问!"
    def has_permission(self, request, view):
        if request.user.user_type==3:
            return True
        return False
复制代码

在views.py:

复制代码
from app01.service.permissions import *

class BookViewSet(generics.ListCreateAPIView):
    permission_classes = [SVIPPermission,]
    queryset = Book.objects.all()
    serializer_class = BookSerializers
复制代码

全局视图权限

settings.py配置如下:

1
2
3
4
REST_FRAMEWORK={
    "DEFAULT_AUTHENTICATION_CLASSES":["app01.service.auth.Authentication",],
    "DEFAULT_PERMISSION_CLASSES":["app01.service.permissions.SVIPPermission",]
}

权限组件一样的逻辑,同样的配置,要配置权限组件,就要在视图类中定义一个变量permission_calsses = [自定义类]

同样的在这个自定义类中,要定义一个固定的方法   def has_permission(self,request,view)  必须传两个参数,一个request,一个是

当前视图类的对象。权限通过返回True,不通过返回False,我们可以定制错误信息,在自定义类中配置一个静态属性message="错误信息",也可以继承一个类BasePermission。跟认证一样。

复制代码
    def check_permissions(self, request):
        """
        Check if the request should be permitted.
        Raises an appropriate exception if the request is not permitted.
        """
        for permission in self.get_permissions():
            if not permission.has_permission(request, self):
                self.permission_denied(
                    request, message=getattr(permission, 'message', None)
                )
复制代码

逻辑很简单,循环这个自定义类实例对象列表,从每一个对象中找has_permission,if not  返回值:表示返回False,,就会抛一个异常,这个异常的信息,会先从对象的message属性中找。

  同样的,配置全局的权限,跟认证一样,在settings文件中的REST_FRAMEWORK字典中配一个键值即可。

 

throttle(访问频率)组件

局部视图throttle

在app01.service.throttles.py中:

复制代码
复制代码
from rest_framework.throttling import BaseThrottle

VISIT_RECORD={}
class VisitThrottle(BaseThrottle):

    def __init__(self):
        self.history=None

    def allow_request(self,request,view):
        remote_addr = request.META.get('REMOTE_ADDR')
        print(remote_addr)
        import time
        ctime=time.time()

        if remote_addr not in VISIT_RECORD:
            VISIT_RECORD[remote_addr]=[ctime,]
            return True

        history=VISIT_RECORD.get(remote_addr)
        self.history=history

        while history and history[-1]<ctime-60:
            history.pop()

        if len(history)<3:
            history.insert(0,ctime)
            return True
        else:
            return False

    def wait(self):
        import time
        ctime=time.time()
        return 60-(ctime-self.history[-1])
复制代码
复制代码

在views.py中:

复制代码
from app01.service.throttles import *

class BookViewSet(generics.ListCreateAPIView):
    throttle_classes = [VisitThrottle,]
    queryset = Book.objects.all()
    serializer_class = BookSerializers
复制代码

全局视图throttle

REST_FRAMEWORK={
    "DEFAULT_AUTHENTICATION_CLASSES":["app01.service.auth.Authentication",],
    "DEFAULT_PERMISSION_CLASSES":["app01.service.permissions.SVIPPermission",],
    "DEFAULT_THROTTLE_CLASSES":["app01.service.throttles.VisitThrottle",]
}

同样的,局部使用访问频率限制组件时,也要在视图类中定义一个变量:throttle_classes = [自定义类],同样在自定义类中也要定义一个固定的方法def allow_request(self,request,view),接收两个参数,里面放我们频率限制的逻辑代码,返回True通过,返回False限制,同时要定义一个def wait(self):pass   放限制的逻辑代码。

内置throttle类

在app01.service.throttles.py修改为:

复制代码
class VisitThrottle(SimpleRateThrottle):

    scope="visit_rate"
    def get_cache_key(self, request, view):

        return self.get_ident(request)
复制代码

settings.py设置:

复制代码
REST_FRAMEWORK={
    "DEFAULT_AUTHENTICATION_CLASSES":["app01.service.auth.Authentication",],
    "DEFAULT_PERMISSION_CLASSES":["app01.service.permissions.SVIPPermission",],
    "DEFAULT_THROTTLE_CLASSES":["app01.service.throttles.VisitThrottle",],
    "DEFAULT_THROTTLE_RATES":{
        "visit_rate":"5/m",
    }
}
复制代码

总结

   restframework的三大套件中为我们提供了多个粒度的控制。局部的管控和全局的校验,都可以很灵活的控制。下一个系列中,将会带来restframework中的查漏补缺。

 

posted @ 2018-08-09 17:26  一抹浅笑  阅读(927)  评论(0编辑  收藏  举报