Django restframework用户访问频率控制组件增加及源码分析

用户访问频率控制源码剖析,和用户登录验证有点相似,但是为了增加记忆,有必要再一次添加,

注意:一定要跟着博主的解说再看代码的中文注释及其下面的一行代码!!!

1、准备一个路由和视图类,全局路由配置暂时忽略,当流程执行到下面的url:groupsSelectAll——> GroupsView的视图类下的as_view()方法

from django.conf.urls import url

from . import views


app_name = '[words]'
urlpatterns = [
    url(r'groupsSelectAll/', views.GroupsView.as_view(), name="groupsSelectAll"),   # 词组信息查询所有

]
class GroupsView(APIView):

    def get(self, request):
        conditions = {
            "id": request.query_params.get("wid"),
            "name": request.query_params.get("name"),
            "start_time": request.query_params.get("start_time"),
            "end_time": request.query_params.get("end_time"),
        }
        res = DataManager.select_by_conditions("words_groups", None, **conditions)
        return Response(data={"code": 200, "result": res})

2、但是GroupsView类下没有as_view方法,这时就要去它的父类APIView查看(点进去看as_view方法),这里博主只复制方法源代码,大家只需要看中文注释及其下的代码语句。在这个方法中值得一提的是super关键字,如果请求视图类(就是GroupsView类,如果继承了多个父类)还有另一个父类,它先会查看这个父类是否有as_view方法。在这里它是会执行APIView的父类View中的as_view方法,然后我们再次查看父类View的as_view方法。第一个as_view方法是APIView类的,第二个as_view方法是View类的。

@classmethod
    def as_view(cls, **initkwargs):
        """
        Store the original class on the view function.

        This allows us to discover information about the view when we do URL
        reverse lookups.  Used for breadcrumb generation.
        """
        if isinstance(getattr(cls, 'queryset', None), models.query.QuerySet):
            def force_evaluation():
                raise RuntimeError(
                    'Do not evaluate the `.queryset` attribute directly, '
                    'as the result will be cached and reused between requests. '
                    'Use `.all()` or call `.get_queryset()` instead.'
                )
            cls.queryset._fetch_all = force_evaluation

        # 执行父类的as_view方法
        view = super(APIView, cls).as_view(**initkwargs)
        view.cls = cls
        view.initkwargs = initkwargs

        # Note: session based authentication is explicitly CSRF validated,
        # all other authentication is CSRF exempt.
        return csrf_exempt(view)
APIView.as_view()
    @classonlymethod
    def as_view(cls, **initkwargs):
        """Main entry point for a request-response process."""
        for key in initkwargs:
            if key in cls.http_method_names:
                raise TypeError("You tried to pass in the %s method name as a "
                                "keyword argument to %s(). Don't do that."
                                % (key, cls.__name__))
            if not hasattr(cls, key):
                raise TypeError("%s() received an invalid keyword %r. as_view "
                                "only accepts arguments that are already "
                                "attributes of the class." % (cls.__name__, key))

        # 执行view方法
        def view(request, *args, **kwargs):
            # 这里的cls就是我们的请求视图类,显而易见,这个self是请求试图类的对象
            self = cls(**initkwargs)
            if hasattr(self, 'get') and not hasattr(self, 'head'):
                self.head = self.get
            self.request = request
            self.args = args
            self.kwargs = kwargs
            # 然后这里就是执行dispatch方法
            return self.dispatch(request, *args, **kwargs)
        view.view_class = cls
        view.view_initkwargs = initkwargs

        # take name and docstring from class
        update_wrapper(view, cls, updated=())

        # and possible attributes set by decorators
        # like csrf_exempt from dispatch
        update_wrapper(view, cls.dispatch, assigned=())
        return view
View.as_view()

3、我们在第二个as_view方法中可以知道self是我们的请求视图类的对象,通过这个self调用dispatch方法,请求视图类中没有dispatch方法,是不是又去APIView类中执行dispatch方法。

def dispatch(self, request, *args, **kwargs):
        """
        `.dispatch()` is pretty much the same as Django's regular dispatch,
        but with extra hooks for startup, finalize, and exception handling.
        """
        self.args = args
        self.kwargs = kwargs
        # 这里是对原生的request加工处理,返回一个新的request对象
        request = self.initialize_request(request, *args, **kwargs)
        self.request = request
        self.headers = self.default_response_headers  # deprecate?

        try:
            # 初始化(用户登录认证,权限验证,访问频率限制)
            self.initial(request, *args, **kwargs)

            # Get the appropriate handler method
            if request.method.lower() in self.http_method_names:
                # 通过python的反射机制反射到请求视图类的方法名称
                handler = getattr(self, request.method.lower(),
                                  self.http_method_not_allowed)
            else:
                handler = self.http_method_not_allowed
            
            # 最后就是执行请求视图类的方法
            response = handler(request, *args, **kwargs)

        except Exception as exc:
            response = self.handle_exception(exc)

        self.response = self.finalize_response(request, response, *args, **kwargs)
        return self.response
APIView.dispatch()

4、其他代码不用看,我们直接看initial方法,因为这个initial方法有访问频率控制的功能。

def initial(self, request, *args, **kwargs):
        """
        Runs anything that needs to occur prior to calling the method handler.
        """
        self.format_kwarg = self.get_format_suffix(**kwargs)

        # Perform content negotiation and store the accepted info on the request
        neg = self.perform_content_negotiation(request)
        request.accepted_renderer, request.accepted_media_type = neg

        # Determine the API version, if versioning is in use.
        version, scheme = self.determine_version(request, *args, **kwargs)
        request.version, request.versioning_scheme = version, scheme

        # Ensure that the incoming request is permitted
        # 用户验证的方法,这个request 是加工之后的request
        self.perform_authentication(request)
        # 用户权限验证
        self.check_permissions(request)
        # 用户访问频率限制
        self.check_throttles(request)
APIView.initial()

5、这就到了我们的用户访问频率控制的戏码了。博主添加APIView部分代码,即check_throttles方法用到的代码。我们可以查看代码中的self.check_chrottles(request),点进去查看check_chrottles()方法,可以看到有get_throttles方法,这个方法有self.throttles_classes变量,即self.throttles_classes = api_settings.DEFAULT_THROTTLES_CLASSES,然后这里也和【上一篇的用户权限验证】很相似,就是请求视图类中如果没有这个变量名及值(值是一个列表),就会使用全局配置文件中的REST_FRAMEWORK={"DEFAULT_THROTTLES_CLASSES":["访问频率控制类的全路径"]},或者我们在请求视图类中添加这个变量及值

class APIView(View):

    # 如果请求视图类中没有这个变量和值,就会使用全局配置文件的值
    throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES

    def check_throttles(self, request):
        """
        Check if request should be throttled.
        Raises an appropriate exception if the request is throttled.
        """
        # 循环访问频率控制类的对象
        for throttle in self.get_throttles():
            # 执行对象下的方法allow_request(),如果为True就不做处理,如果为False,就抛出异常信息
            if not throttle.allow_request(request, self):
                self.throttled(request, throttle.wait())

    def get_throttles(self):
        """
        Instantiates and returns the list of throttles that this view uses.
        """
        # 返回访问频率控制类的对象列表
        return [throttle() for throttle in self.throttle_classes]

    def throttled(self, request, wait):
        """
        If request is throttled, determine what kind of exception to raise.
        """
        # 抛出访问频率控制类的异常
        raise exceptions.Throttled(wait)
APIView

6、在上面的APIView类中会执行到if not throttle.allow_throttles(request, self),我们可以直接点进去查看allow_request方法,进入有关于Throttles的类:BaseThrottle类主要提供了三个方法,但一般不会继承这个类,而是继承SimpleRateThrottle类,我们看看SimpleRateThrottle类,以下我添加了中文注释便于理解

class SimpleRateThrottle(BaseThrottle):
    """
    A simple cache implementation, that only requires `.get_cache_key()`
    to be overridden.

    The rate (requests / seconds) is set by a `rate` attribute on the View
    class.  The attribute is a string of the form 'number_of_requests/period'.

    Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')

    Previous request information used for throttling is stored in the cache.
    """
    # Django默认的缓存
    cache = default_cache
    timer = time.time
    cache_format = 'throttle_%(scope)s_%(ident)s'
    scope = None
    # 看这据行也可以在全局配置文件中配置
    THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES

    def __init__(self):
        # 寻找继承该类的方法是否含有rate方法或者变量名,如果没有,就执行get_rate函数
        if not getattr(self, 'rate', None):
            # 拿到的值是"3/m"
            self.rate = self.get_rate()
        # 接着执行这个方法
        self.num_requests, self.duration = self.parse_rate(self.rate)

    def get_cache_key(self, request, view):
        """
        Should return a unique cache-key which can be used for throttling.
        Must be overridden.

        May return `None` if the request should not be throttled.
        """
        raise NotImplementedError('.get_cache_key() must be overridden')

    def get_rate(self):
        """
        Determine the string representation of the allowed request rate.
        """
        # 通过反射机制调用子类或者当前类的scope函数或者变量名,如果有的话就执行try里面的语句
        if not getattr(self, 'scope', None):
            msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
                   self.__class__.__name__)
            raise ImproperlyConfigured(msg)

        try:
            # 获取配置类型的scope变量或者方法,这里应该返回一个比例
            # 在配置文件中假设一个值为 "3/m"
            return self.THROTTLE_RATES[self.scope]
        except KeyError:
            msg = "No default throttle rate set for '%s' scope" % self.scope
            raise ImproperlyConfigured(msg)

    def parse_rate(self, rate):
        """
        Given the request rate string, return a two tuple of:
        <allowed number of requests>, <period of time in seconds>
        """
        # rate = "3/m"
        if rate is None:
            return (None, None)
        num, period = rate.split('/')
        # num_request = 3
        num_requests = int(num)
        # duration = 60
        duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
        # (3, 60)
        return (num_requests, duration)

    def allow_request(self, request, view):
        """
        Implement the check to see if the request should be throttled.

        On success calls `throttle_success`.
        On failure calls `throttle_failure`.
        """
        if self.rate is None:
            return True

        # 这个方法专门供缓存适用,这个我们方法一般在实际开发中就需要重写这个方法了
        self.key = self.get_cache_key(request, view)
        if self.key is None:
            return True

        self.history = self.cache.get(self.key, [])
        self.now = self.timer()

        # Drop any requests from the history which have now passed the
        # throttle duration
        # 访问限制的主要实现
        while self.history and self.history[-1] <= self.now - self.duration:
            self.history.pop()
        if len(self.history) >= self.num_requests:
            return self.throttle_failure()
        return self.throttle_success()

    def throttle_success(self):
        """
        Inserts the current request's timestamp along with the key
        into the cache.
        """
        self.history.insert(0, self.now)
        self.cache.set(self.key, self.history, self.duration)
        return True

    def throttle_failure(self):
        """
        Called when a request to the API has failed due to throttling.
        """
        return False

    def wait(self):
        """
        Returns the recommended next request time in seconds.
        """
        # 这个方法就是提示还有多少时间才能访问
        if self.history:
            remaining_duration = self.duration - (self.now - self.history[-1])
        else:
            remaining_duration = self.duration

        available_requests = self.num_requests - len(self.history) + 1
        if available_requests <= 0:
            return None

        return remaining_duration / float(available_requests)
SimpleRateThrottle

这时我们自定义访问频率控制类时,继承这个类,编写就很简单了

class MyThrottle(SimpleRateThrottle):
    # 这个就是在全局配置文件上的键
    scope = "my_rate"

    def get_cache_key(self, request, view):
        return self.get_ident(request)


# 在配置文件上的配置
REST_FRAMEWORK = {
    "DEFAULT_THROTTLE_CLASSES":["MyThrottle的全路径"],
    "DEFAULT_THROTTLE_RATE:{
        "my_rate": "3/m"
    }
}

 

posted @ 2020-06-04 18:35  xsha_h  阅读(238)  评论(0编辑  收藏  举报