07 drf源码剖析之节流
07 drf源码剖析之节流
1. 节流简述
-
节流类似于权限,它确定是否应授权请求。节流指示临时状态,并用于控制客户端可以向API发出的请求的速率。
-
还有情况可能是 ,由于某些服务特别耗费资源,因此您需要在API的不同部分施加不同的约束。
-
频率限制在认证、权限之后
2. 节流使用
-
在settings配置文件中设置规定时间段内可以访问的次数
REST_FRAMEWORK = { "DEFAULT_THROTTLE_RATES": {"anon": '10/m'}, }
-
在需要节流的类中加throttle_classes
from rest_framework.views import APIView from rest_framework.response import Response from rest_framework.throttling import AnonRateThrottle,BaseThrottle class ArticleView(APIView): throttle_classes = [AnonRateThrottle,] def get(self,request,*args,**kwargs): return Response('文章列表') class ArticleDetailView(APIView): def get(self,request,*args,**kwargs): return Response('文章列表')
3. 源码剖析
-
请求过来先执行dispatch方法
class APIView(View): permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES def dispatch(self, request, *args, **kwargs): # 封装request对象... self.initial(request, *args, **kwargs) # 通过反射执行视图中的方法...
-
initial方法过渡
def initial(self, request, *args, **kwargs): # 版本的处理... # 认证... # 权限判断 self.check_throttles(request) # 节流
-
将所有的节流类实例化成对象列表
def get_throttles(self): return [throttle() for throttle in self.throttle_classes]
-
循环执行每个对象的allow_request方法
def check_throttles(self, request): throttle_durations = [] for throttle in self.get_throttles(): if not throttle.allow_request(request, self): throttle_durations.append(throttle.wait())
-
执行allow_request方法,AnonRateThrottle没有allow_request,去父类找
class AnonRateThrottle(SimpleRateThrottle): scope = 'anon' def get_cache_key(self, request, view): if request.user.is_authenticated: return None # Only throttle unauthenticated requests. return self.cache_format % { 'scope': self.scope, 'ident': self.get_ident(request) }
-
执行SimpleRateThrottle类中allow_request方法
- 获取请求用户的IP
- 根据IP获取他的所有访问记录
- 获取当前时间
- 将不在规定时间的记录删除掉
- 判断规定时间段时间内访问了多少次
- 和设定次数对比,判断用户是否可以继续访问
class SimpleRateThrottle(BaseThrottle): cache = default_cache timer = time.time cache_format = 'throttle_%(scope)s_%(ident)s' scope = None THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES def __init__(self): if not getattr(self, 'rate', None): self.rate = self.get_rate() self.num_requests, self.duration = self.parse_rate(self.rate) def parse_rate(self, rate): if rate is None: return (None, None) num, period = rate.split('/') num_requests = int(num) duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]] return (num_requests, duration) def allow_request(self, request, view): if self.rate is None: return True # 获取请求用户的IP self.key = self.get_cache_key(request, view) if self.key is None: return True # 根据IP获取他的所有访问记录,[] self.history = self.cache.get(self.key, []) self.now = self.timer() while self.history and self.history[-1] <= self.now - self.duration: self.history.pop() if len(self.history) >= self.num_requests: return self.throttle_failure() return self.throttle_success() def throttle_success(self): self.history.insert(0, self.now) self.cache.set(self.key, self.history, self.duration) return True def throttle_failure(self): return False def wait(self): if self.history: remaining_duration = self.duration - (self.now - self.history[-1]) else: remaining_duration = self.duration available_requests = self.num_requests - len(self.history) + 1 if available_requests <= 0: return None return remaining_duration / float(available_requests)
总结:
- 实现原理
来访问时:
1.获取当前时间 100121280
2.100121280-60 = 100121220,小于100121220所有记录删除
3.判断1分钟以内已经访问多少次了? 4
4.无法访问
停一会
来访问时:
1.获取当前时间 100121340
2.100121340-60 = 100121280,小于100121280所有记录删除
3.判断1分钟以内已经访问多少次了? 0
4.可以访问 - 具体流程
- 请求来时会执行allow_follow方法,
- 会用self.key获取请求用户的ip,再用self.history根据用户的ip获取其访问的记录,
- 获取当前的时间,用当前的时间减去设定的时间段,
- 循环该用户访问的记录,将不在该时间段的记录pop掉,
- 通过len判定该时间段已经访问了多少次,超过限定次数会返回false
- 匿名用户是通过ip进行访问限制,登录用户通过用户的id进行访问限制