DRF之频率组件源码分析
【一】频率组件介绍
- Django Rest Framework(DRF)中的频率组件是用于限制API端点的访问频率的一种机制。
- 频率组件可以帮助你控制用户对API的请求频率,以防止滥用和DDoS攻击。
- 比如某个接口,一分钟只能访问5次,超过了就得等
- 按IP地址 限制
- 按用户id 限制
【二】内置的频率类
Throttle
类:这是频率组件的基类,定义了频率限制的核心逻辑。它包括 allow_request()
和 wait()
方法,用于检查是否允许请求以及在请求受限时应该等待多长时间。
- AnonRateThrottle:基于认证用户的请求频率进行限制,继承自
Throttle
并实现了特定的频率限制算法
- UserRateThrottle:基于匿名用户的请求频率进行限制,继承自
Throttle
并实现了特定的频率限制算法
【三】执行流程分析
- DRF视图类可以包含一个
throttle_classes
属性,该属性定义了应用于特定视图的频率组件。
- 在处理请求之前,DRF将遍历视图的
throttle_classes
列表,检查每个频率组件是否允许请求。
- 频率组件的配置:你可以在DRF的设置中配置默认的频率限制,也可以在视图类中设置
throttle_classes
来覆盖默认配置。
- 频率组件将根据这些配置来控制请求的频率。
- 在请求到达API视图时,DRF会首先检查用户的请求是否在频率组件的允许范围内。

- 如果请求超出了允许的频率限制,将返回一个HTTP 429 Too Many Requests响应。
| def check_throttles(self, request): |
| """ |
| Check if request should be throttled. |
| Raises an appropriate exception if the request is throttled. |
| """ |
| |
| throttle_durations = [] |
| |
| |
| for throttle in self.get_throttles(): |
| |
| if not throttle.allow_request(request, self): |
| |
| |
| throttle_durations.append(throttle.wait()) |
| |
| |
| if throttle_durations: |
| |
| |
| |
| |
| durations = [ |
| duration for duration in throttle_durations |
| if duration is not None |
| ] |
| |
| |
| duration = max(durations, default=None) |
| |
| |
| self.throttled(request, duration) |
| def throttled(self, request, wait): |
| """ |
| If request is throttled, determine what kind of exception to raise. |
| """ |
| |
| raise exceptions.Throttled(wait) |
【四】内置频率限制类源码分析
【0】BaseThrottle
| from rest_framework.throttling import BaseThrottle |
| class BaseThrottle: |
| """ |
| Rate throttling of requests. |
| """ |
| |
| |
| |
| |
| |
| def allow_request(self, request, view): |
| """ |
| Return `True` if the request should be allowed, `False` otherwise. |
| """ |
| raise NotImplementedError('.allow_request() must be overridden') |
| |
| |
| def get_ident(self, request): |
| """ |
| Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR |
| if present and number of proxies is > 0. If not use all of |
| HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR. |
| """ |
| |
| xff = request.META.get('HTTP_X_FORWARDED_FOR') |
| |
| remote_addr = request.META.get('REMOTE_ADDR') |
| num_proxies = api_settings.NUM_PROXIES |
| |
| |
| if num_proxies is not None: |
| |
| if num_proxies == 0 or xff is None: |
| |
| return remote_addr |
| |
| addrs = xff.split(',') |
| client_addr = addrs[-min(num_proxies, len(addrs))] |
| return client_addr.strip() |
| |
| |
| return ''.join(xff.split()) if xff else remote_addr |
| |
| |
| |
| |
| def wait(self): |
| """ |
| Optionally, return a recommended number of seconds to wait before |
| the next request. |
| """ |
| return None |
【1】AnonRateThrottle
| from rest_framework.throttling import AnonRateThrottle |
| class AnonRateThrottle(SimpleRateThrottle): |
| """ |
| Limits the rate of API calls that may be made by a anonymous users. |
| |
| The IP address of the request will be used as the unique cache key. |
| """ |
| scope = 'anon' |
| |
| def get_cache_key(self, request, view): |
| |
| if request.user and request.user.is_authenticated: |
| |
| return None |
| |
| |
| return self.cache_format % { |
| 'scope': self.scope, |
| 'ident': self.get_ident(request) |
| } |
【2】UserRateThrottle
| from rest_framework.throttling import UserRateThrottle |
| class UserRateThrottle(SimpleRateThrottle): |
| """ |
| Limits the rate of API calls that may be made by a given user. |
| |
| The user id will be used as a unique cache key if the user is |
| authenticated. For anonymous requests, the IP address of the request will |
| be used. |
| """ |
| |
| scope = 'user' |
| |
| |
| def get_cache_key(self, request, view): |
| |
| if request.user and request.user.is_authenticated: |
| |
| ident = request.user.pk |
| else: |
| |
| ident = self.get_ident(request) |
| |
| |
| return self.cache_format % { |
| 'scope': self.scope, |
| 'ident': ident |
| } |
【3】SimpleRateThrottle
| from rest_framework.throttling import SimpleRateThrottle |
| class SimpleRateThrottle(BaseThrottle): |
| """ |
| A simple cache implementation, that only requires `.get_cache_key()` |
| to be overridden. |
| |
| The rate (requests / seconds) is set by a `rate` attribute on the Throttle |
| class. The attribute is a string of the form 'number_of_requests/period'. |
| |
| Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day') |
| |
| Previous request information used for throttling is stored in the cache. |
| """ |
| |
| |
| |
| cache = default_cache |
| |
| |
| timer = time.time |
| |
| |
| cache_format = 'throttle_%(scope)s_%(ident)s' |
| |
| |
| |
| |
| scope = None |
| |
| |
| |
| THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES |
| |
| def __init__(self): |
| |
| |
| if not getattr(self, 'rate', None): |
| |
| self.rate = self.get_rate() |
| |
| |
| self.num_requests, self.duration = self.parse_rate(self.rate) |
| |
| |
| |
| def get_cache_key(self, request, view): |
| """ |
| Should return a unique cache-key which can be used for throttling. |
| Must be overridden. |
| |
| May return `None` if the request should not be throttled. |
| """ |
| |
| raise NotImplementedError('.get_cache_key() must be overridden') |
| |
| |
| def get_rate(self): |
| """ |
| Determine the string representation of the allowed request rate. |
| """ |
| |
| if not getattr(self, 'scope', None): |
| |
| msg = ("You must set either `.scope` or `.rate` for '%s' throttle" % |
| self.__class__.__name__) |
| |
| raise ImproperlyConfigured(msg) |
| |
| try: |
| |
| return self.THROTTLE_RATES[self.scope] |
| except KeyError: |
| |
| msg = "No default throttle rate set for '%s' scope" % self.scope |
| raise ImproperlyConfigured(msg) |
| |
| |
| def parse_rate(self, rate): |
| """ |
| Given the request rate string, return a two tuple of: |
| <allowed number of requests>, <period of time in seconds> |
| """ |
| |
| if rate is None: |
| |
| return (None, None) |
| |
| |
| num, period = rate.split('/') |
| |
| |
| num_requests = int(num) |
| |
| |
| |
| |
| |
| |
| duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]] |
| |
| return (num_requests, duration) |
| |
| |
| def allow_request(self, request, view): |
| """ |
| Implement the check to see if the request should be throttled. |
| |
| On success calls `throttle_success`. |
| On failure calls `throttle_failure`. |
| """ |
| |
| if self.rate is None: |
| |
| return True |
| |
| |
| self.key = self.get_cache_key(request, view) |
| |
| if self.key is None: |
| return True |
| |
| |
| self.history = self.cache.get(self.key, []) |
| |
| |
| self.now = self.timer() |
| |
| |
| |
| |
| |
| while self.history and self.history[-1] <= self.now - self.duration: |
| |
| self.history.pop() |
| |
| |
| if len(self.history) >= self.num_requests: |
| |
| |
| return self.throttle_failure() |
| |
| |
| return self.throttle_success() |
| |
| |
| def throttle_success(self): |
| """ |
| Inserts the current request's timestamp along with the key |
| into the cache. |
| """ |
| |
| self.history.insert(0, self.now) |
| |
| self.cache.set(self.key, self.history, self.duration) |
| return True |
| |
| |
| def throttle_failure(self): |
| """ |
| Called when a request to the API has failed due to throttling. |
| """ |
| return False |
| |
| |
| def wait(self): |
| """ |
| Returns the recommended next request time in seconds. |
| """ |
| |
| |
| if self.history: |
| |
| |
| remaining_duration = self.duration - (self.now - self.history[-1]) |
| else: |
| |
| |
| remaining_duration = self.duration |
| |
| |
| |
| |
| available_requests = self.num_requests - len(self.history) + 1 |
| |
| |
| if available_requests <= 0: |
| return None |
| |
| |
| return remaining_duration / float(available_requests) |
【4】ScopedRateThrottle
| from rest_framework.throttling import ScopedRateThrottle |
| class ScopedRateThrottle(SimpleRateThrottle): |
| """ |
| Limits the rate of API calls by different amounts for various parts of |
| the API. Any view that has the `throttle_scope` property set will be |
| throttled. The unique cache key will be generated by concatenating the |
| user id of the request, and the scope of the view being accessed. |
| """ |
| |
| |
| |
| scope_attr = 'throttle_scope' |
| |
| def __init__(self): |
| |
| |
| |
| pass |
| |
| |
| def allow_request(self, request, view): |
| |
| |
| |
| self.scope = getattr(view, self.scope_attr, None) |
| |
| |
| |
| |
| if not self.scope: |
| return True |
| |
| |
| |
| |
| |
| self.rate = self.get_rate() |
| |
| self.num_requests, self.duration = self.parse_rate(self.rate) |
| |
| |
| return super().allow_request(request, view) |
| |
| |
| def get_cache_key(self, request, view): |
| """ |
| If `view.throttle_scope` is not set, don't apply this throttle. |
| |
| Otherwise generate the unique cache key by concatenating the user id |
| with the `.throttle_scope` property of the view. |
| """ |
| |
| if request.user and request.user.is_authenticated: |
| |
| ident = request.user.pk |
| else: |
| |
| ident = self.get_ident(request) |
| |
| |
| return self.cache_format % { |
| 'scope': self.scope, |
| 'ident': ident |
| } |
【五】频率类的一般使用步骤(固定用法)
【1】创建自定义频次认证类
- 首先,需要创建一个频次认证类,并让它继承自
SimpleRateThrottle
- 这个类负责限制对某些操作或资源的访问频次。
【2】重写 get_cache_key 方法
- 在创建频次认证类后,需要重写其中的
get_cache_key
方法。
- 这个方法决定了如何从请求中获取缓存键值,以便在缓存中存储和检索频次信息。
- 可以考虑使用以下信息作为缓存键值的组成部分:
- 用户身份:可以使用用户的唯一标识符或者请求中的某些认证信息。
- 资源标识:如果需要对不同的资源进行频次限制,可以加入资源标识。
- 操作类型:如果需要对不同的操作类型进行频次限制,可以加入操作类型。
- 根据具体情况,可以将这些信息组合起来构成一个唯一的缓存键值,并在
get_cache_key
方法中返回。
- 最后,在频次认证类中,可以添加一个类属性来自定义命名。
- 这个属性可以用于在缓存中存储频次信息时使用。
- 可以考虑使用以下方式定义类属性:
- 在使用时,可以通过
MyThrottle.cache_name
来访问这个自定义属性。
| from rest_framework.throttling import SimpleRateThrottle |
| |
| class SimpleRate_Throttle(SimpleRateThrottle): |
| |
| scope = "scope_throttle" |
| |
| |
| |
| def get_cache_key(self, request, view): |
| |
| |
| return request.META.get('REMOTE_ADDR') |
| |
| |
| def wait(self): |
| """ |
| Returns the recommended next request time in seconds. |
| """ |
| if self.history: |
| remaining_duration = self.duration - (self.now - self.history[-1]) |
| else: |
| remaining_duration = self.duration |
| |
| available_requests = self.num_requests - len(self.history) + 1 |
| if available_requests <= 0: |
| return None |
| |
| |
| |
| ip = self.get_cache_key(self.request, self.view) |
| if ip == "特定IP地址": |
| |
| return remaining_duration / float(available_requests) * 0.5 |
| else: |
| |
| return remaining_duration / float(available_requests) |
【3】配置文件中配置
- 配置文件中需要对频次认证类进行相关配置
- 具体配置内容包括认证类的名称、参数设置和限制的频次等信息。
- 这里的限制关键字的变量名要和上面自定义限制类中的变量名一致
- 例如,在Django框架中,可以在settings.py或其他相关的配置文件中添加如下内容
- 'myapp.apithrottling.SimpleRate_Throttle' 是频次认证类的名称
- scope_throttle 是前述创建的频次认证类的类属性自定义命名
| REST_FRAMEWORK = { |
| 'DEFAULT_THROTTLE_CLASSES': ( |
| |
| 'myapp.apithrottling.SimpleRate_Throttle', |
| ), |
| 'DEFAULT_THROTTLE_RATES': { |
| 'scope_throttle': '5/minute', |
| } |
| } |
【4】局部使用
- 注意,只有继承了APIView 及其 子类的视图才会走三大认证
| from rest_framework.decorators import throttle_classes |
| |
| # @throttle_classes([MyThrottle]) 方式一:作为类的装饰器使用 |
| class BookView(APIView): |
| # 方式二:在视图类中应用限制类 |
| throttle_classes = [SimpleRate_Throttle] |
| |
【5】全局使用,局部禁用
| class BookView(APIView): |
| |
| # 将频率限制类的限制类列表清空即可 |
| throttle_classes = [] |
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 无需6万激活码!GitHub神秘组织3小时极速复刻Manus,手把手教你使用OpenManus搭建本
· Manus爆火,是硬核还是营销?
· 终于写完轮子一部分:tcp代理 了,记录一下
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通