Loading

频率组件及源码分析

频率组件

​ 他的作用是限制接口访问的频率

频率类的编写

  1. 写一个类,继承SimpleRateThrottle
  2. 重写get_cache_key,返回唯一标识,返回什么就以什么做限制
  3. 重写类属性rate 控制频率
from rest_framework.throttling import BaseThrottle, SimpleRateThrottle

class CommonThrottling(SimpleRateThrottle):
    # 每分钟限制访问三次
    rate = '3/m'
    def get_cache_key(self, request, view):
        # 返回什么就以什么为限制
        return request.META.get('REMOTE_ADDR')

频率类的使用

  1. 局部使用
  2. 全局使用
  3. 局部禁用
# 局部使用
class BookView(ModelViewSet):
    queryset = Book.objects.all()
    serializer_class = BookSerializer
    # 在视图类里重写throttle_classes属性,列表里填自己定义的频率类
    throttle_classes = [CommonThrottling]

# 全局使用
REST_FRAMEWORK = {
    # 频率
    'DEFAULT_THROTTLE_CLASSES': [],
}

# 局部禁用
class BookView(ModelViewSet):
    queryset = Book.objects.all()
    serializer_class = BookSerializer
    throttle_classes = []

频率类源码分析

执行流程

​ 当请求来的时候,首先会执行APIView里面的dispath里面的initial方法

​ 在这里面会进行三大认证,频率限制就在这里执行

​ 依次点进来看看什么情况

check_permissions

def check_throttles(self, request):
	# 首先初始化一个throttle_durations列表
    throttle_durations = []
    # 点进get_throttles看就会知道
    # 他就是对我们定义的频率类列表进行遍历,然后分别生成对象
    for throttle in self.get_throttles():
        # 所以这里的意思就是如果allow_request返回了false
        # 就把throttle.wait()这个玩意的返回值加入到throttle_durations列表里
        if not throttle.allow_request(request, self):
            throttle_durations.append(throttle.wait())
	
    # 如果throttle_durations不为空
    if throttle_durations:
        # 把throttle_durations列表的None值过滤掉,然后把结果赋值给 durations列表
        durations = [
            duration for duration in throttle_durations
            if duration is not None
        ]

        duration = max(durations, default=None)
        # 这里是抛出了一个异常
        self.throttled(request, duration)

​ 接下来看看内置的频率类源码

BaseThrottle

class BaseThrottle:
    """
    Rate throttling of requests.
    请求的速率限制
    """
	
    def allow_request(self, request, view):
        """
        返回True就是通过校验,反之不通过抛出异常
        """
        raise NotImplementedError('.allow_request() must be overridden')
	
    # 这个方法就是用于确定请求来源
    def get_ident(self, request):
        xff = request.META.get('HTTP_X_FORWARDED_FOR')
        remote_addr = request.META.get('REMOTE_ADDR')
        num_proxies = api_settings.NUM_PROXIES

        if num_proxies is not None:
            if num_proxies == 0 or xff is None:
                return remote_addr
            addrs = xff.split(',')
            client_addr = addrs[-min(num_proxies, len(addrs))]
            return client_addr.strip()

        return ''.join(xff.split()) if xff else remote_addr

    # 用来定义返回的描述的
    def wait(self):
        """
        Optionally, return a recommended number of seconds to wait before
        the next request.
        """
        return None

SimpleRateThrottle

class SimpleRateThrottle(BaseThrottle):

    cache = default_cache
    timer = time.time
    cache_format = 'throttle_%(scope)s_%(ident)s'
    scope = None
    THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES

    def __init__(self):
        if not getattr(self, 'rate', None):
            self.rate = self.get_rate()
        self.num_requests, self.duration = self.parse_rate(self.rate)

    # 这个方法必须被重写,且返回什么就什么为标准    
    def get_cache_key(self, request, view):

        raise NotImplementedError('.get_cache_key() must be overridden')


    def get_rate(self):

        if not getattr(self, 'scope', None):
            msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
                   self.__class__.__name__)
            raise ImproperlyConfigured(msg)

        try:
            return self.THROTTLE_RATES[self.scope]
        except KeyError:
            msg = "No default throttle rate set for '%s' scope" % self.scope
            raise ImproperlyConfigured(msg)

    def parse_rate(self, rate):
		# 判断是否为None,如果为None则返回两个None
        if rate is None:
            return (None, None)
        # 如果不为空就切分一下
        num, period = rate.split('/')
        # 把次数转成数字整数类型
        num_requests = int(num)
        # 定义了一个时间单位字典
        duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
		# 返回 一定时间内的次数 和 一定时间
        return (num_requests, duration)

    def allow_request(self, request, view):
		# 如果rate为None,就是不做限制
        if self.rate is None:
            return True
		# 获取限制标志
        self.key = self.get_cache_key(request, view)
        if self.key is None:
            return True
		# 先把它当个列表看
        self.history = self.cache.get(self.key, [])
        self.now = self.timer()

		# 当这个列表不为空,代表着有访问,且没有超过持续限制时间
        # 且 列表的最后一个数据,也就是最早访问的时间小于现在的时间减去持续时间,也就是超过了持续时间
        # 就把这条访问记录踢出列表
        while self.history and self.history[-1] <= self.now - self.duration:
            self.history.pop()
        # 如果列表的长度,也就是访问记录条数大于规定的访问次数,就执行throttle_failure方法
        if len(self.history) >= self.num_requests:
            return self.throttle_failure()
        # 否则执行throttle_success方法
        return self.throttle_success方法()

    def throttle_success(self):
		# 代表访问成功,然后往访问列表插入一条数据
        self.history.insert(0, self.now)
        self.cache.set(self.key, self.history, self.duration)
        return True

    def throttle_failure(self):
		# 访问不成功 返回False
        return False
    
	# 看三大认证的源码可以发现,wait()方法是在访问不成功之后执行的
    # 也就是allow_request 返回False执行的
    def wait(self):
		# 如果列表有值,也就是之前有访问记录
        if self.history:
            # 就计算出最早的一次访问还有多久解除限制,把值赋给remaining_duration
            remaining_duration = self.duration - (self.now - self.history[-1])
        else:
            # 如果列表没有值,就代表自己定制的持续时间就是接触访问限制的时间
            remaining_duration = self.duration
		
        # 这里就是计算了一下还剩下多少次数可以访问
        available_requests = self.num_requests - len(self.history) + 1
        # 没有访问机会了就返回None
        # 在三大认证里面会把None去除掉
        if available_requests <= 0:
            return None
		# 最后返回 下次可以访问的时间 和 访问次数
        return remaining_duration / float(available_requests)
posted @ 2024-04-21 22:00  HuangQiaoqi  阅读(3)  评论(0编辑  收藏  举报