drf之权限,频率的使用以及认证、权限、频率源码分析、鸭子类型

权限类使用

之前学习了认证类的使用:校验用户是否登录,用token进行认证。

用户登录之后,某些接口可能只有超级管理员才能访问,普通用户不能访问

我们可以设置为出版社的所有接口,必须登录后访问,并且必须是超级管理员才能访问

使用步骤

  • 写一个类,继承BasePermission
  • 重写has_permission方法
  • 在方法中校验是否有权限(request.user就是当前登录用户)
  • 如果有权限,返回True,没有权限,返回False
  • self.message是给前端的提示信息
  • 局部使用,全局使用,局部禁用

代码演示

表模型准备

class User(models.Model):
    username=models.CharField(max_length=32)
    password=models.CharField(max_length=32)
    user_type=models.IntegerField(default=3,choices=((1,'超级管理员'),(2,'普通管理员'),(3,'普通用户')))

自定义权限类

from rest_framework.permissions import BasePermission

class UserTypePermission(BasePermission):
    # 只有超级管理员才有权限
    def has_permission(self, request, view):
        if request.user.user_type==1:
            return True  # 返回True即有权限,后面源码会解释
        else:
            self.message=f'您当前为{request.user.get_user_type_display()},无访问权限'
            return False

出版社视图类使用

class PublishView(ModelViewSet):
    queryset = Publish.objects.all()
    serializer_class =PublishSerializer
    
    authentication_classes = [LoginAuth,]  # 认证类,导入在视图类中为局部使用
    permission_classes = [UserTypePermission,]  # 权限类,导入在视图类中为局部使用

频率类使用

一般来说,无论是否登录和是否有权限,都要限制访问的频率,比如一分钟访问3次

使用步骤

  • 写一个类:继承SimpleRateThrottle
  • 重写get_cache_key,返回唯一的字符串,会以这个字符串做频率限制
  • 写一个类属性scop='xxx',这个'xxx'后续就是配置文件中的键
  • 配置文件中写
'DEFAULT_THROTTLE_RATES': {
        'xxx':'3/m'
    },
  • 局部配置,全局配置,局部禁用

自定义频率类

from rest_framework.throttling import SimpleRateThrottle,BaseThrottle
class LimitThrottle(SimpleRateThrottle):
  # 我们继承SimpleRateThrottle去写,而不是继承BaseThrottle去写,因为SimpleRateThrottle基于BaseThrottle做了封装,代码更为简洁
    scope = 'frequency'
    # 类属性,这个类属性可以随意命名,但是跟配置文件对应
    def get_cache_key(self, request, view):
      # 返回什么,频率就以什么做限制
      # 可以通过用户id限制
      # 可以通过ip地址限制
        return request.META.get('REMOTE_ADDR')

settings中配置:全局配置

REST_FRAMEWORK={
    'DEFAULT_THROTTLE_RATES': {
        'frequency':'3/m'
    },
    'DEFAULT_THROTTLE_CLASSES': [
        'app01.throttling.LimitThrottle'
    ],
}

认证源码分析

在我们写认证类的时候,必须要重写authenticate方法,然后将该类配置在视图类上,该视图就有认证了。

request.user就是当前登录用户,所以认证类的执行, 是在视图类的方法之前执行的,我们可以通过读APIView的源码可以分析出先走三大认证再走视图中方法。

源码分析

APIView的执行流程

  • 包装了新的request
  • 执行了三大认证
  • 执行了视图类的方法
  • 处理了全局异常

认证源码分析的入口:APIView的dispatch

# 在APIView中
self.initial(request, *args, **kwargs)
'分别执行了认证,权限,频率的方法'
# 点进该方法
self.perform_authentication(request)  # 认证
self.check_permissions(request)       # 权限
self.check_throttles(request)         # 频率

读认证类的源码

    def perform_authentication(self, request):
        request.user  # 因为dispatch的执行流程,所以该request是新的request对象

因为该request是新的request,我们去Request类中找user属性(方法),是个方法包装成了数据属性

# 来到Request类中找user:发现是一个property修饰的方法
 @property
    def user(self):
        # 一开始没有_user
        if not hasattr(self, '_user'):  #Request类的对象中反射_user
            with wrap_attributeerrors():
                self._authenticate()  # 第一次会走这个代码
        return self._user

Request类的self._authenticate()

    def _authenticate(self):
        for authenticator in self.authenticators:  # 配置在视图类中所有的认证类的对象 
            try:
                # 这个user_auth_tuple接受的就是我们重写authenticate方法校验用户登录成功的返回值
                # return user_token.user,token
                user_auth_tuple = authenticator.authenticate(self)  # 调用认证类对象的authenticate
            except exceptions.APIException: # AuthenticationFailed的父类捕获
                self._not_authenticated()
                raise

            if user_auth_tuple is not None:
                self._authenticator = authenticator
                self.user, self.auth = user_auth_tuple  # self.user 赋值给request:我们用request.user
                # 认证类可以配置多个,但是如果有一个返回了两个值,后续的就不执行了
                return

        self._not_authenticated()

总结

认证类,要重写authenticate方法,认证通过返回两个值或None,认证不通过抛AuthenticationFailed(继承了APIException)异常。

权限源码分析

APIView中三大认证的initial方法中check_permissions(request)方法

    def check_permissions(self, request):
    # permission是我们配置在视图类中权限类的对象,对象调用它的绑定方法has_permission
    # 对象调用自己的绑定方法会把自己传入(权限类的对象,request,视图类的对象)
        for permission in self.get_permissions():
        # 因为该方法在APIView中,所以self是视图类对象
            if not permission.has_permission(request, self): 
                self.permission_denied(
                    request,
                    message=getattr(permission, 'message', None),
                    code=getattr(permission, 'code', None)
                )

APIView的self._permissions()

    def get_permissions(self):
        return [permission() for permission in self.permission_classes]
    """
    self.permission_classes:就是我们在视图类中配的权限类的列表
    所以这个get_peimissions返回的是,我们在视图类中配置的权限类的对象
    """

总结

为什么要写一个类?重写has_permission方法,有三个参数,为什么一定要return True或False,message可以做什么用

  • APIView的self._permissions()会调用这个类,并将该类实例化成对象

  • 实例化后的对象在源码需要调用has_permission方法,不重写直接抛异常

  • 三个参数,因为在APIView中权限类对象调用类中方法,不需要带self,第三个参数的self是视图类view

  • return True或False是因为源码中根据该方法的返回值进行判断,是True就不会走 self.permission_denied()

  • message,是抛异常的提示信息,通过反射判断对象是否有该属性,可以定义给对象,也可以定义在类属性...

简单读频率类源码

 def check_throttles(self, request):
    throttle_durations = []
    for throttle in self.get_throttles():
      if not throttle.allow_request(request, self):  # Flase的话就会走
        throttle_durations.append(throttle.wait())

查看allow_request源码

def allow_request(self, request, view):
    raise NotImplementedError('.allow_request() must be overridden')
    
"可以看出我们使用继承该频率类的时候必须重写allow_request方法"

总结

要写频率类,必须重写allow_request方法,返回True(没有频率的限制)或False(到了频率的限制)

鸭子类型

指的是面向对象中,子类不需要显示的继承某个类,只要有某个的方法和属性,那子类就属于这个类

走路像鸭子,说话像鸭子,它就是鸭子

假设有一个鸭子Duck类,有两个方法,run,speak方法。只要继承了Duck类就是鸭子类,如果不继承就不是鸭子这种类型

但是在python中不推崇这个,它推崇鸭子类型指的是:

不需要显示的继承某个类,只要我的类中有run和speak方法,我就是鸭子这个类

如果使用鸭子类型的写法,一旦方法写错了,它就不是这个类型了,容易有问题

python为了解决这个问题:

  • 方法1:abc模块,装饰后,必须重写方法,不重写就报错
  • 方法2:drf源码中使用的,父类中写个方法, 但没有具体实现,直接抛异常

重写频率类

from djangoProject10.settings import REST_FRAMEWORK
class LimitThrottling(BaseThrottle):
    class_dict = {}
    def __init__(self):
        self.user_temp=None
        self.timer=time.time
        # 通过配置文件找到对应的字典
        self.count=REST_FRAMEWORK.get('DEFAULT_THROTTLE_RATES')
    def allow_request(self, request, view):
        if not self.sets():
            return True
        user_ip=request.META.get('REMOTE_ADDR')
        self.user_temp=self.class_dict.get(user_ip)
	count,duration=self.split_count()
        if self.user_temp:
            if (self.timer()-self.user_temp[0])>duration:
                self.user_temp.pop(0)
                self.user_temp.append(self.timer())
                return True
            else:
                if len(self.user_temp) >= count:
                    return False
                else:
                    self.user_temp.append(self.timer())
                    return True
        else:
            self.class_dict[user_ip]=[]
            self.class_dict[user_ip].append(self.timer())
            return True
    def split_count(self):
        try:
            # 将5/m 以/拆分
            count, times = self.sets().split('/')
            # 给出一个时间的字典,通过拆分出来的时间取出对应的秒数
            duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[times[0]]
        except Exception:
            raise ValidationError('配置文件错误,请按正确格式配:5/min  5/m')
        return int(count),duration
    # 通过对象属性的配置文件的字典找到对应的值:5/m
    def sets(self):
        if self.count:
            return self.count.get('frequency')
        else:
            return None
posted @ 2022-10-09 22:53  荀飞  阅读(59)  评论(0编辑  收藏  举报