Django框架(二十一)--Django rest_framework-频率组件
一、作用
为了控制用户对某个url请求的频率,比如,一分钟以内,只能访问三次
二、自定义频率类
# 写一个频率认证类 class MyThrottle: visit_dic = {} visit_time = None def __init__(self): self.ctime = time.time() # 重写allow_request()方法 # request是request对象,view是视图类,可以对视图类进行操作 def allow_request(self, request, view): ''' (1)取出访问者ip (2)判断当前ip不在访问字典里,添加进去,并且直接返回True,表示第一次访问,在字典里,继续往下走 (3)循环判断当前ip的列表,有值,并且当前时间减去列表的最后一个时间大于60s,把这种数据pop掉,这样列表中只有60s以内的访问时间, (4)判断,当列表小于3,说明一分钟以内访问不足三次,把当前时间插入到列表第一个位置,返回True,顺利通过 (5)当大于等于3,说明一分钟内访问超过三次,返回False验证失败 visit_dic = {ip1:[time2, time1, time0], ip2:[time1, time0], } ''' # 取出访问者ip,ip可以从请求头中取出来 ip = request.META.get('REMOTE_ADDR') # 判断该次请求的ip是否在地点中 if ip in self.visit_dic: # 当存在字典中时,取出这个ip访问时间的列表 visit_time = self.visit_dic[ip] self.visit_time = visit_time while visit_time: # 当访问时间列表中有值,时间间隔超过60,就将那个历史时间删除 if self.ctime - visit_time[-1] > 60: visit_time.pop() else: # 当pop到一定时,时间间隔不大于60了,退出循环,此时得到的是60s内访问的时间记录 break # while循环等价于 # while visit_time and ctime - visit_time[-1] > 60: # visit_time.pop() # 列表长度可表示访问次数,根据源码,可以得出,返回值是Boolean类型 if len(visit_time) >= 3: return False else: # 如果60秒内访问次数小于3次,将当前访问的时间记录下来 visit_time.insert(0, self.ctime) return True else: # 如果字典中没有当前访问ip,将ip加到字典中 self.visit_dic[ip] = [self.ctime, ] return True # 获取下次距访问的时间 def wait(self): return 60 - (self.ctime - self.visit_time[-1])
# view层 from app01 import MyAuth from rest_framework import exceptions class Book(APIView): # 局部使用频率控制 throttle_classes = [MyAuth.MyThrottle, ] def get(self,request): return HttpResponse('ok') # 重写抛出异常的方法 throttled def throttled(self, request, wait): class MyThrottled(exceptions.Throttled): default_detail = '下次访问' extra_detail_singular = '还剩 {wait} 秒.' extra_detail_plural = '还剩 {wait} 秒' raise MyThrottled(wait)
三、内置的访问频率控制类
from rest_framework.throttling import SimpleRateThrottle # 写一个频率控制类,继承SimpleRateThrottle类 class MyThrottle(SimpleRateThrottle): # 配置scope,通过scope到setting中找到 3/m scope = 'ttt' def get_cache_key(self, request, view): # 返回ip,效果和 get_ident() 方法相似 # ip = request.META.get('REMOTE_ADDR') # return ip # get_ident 返回的就是ip地址 return self.get_ident(request)
# view层视图类 class Book(APIView): throttle_classes = [MyAuth.MyThrottle, ] def get(self, request): return HttpResponse('ok') def throttled(self, request, wait): class MyThrottled(exceptions.Throttled): default_detail = '下次访问' extra_detail_singular = '还剩 {wait} 秒.' extra_detail_plural = '还剩 {wait} 秒' raise MyThrottled(wait)
# setting中配置 REST_FRAMEWORK = { 'DEFAULT_THROTTLE_RATES': { 'ttt': '10/m' } }
- 因此,要实现10分钟允许访问六次,可以继承
SimpleRateThrottle
类,然后重写parse_rate()
方法,将duration中key对应的值改为自己需要的值
四、全局、局部使用
1、全局使用
在setting中配置
REST_FRAMEWORK = { 'DEFAULT_THROTTLE_CLASSES': ['app01.MyAuth.MyThrottle', ], }
2、局部使用
在视图类中重定义throttle_classes
throttle_classes = [MyAuth.MyThrottle, ]
3、局部禁用
在视图类中重定义throttle_classes
为一个空列表
throttle_classes = []
五、源码分析
1、as_view -----> view ------> dispatch ------> initial ----> check_throttles 频率控制
2、self.check_throttles(request)
def check_throttles(self, request): """ Check if request should be throttled. Raises an appropriate exception if the request is throttled. """ # (2-----1) get_throttles 由频率类产生的对象组成的列表 for throttle in self.get_throttles(): if not throttle.allow_request(request, self): # (4)异常信息的处理 self.throttled(request, throttle.wait())
(2-----1) self.get_throttles()
def get_throttles(self): """ Instantiates and returns the list of throttles that this view uses. """ return [throttle() for throttle in self.throttle_classes]
3、allow_request()
自身、所在类找都没有,去父类中找
class SimpleRateThrottle(BaseThrottle): cache = default_cache timer = time.time cache_format = 'throttle_%(scope)s_%(ident)s' scope = None THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES def __init__(self): if not getattr(self, 'rate', None): self.rate = self.get_rate() self.num_requests, self.duration = self.parse_rate(self.rate) def parse_rate(self, rate): if rate is None: return (None, None) num, period = rate.split('/') num_requests = int(num) duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]] return (num_requests, duration) def allow_request(self, request, view): if self.rate is None: return True # (3-----1) get_cache_key就是要重写的方法,若不重写,会直接抛出异常 self.key = self.get_cache_key(request, view) if self.key is None: return True self.history = self.cache.get(self.key, []) self.now = self.timer() # Drop any requests from the history which have now passed the # throttle duration while self.history and self.history[-1] <= self.now - self.duration: self.history.pop() if len(self.history) >= self.num_requests: return self.throttle_failure() return self.throttle_success() # 返回距下一次能请求的时间 def wait(self): """ Returns the recommended next request time in seconds. """ if self.history: remaining_duration = self.duration - (self.now - self.history[-1]) else: remaining_duration = self.duration
(3-----1) self.get_cache_key(request, view)
def get_cache_key(self, request, view): """ Should return a unique cache-key which can be used for throttling. Must be overridden. May return `None` if the request should not be throttled. """ raise NotImplementedError('.get_cache_key() must be overridden')
4、self.throttled(request, throttle.wait()) --------> 抛出异常
def throttled(self, request, wait): """ If request is throttled, determine what kind of exception to raise. """ raise exceptions.Throttled(wait)
(4------1)raise exceptions.Throttled(wait) -------> 异常信息
class Throttled(APIException): status_code = status.HTTP_429_TOO_MANY_REQUESTS # 重写下面三个变量就可以修改显示的异常信息,例如用中文显示异常信息 default_detail = _('Request was throttled.') extra_detail_singular = 'Expected available in {wait} second.' extra_detail_plural = 'Expected available in {wait} seconds.' default_code = 'throttled' def __init__(self, wait=None, detail=None, code=None): if detail is None: detail = force_text(self.default_detail) if wait is not None: wait = math.ceil(wait) detail = ' '.join(( detail, force_text(ungettext(self.extra_detail_singular.format(wait=wait), self.extra_detail_plural.format(wait=wait), wait)))) self.wait = wait super(Throttled, self).__init__(detail, code)