drf三大认证源码分析及异常捕获

drf三大认证源码分析及异常捕获

三大认证分析源头:drf的APIView中,重写了dispatch方法,在分发按请求方式分发之前,进行了运行了initial函数,其中就有以下代码,并且整体的将initial和视图分发放在同一个try的子代码中进行异常捕获,那么当认证不通过时,只要抛出异常,后续的视图函数就不会执行。

# APIView中的initial函数
def initial(self, request, *args, **kwargs)
	...解析编码、版本控制相关代码
    self.perform_authentication(request)
    self.check_permissions(request)
    self.check_throttles(request)

# APIView中的dispatch函数
def dispatch(self, request, *args, **kwargs):
    。。。
    try:
        self.initial(request, *args, **kwargs)
       	视图分发代码
    except Exception as exc:
        response = self.handle_exception(exc)  # 处理三大认证和视图中遇到的异常
     。。。

权限类的源码分析(check_permissions)

def check_permissions(self, request):
	# get_permissions()得到的是[permission() for permission in self.permission_classes]
    # 即是视图对象中注册的权限类进行初始化得到的权限对象列表
    for permission in self.get_permissions():
        # 对权限列表的has_permission依次执行,如果有一个权限没有通过则执行permission_denied
        # permission_denied中会直接报错,交由最外层的try处理
        if not permission.has_permission(request, self): 
            self.permission_denied(
                request,
                message=getattr(permission, 'message', None),
                code=getattr(permission, 'code', None)
            )

总结:

  1. 视图类中注册的permission_classes权限类列表,会被遍历执行初始化拿到一个个权限对象

    如果没有在视图类中注册permission_classes,则用APIView中默认配置的即全局配置的。

  2. 通过权限对象依次执行所有注册的权限类中覆写的has_permission方法

  3. 如果所有权限对象.has_permission的执行结果是True,则check_permission不会报错

  4. 如果有一个权限对象执行结果不为True,则抛出异常拦截到后面的视图代码。

认证类源码分析(perform_authentication)

就一句:

def perform_authentication(self, request):
    request.user  # request.user实际上是被包装成数据的功能

首先,request在APIView.dispatch首部就已经被替换为Request类产生的新的request了,所以要去drf的Request方法里面去找user方法。

@property
def user(self):
    if not hasattr(self, '_user'):  # 仅第一次执行,执行后会将结果保存到_user中,不必重复执行
        with wrap_attributeerrors():
            self._authenticate()
    return self._user

在user中会仅执行一次_authenticate,最终可以猜测是将内容存到了_user属性中

def _authenticate(self):
	# 这里是request内部的函数,authenticators是__init__中初始化传入的
    # 而这里是dispatch中对Request进行传入的get_authenticators(self)
    # 得到的内容是:[auth() for auth in self.authentication_classes]
    # 即视图中注册的,或者全局配置的
    for authenticator in self.authenticators:
        try:
            # 执行认证类的authenticate方法,返回的是(user,token)的形式
            user_auth_tuple = authenticator.authenticate(self)
        except exceptions.APIException:
            # 如果捕获到drf的异常,那么就执行以下函数
            self._not_authenticated()  # 将request.user设置为未认证用户和token
            raise  # 执行到这里说明是有些认证没有通过,直接报错拦住
		
        # 正常拿到结果,则判断是否是None
        if user_auth_tuple is not None:
            self._authenticator = authenticator
            self.user, self.auth = user_auth_tuple  # 拿到user和token
            return  # 直接结束函数
        # 是None则还继续认证下一个认证类

    self._not_authenticated()

总结:

  1. user被包装成数据属性,第一次拿值时会执行_authenticate,最终认证通过时会通过user.setter将user对象放到self._userself._request.user中,后续拿值时直接通过_user
  2. 认证类拿取的路线更曲折一些,视图类中注册的类会先保存到APIView对象中,其重写的dispatch中初始化Resquest类产生新的request时,将认证类传入到request.authenticators中,最终还是在_authenticate方法中遍历执行覆写的authenticate方法
  3. 覆写的authenticate要求必须返回(user,token)才会算认证成功,如果返回None,则虽然没有认证成功,但会允许走下一个认证类。
  4. 当覆写的authenticate抛出APIException时,则会被异常捕获,设置为未认证的用户和token,拦截后面的视图函数。

频率类源码分析(check_throttles)

def check_throttles(self, request):
    throttle_durations = []
    # 这个方法是视图类中的方法,self就说视图类,
    # get_throttles最终会拿到视图中注册的频率类或者全局配置的频率类
    for throttle in self.get_throttles():
        # allow_request需要覆写,当频率达到限制时,就返回false,添加到throttle_durations列表中
        if not throttle.allow_request(request, self):
            throttle_durations.append(throttle.wait())
	# 过滤掉为none的间隔值
    if throttle_durations:
        durations = [
            duration for duration in throttle_durations
            if duration is not None
        ]
		# 取最大的等待间隔
        duration = max(durations, default=None)
        self.throttled(request, duration)  # 报错,交由全局异常捕获处理

check_throttles源码中,可以分析出只需要覆写allow_request,我们就可以书写限制频率的逻辑了。但是明显这个自由度太大了,并没有频率的框架,我们可以直接让allow_request按任意的规则返回true或者false。

这里,我们先尝试通过直接覆写allow_request来编辑频率认证的逻辑,然后再看一下SimpleRateThrottle是怎么通过配置参数就简单实现评率认证的限制的。

覆写allow_request

class SuperThrottle(BaseThrottle):
    visited_dict = {}
    
    def __init__(self):
        self.history = None

    def allow_request(self, request, view):
        # 取出访问者的ip地址
        user_ip = request.META.get('REMOTE_ADDR')
        # 判断这个判断标准如ip是否在频率字典里,这个字典应该所有这个类产生对象共同能访问到的
        # 如果是不在字典,则直接添加表示第一次访问
        visited_dic = self.visited_dict
        if user_ip not in visited_dic:
            visited_dic[user_ip] = [time.time(), ]
            return True
        # 如果在字典中,那么我们先将它的时间列表调整到只保留60s内
        now = time.time()
        visit_time_list = visited_dic.get(user_ip)
        visited_dic[user_ip] = [visit_time for visit_time in visit_time_list if now - visit_time < 60]
        visited_dic[user_ip].append(now)
        # 保存一份到对象,以便其他函数使用
        self.history = visited_dic[user_ip]
        # 判断当列表小于等于5,则放行
        if len(visited_dic[user_ip]) <= 5:
            return True
        else: 
            # 否则不放行 
            return False
        
    # 用于提示还有下一次能访问的时间
    def wait(self):
        ctime = time.time()
        return 60 - (ctime - self.history[0])

这份代码就完成了访问频次的限制,其中,有两个数字是可以编辑成配置数字,就是表示一分钟的60,和表示限制5次的5,SimpleRateThrottle就是采取的这个思路。

SimpleRateThrottle源码分析

# SimpleRateThrottle重写的allow_request
def allow_request(self, request, view):
    # 如果rate或者key属性没有值,那么不做限制,直接返回True
    if self.rate is None:
        return True
    # 取判断依据,所以要在SimpleRateThrottle覆写get_cache_key
    self.key = self.get_cache_key(request, view)  
    if self.key is None:
        return True
    # 从缓存中取时间列表
    self.history = self.cache.get(self.key, [])
    self.now = self.timer()

    # 调整列表,只保留设定duration内的时间
    while self.history and self.history[-1] <= self.now - self.duration:
        self.history.pop()
    # 比较列表内时间次数
   	if len(self.history) >= self.num_requests:
        # 其实就是return了False,但是允许覆写来做一些额外的事,如记录日志
        return self.throttle_failure()  
    return self.throttle_success()  # 内部返回True,并且将新时间插入缓存

我们可以充分的感受到相对于我们自己覆写的allow_request,模块提供的allow_request更加的有扩展性,体现在:

  • 频次可以被配置
  • 访问频次的判断依据可以被设置(get_cache_key)
  • 在返回放行和拦截时,给了扩展的可能性(throttle_failure)

scope配置的原理:

# SimpleRateThrottle的init方法
    def __init__(self):
        if not getattr(self, 'rate', None):
            # self.rate= '5/h'
            self.rate = self.get_rate()
        # 5  3600
        self.num_requests, self.duration = self.parse_rate(self.rate)

# SimpleRateThrottle的get_rate() 方法
    def get_rate(self):
        if not getattr(self, 'scope', None):
            msg = ("You must set either `.scope` or `.rate` for '%s' throttle" % self.__class__.__name__)
            raise ImproperlyConfigured(msg)
        try:
            #  self.scope 是字符串
            # return '5/h'
            return self.THROTTLE_RATES[self.scope]
        except KeyError:
            msg = "No default throttle rate set for '%s' scope" % self.scope
            raise ImproperlyConfigured(msg)
                        
# SimpleRateThrottle的parse_rate 方法
	def parse_rate(self, rate):
        # '5/h'
        if rate is None:
            return (None, None)
        # num =5
        # period= 'hour'
        num, period = rate.split('/')
        # num_requests=5
        num_requests = int(num)
        duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
        # (5,36000)
        return (num_requests, duration)
  • init中,用get_rate将rate的配置拿出来,并用parse_rate解析rate
  • get_rate中,判断若没有scope配置就报错提醒,有的话则按照这个字符串去注册的频次字典中去取值,拿到rate如'5/h'这种。
  • parse_rate中,将'5/h'这种字符串按斜杠分割,变为请求次数和间隔时间

drf异常捕获(handle_exception)

在进行三大认证和视图函数时,dispatch函数一直监听着异常,并会捕获异常对象exc传入self.handle_exception

首先我们可以找到在APIView中已经写了默认的handle_exception去处理异常,但是通过self对象来找,也意味着我们的继承APIView的视图类可以覆写这个方法,做到派生,也就是可以添加记录日志等相关功能。(不过,其内侧提供了可改写的接口,这个思路了解即可)

其次我们来分析以下APIView本身怎么去处理异常的。

handle_exception源码分析

def handle_exception(self, exc):
	# 如果是认证权限方面的错误,那就设置403或401
    if isinstance(exc, (exceptions.NotAuthenticated,
                        exceptions.AuthenticationFailed)):
        auth_header = self.get_authenticate_header(self.request)
        if auth_header:
            exc.auth_header = auth_header
        else:
            exc.status_code = status.HTTP_403_FORBIDDEN

    # 如果是其他错误则执行以下函数
    # 1.取处理错误的函数
    exception_handler = self.get_exception_handler()
    # 2.取处理错误的内容
    context = self.get_exception_handler_context()
    # 3.传入错误和内容,执行处理错误函数
    response = exception_handler(exc, context)  
    # 防止没有response,有一个默认的response作为上一层dispatch的返回
    if response is None:
        self.raise_uncaught_exception(exc)
    response.exception = True
    return response

def get_exception_handler(self):
    # 配置中的内容,也就是说处理异常的函数是可以配置的
    # self.settings可以在django项目的REST_FRAMEWORK字典中配置
    return self.settings.EXCEPTION_HANDLER  

def get_exception_handler_context(self):
	# 返回一些错误信息,以字典的形式
    return {
        'view': self,
        'args': getattr(self, 'args', ()),
        'kwargs': getattr(self, 'kwargs', {}),
        'request': getattr(self, 'request', None)
    }
    
# drf默认的处理异常的函数配置
'EXCEPTION_HANDLER': 'rest_framework.views.exception_handler',

def exception_handler(exc, context):
    # 处理404,权限错误
    if isinstance(exc, Http404):
        exc = exceptions.NotFound()
    elif isinstance(exc, PermissionDenied):
        exc = exceptions.PermissionDenied()
	# 如果是APIException(drf所有异常都继承这个异常),做响应处理
    if isinstance(exc, exceptions.APIException):
        # 处理响应头
        headers = {}
        if getattr(exc, 'auth_header', None):
            headers['WWW-Authenticate'] = exc.auth_header
        if getattr(exc, 'wait', None):
            headers['Retry-After'] = '%d' % exc.wait
		# APIException都有detail属性,将其捕捉到data中
        if isinstance(exc.detail, (list, dict)):
            data = exc.detail
        else:
            data = {'detail': exc.detail}

        set_rollback()
        # 响应,并添加data,响应状态码,响应头
        return Response(data, status=exc.status_code, headers=headers)

    return None

总结:

  1. handle_exception是捕获异常的总函数,内部通过先取配置好的异常处理函数和异常信息,再执行,意味着异常处理函数是可配置的,在REST_FRAMEWORK中就可以配置。
  2. 默认的异常处理函数是views.exception_handle函数,总体思路是对APIException异常及其子类进行捕获和返回正常的drf的Response响应,其他的异常则不处理。

利用接口增加异常处理的功能

# 配置'EXCEPTION_HANDLER': 'app01.exceptions.common_exception_handler',
from rest_framework.views import exception_handler

def common_exception_handler(exc, context):
    # 假装记录了日志
    now = time.time()  # 当前时间
    request = context.get('request')  
    user_id = request.user.id  # 用户id,匿名用户无id
    user_ip = request.META.get('REMOTE_ADDR')   # 设备ip
    method = request.method  # 请求方式
    addr = '127.0.0.1:8000' + request.path  # 请求地址
    view = context.get('view')  # 请求视图
    print(now, user_id, user_ip, method, addr, view, exc)
    return exception_handler(exc, context)

APIView编写分页

class BookPageView(viewsets.ViewSet):
    paginator = CommonPaginator()

    def list(self, request):
        queryset = Book.objects.all()
        paginate_queryset = self.paginator.paginate_queryset(queryset, request, self)
        ser = BookSerializer(instance=paginate_queryset, many=True)
        return Response(ser.data)
posted @ 2023-02-08 19:44  leethon  阅读(29)  评论(0编辑  收藏  举报