drf源码剖析----限流

点击查看代码
urlpatterns = [
# 1. 访问视图函数中的LoginView()类中的as_view()方法
    path('login/', views.LoginView.as_view()),  
]
点击查看代码
# 可自定义detail code
class Throttled(APIException):
    status_code = status.HTTP_429_TOO_MANY_REQUESTS
    default_detail = _('Request was throttled.')
    extra_detail_singular = _('Expected available in {wait} second.')
    extra_detail_plural = _('Expected available in {wait} seconds.')
    default_code = 'throttled'

    def __init__(self, wait=None, detail=None, code=None):
        if detail is None:
            detail = force_str(self.default_detail)
        if wait is not None:
            wait = math.ceil(wait)
            detail = ' '.join((
                detail,
                force_str(ngettext(self.extra_detail_singular.format(wait=wait),
                                   self.extra_detail_plural.format(wait=wait),
                                   wait))))
        self.wait = wait
        super().__init__(detail, code)
点击查看代码
class APIView(View):
    throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES

    def get_throttles(self):
       # 7. 优先获取类LoginView()的throttle_classe,没有再找全局配置
        return [throttle() for throttle in self.throttle_classes]

    def throttled(self, request, wait):
        raise exceptions.Throttled(wait)

    def check_throttles(self, request):  
        throttle_durations = []
        # 6. 返回限流类的实例化对象
        for throttle in self.get_throttles():
# 8. 类MyThrottling()里没有allow_request()方法,到父类SimpleRateThrottle()里找
            if not throttle.allow_request(request, self): # 访问次数超限
                throttle_durations.append(throttle.wait()) 
                  # 每个限流组件要等待的时间

        if throttle_durations:
            durations = [
                duration for duration in throttle_durations
                if duration is not None
            ]
            # 找到要等待的最长时间
            duration = max(durations, default=None)
            self.throttled(request, duration) # 抛出异常

    def initial(self, request, *args, **kwargs):
        self.perform_authentication(request)
        self.check_permissions(request)
        self.check_throttles(request)  # 5. 限流组件

    @classmethod
    def as_view(cls, **initkwargs):
# 2. 访问父类View()类中的as_view()方法
        view = super().as_view(**initkwargs)  
        view.cls = cls
        view.initkwargs = initkwargs

        return csrf_exempt(view)

    def dispatch(self, request, *args, **kwargs):
        self.args = args
        self.kwargs = kwargs
        request = self.initialize_request(request, *args, **kwargs)
        self.request = request
        self.headers = self.default_response_headers  

        try:
            # 4. 入口
            self.initial(request, *args, **kwargs)

            if request.method.lower() in self.http_method_names:
                handler = getattr(self, request.method.lower(),
                                  self.http_method_not_allowed)
            else:
                handler = self.http_method_not_allowed

            response = handler(request, *args, **kwargs)

        except Exception as exc:
            response = self.handle_exception(exc)

        self.response = self.finalize_response(request, response, *args, **kwargs)
        return self.response
点击查看代码
class View:
    @classonlymethod
    def as_view(cls, **initkwargs):
        def view(request, *args, **kwargs):
            self = cls(**initkwargs)
            self.setup(request, *args, **kwargs)
            if not hasattr(self, "request"):
                raise AttributeError(
                    "%s instance has no 'request' attribute. Did you override "
                    "setup() and forget to call super()?" % cls.__name__
                )
            return self.dispatch(request, *args, **kwargs)
# 3. 返回闭包函数,类APIView()中的dispatch()方法
        return view 
点击查看代码
class LoginView(APIView):
    authentication_classes = []
    throttle_classes = [MyThrottling, ]

    def post(self, request):
        username = request.data.get('username')
        password = request.data.get('password')
        user_obj = models.UserInfo.objects.filter(username=username, password=password).first()
        if not user_obj:
            return Response({'status': False, 'error': '用户不存在'})

        token = str(uuid.uuid4())
        user_obj.token = token
        user_obj.save()
        return Response({'status': True, 'data': token})
点击查看代码
class SimpleRateThrottle(BaseThrottle):
    cache = default_cache
    timer = time.time
    cache_format = 'throttle_%(scope)s_%(ident)s'
    scope = None
    THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES

    def __init__(self):
        if not getattr(self, 'rate', None):
            self.rate = self.get_rate()
        # 解析获取访问次数和时间间隔
        self.num_requests, self.duration = self.parse_rate(self.rate)

    def get_cache_key(self, request, view):
        raise NotImplementedError('.get_cache_key() must be overridden')

    def get_rate(self):
        if not getattr(self, 'scope', None):
            msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
                   self.__class__.__name__)
            raise ImproperlyConfigured(msg)

        try:
            return self.THROTTLE_RATES[self.scope]
        except KeyError:
            msg = "No default throttle rate set for '%s' scope" % self.scope
            raise ImproperlyConfigured(msg)

    def parse_rate(self, rate):
        if rate is None:
            return (None, None)
        num, period = rate.split('/')
        num_requests = int(num)
        duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
        return (num_requests, duration)

    def allow_request(self, request, view):
        if self.rate is None:
            return True
        # 获取用户的唯一标识
        self.key = self.get_cache_key(request, view)
        if self.key is None:
            return True
        # 获取历史访问记录
        self.history = self.cache.get(self.key, [])
        self.now = self.timer()
        # 剔除时间限制范围之外的访问记录
        while self.history and self.history[-1] <= self.now - self.duration:
            self.history.pop()
        # 访问记录的次数大于访问次数则失败, 成功则插入新的访问记录
        if len(self.history) >= self.num_requests:
            return self.throttle_failure()
        return self.throttle_success()

    def throttle_success(self):
        self.history.insert(0, self.now)
        self.cache.set(self.key, self.history, self.duration)
        return True

    def throttle_failure(self):
        return False

    def wait(self):
        if self.history:
            # 还需要等待多久
            remaining_duration = self.duration - (self.now - self.history[-1])
        else:
            remaining_duration = self.duration

        available_requests = self.num_requests - len(self.history) + 1
        if available_requests <= 0:
            return None

        return remaining_duration / float(available_requests)
点击查看代码
class MyThrottling(SimpleRateThrottle):
    scope = 'xxx'
    THROTTLE_RATES = {'xxx': '5/m'}
    cache = default_cache

    def get_cache_key(self, request, view):
        if request.user:
            ident = request.user.pk
        else:
            ident = self.get_ident(request)
        return self.cache_format % {'scope': self.scope, 'ident': ident}
posted @   周亚彪  阅读(12)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 分享4款.NET开源、免费、实用的商城系统
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· 上周热点回顾(2.24-3.2)
点击右上角即可分享
微信分享提示