频率组件及源码分析
频率组件
他的作用是限制接口访问的频率
频率类的编写
- 写一个类,继承SimpleRateThrottle
- 重写get_cache_key,返回唯一标识,返回什么就以什么做限制
- 重写类属性rate 控制频率
from rest_framework.throttling import BaseThrottle, SimpleRateThrottle
class CommonThrottling(SimpleRateThrottle):
# 每分钟限制访问三次
rate = '3/m'
def get_cache_key(self, request, view):
# 返回什么就以什么为限制
return request.META.get('REMOTE_ADDR')
频率类的使用
- 局部使用
- 全局使用
- 局部禁用
# 局部使用
class BookView(ModelViewSet):
queryset = Book.objects.all()
serializer_class = BookSerializer
# 在视图类里重写throttle_classes属性,列表里填自己定义的频率类
throttle_classes = [CommonThrottling]
# 全局使用
REST_FRAMEWORK = {
# 频率
'DEFAULT_THROTTLE_CLASSES': [],
}
# 局部禁用
class BookView(ModelViewSet):
queryset = Book.objects.all()
serializer_class = BookSerializer
throttle_classes = []
频率类源码分析
执行流程
当请求来的时候,首先会执行APIView里面的dispath里面的initial方法
在这里面会进行三大认证,频率限制就在这里执行
依次点进来看看什么情况
check_permissions
def check_throttles(self, request):
# 首先初始化一个throttle_durations列表
throttle_durations = []
# 点进get_throttles看就会知道
# 他就是对我们定义的频率类列表进行遍历,然后分别生成对象
for throttle in self.get_throttles():
# 所以这里的意思就是如果allow_request返回了false
# 就把throttle.wait()这个玩意的返回值加入到throttle_durations列表里
if not throttle.allow_request(request, self):
throttle_durations.append(throttle.wait())
# 如果throttle_durations不为空
if throttle_durations:
# 把throttle_durations列表的None值过滤掉,然后把结果赋值给 durations列表
durations = [
duration for duration in throttle_durations
if duration is not None
]
duration = max(durations, default=None)
# 这里是抛出了一个异常
self.throttled(request, duration)
接下来看看内置的频率类源码
BaseThrottle
class BaseThrottle:
"""
Rate throttling of requests.
请求的速率限制
"""
def allow_request(self, request, view):
"""
返回True就是通过校验,反之不通过抛出异常
"""
raise NotImplementedError('.allow_request() must be overridden')
# 这个方法就是用于确定请求来源
def get_ident(self, request):
xff = request.META.get('HTTP_X_FORWARDED_FOR')
remote_addr = request.META.get('REMOTE_ADDR')
num_proxies = api_settings.NUM_PROXIES
if num_proxies is not None:
if num_proxies == 0 or xff is None:
return remote_addr
addrs = xff.split(',')
client_addr = addrs[-min(num_proxies, len(addrs))]
return client_addr.strip()
return ''.join(xff.split()) if xff else remote_addr
# 用来定义返回的描述的
def wait(self):
"""
Optionally, return a recommended number of seconds to wait before
the next request.
"""
return None
SimpleRateThrottle
class SimpleRateThrottle(BaseThrottle):
cache = default_cache
timer = time.time
cache_format = 'throttle_%(scope)s_%(ident)s'
scope = None
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
def __init__(self):
if not getattr(self, 'rate', None):
self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate)
# 这个方法必须被重写,且返回什么就什么为标准
def get_cache_key(self, request, view):
raise NotImplementedError('.get_cache_key() must be overridden')
def get_rate(self):
if not getattr(self, 'scope', None):
msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
self.__class__.__name__)
raise ImproperlyConfigured(msg)
try:
return self.THROTTLE_RATES[self.scope]
except KeyError:
msg = "No default throttle rate set for '%s' scope" % self.scope
raise ImproperlyConfigured(msg)
def parse_rate(self, rate):
# 判断是否为None,如果为None则返回两个None
if rate is None:
return (None, None)
# 如果不为空就切分一下
num, period = rate.split('/')
# 把次数转成数字整数类型
num_requests = int(num)
# 定义了一个时间单位字典
duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
# 返回 一定时间内的次数 和 一定时间
return (num_requests, duration)
def allow_request(self, request, view):
# 如果rate为None,就是不做限制
if self.rate is None:
return True
# 获取限制标志
self.key = self.get_cache_key(request, view)
if self.key is None:
return True
# 先把它当个列表看
self.history = self.cache.get(self.key, [])
self.now = self.timer()
# 当这个列表不为空,代表着有访问,且没有超过持续限制时间
# 且 列表的最后一个数据,也就是最早访问的时间小于现在的时间减去持续时间,也就是超过了持续时间
# 就把这条访问记录踢出列表
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
# 如果列表的长度,也就是访问记录条数大于规定的访问次数,就执行throttle_failure方法
if len(self.history) >= self.num_requests:
return self.throttle_failure()
# 否则执行throttle_success方法
return self.throttle_success方法()
def throttle_success(self):
# 代表访问成功,然后往访问列表插入一条数据
self.history.insert(0, self.now)
self.cache.set(self.key, self.history, self.duration)
return True
def throttle_failure(self):
# 访问不成功 返回False
return False
# 看三大认证的源码可以发现,wait()方法是在访问不成功之后执行的
# 也就是allow_request 返回False执行的
def wait(self):
# 如果列表有值,也就是之前有访问记录
if self.history:
# 就计算出最早的一次访问还有多久解除限制,把值赋给remaining_duration
remaining_duration = self.duration - (self.now - self.history[-1])
else:
# 如果列表没有值,就代表自己定制的持续时间就是接触访问限制的时间
remaining_duration = self.duration
# 这里就是计算了一下还剩下多少次数可以访问
available_requests = self.num_requests - len(self.history) + 1
# 没有访问机会了就返回None
# 在三大认证里面会把None去除掉
if available_requests <= 0:
return None
# 最后返回 下次可以访问的时间 和 访问次数
return remaining_duration / float(available_requests)