django rest framework之节流的源码流程剖析
视图类:
1 class UserViewset(BaseView): 2 ''' 3 create: 4 创建用户 5 retrieve: 6 7 ''' 8 queryset = User.objects.all() 9 throttle_classes = [UserRateThrottle] #添加节流类 10 authentication_classes = (JSONWebTokenAuthentication, authentication.SessionAuthentication) 11 def get_serializer_class(self): 12 self.dispatch 13 if self.action == "retrieve": 14 return UserDetailSerializer 15 elif self.action == "create": 16 return UserRegSerializer 17 18 return UserDetailSerializer 19 20 def get_permissions(self): 21 if self.action == "retrieve": 22 return [permissions.IsAuthenticated()] 23 elif self.action == "create": 24 return [] 25 26 return [] 27 28 def create(self, request, *args, **kwargs): 29 serializer = self.get_serializer(data=request.data) 30 serializer.is_valid(raise_exception=True) 31 user = self.perform_create(serializer) 32 re_dict = serializer.data 33 payload = jwt_payload_handler(user) 34 re_dict["token"] = jwt_encode_handler(payload) 35 re_dict["name"] = user.name if user.name else user.username 36 37 headers = self.get_success_headers(serializer.data) 38 return Response(re_dict, status=status.HTTP_201_CREATED, headers=headers) 39 40 def get_object(self): 41 return self.request.user 42 43 def perform_create(self, serializer): 44 return serializer.save()
通权限类一样在中调用:
1 def check_throttles(self, request): 2 """ 3 Check if request should be throttled. 4 Raises an appropriate exception if the request is throttled. 5 """ 6 for throttle in self.get_throttles(): 7 if not throttle.allow_request(request, self): #验证是不是要被节流 8 self.throttled(request, throttle.wait()) #验证不通过就返回响应
内置节流类:
1 class BaseThrottle(object): 2 """ 3 Rate throttling of requests. 4 """ 5 6 def allow_request(self, request, view): 7 """ 8 Return `True` if the request should be allowed, `False` otherwise. 9 """ 10 raise NotImplementedError('.allow_request() must be overridden') 11 12 def get_ident(self, request): 获取访问IP 13 """ 14 Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR 15 if present and number of proxies is > 0. If not use all of 16 HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR. 17 """ 18 xff = request.META.get('HTTP_X_FORWARDED_FOR') 19 remote_addr = request.META.get('REMOTE_ADDR') 20 num_proxies = api_settings.NUM_PROXIES 21 22 if num_proxies is not None: 23 if num_proxies == 0 or xff is None: 24 return remote_addr 25 addrs = xff.split(',') 26 client_addr = addrs[-min(num_proxies, len(addrs))] 27 return client_addr.strip() 28 29 return ''.join(xff.split()) if xff else remote_addr 30 31 def wait(self): 32 """ 33 Optionally, return a recommended number of seconds to wait before 34 the next request. 35 """ 36 return None 37 38 39 class SimpleRateThrottle(BaseThrottle): 40 """ 41 A simple cache implementation, that only requires `.get_cache_key()` 42 to be overridden. 43 44 The rate (requests / seconds) is set by a `rate` attribute on the View 45 class. The attribute is a string of the form 'number_of_requests/period'. 46 47 Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day') 48 49 Previous request information used for throttling is stored in the cache. 50 """ 51 cache = default_cache 52 timer = time.time 53 cache_format = 'throttle_%(scope)s_%(ident)s' 54 scope = None 55 THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES #获取配置 56 57 def __init__(self): 58 if not getattr(self, 'rate', None): 59 self.rate = self.get_rate() 60 self.num_requests, self.duration = self.parse_rate(self.rate) 61 62 def get_cache_key(self, request, view): 63 """ 64 Should return a unique cache-key which can be used for throttling. 65 Must be overridden. 66 67 May return `None` if the request should not be throttled. 68 """ 69 raise NotImplementedError('.get_cache_key() must be overridden') 70 71 def get_rate(self): #获取配置的参数 72 """ 73 Determine the string representation of the allowed request rate. 74 """ 75 if not getattr(self, 'scope', None): 76 msg = ("You must set either `.scope` or `.rate` for '%s' throttle" % 77 self.__class__.__name__) 78 raise ImproperlyConfigured(msg) 79 80 try: 81 return self.THROTTLE_RATES[self.scope] 82 except KeyError: 83 msg = "No default throttle rate set for '%s' scope" % self.scope 84 raise ImproperlyConfigured(msg) 85 86 def parse_rate(self, rate): #获取定义里的节流策略如3/m,每分钟访问3次 87 """ 88 Given the request rate string, return a two tuple of: 89 <allowed number of requests>, <period of time in seconds> 90 """ 91 if rate is None: 92 return (None, None) 93 num, period = rate.split('/') 94 num_requests = int(num) 95 duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]] 96 return (num_requests, duration) 97 98 def allow_request(self, request, view): 99 """ 100 Implement the check to see if the request should be throttled. 101 102 On success calls `throttle_success`. 103 On failure calls `throttle_failure`. 104 """ 105 if self.rate is None: 106 return True 107 108 self.key = self.get_cache_key(request, view) #获取存储的key 109 if self.key is None: 110 return True 111 112 self.history = self.cache.get(self.key, []) #获取访问历史 113 self.now = self.timer() 114 115 # Drop any requests from the history which have now passed the 116 # throttle duration 117 while self.history and self.history[-1] <= self.now - self.duration: 118 self.history.pop() 119 if len(self.history) >= self.num_requests: #判断 120 return self.throttle_failure() 121 return self.throttle_success() 122 123 def throttle_success(self): 124 """ 125 Inserts the current request's timestamp along with the key 126 into the cache. 127 """ 128 self.history.insert(0, self.now) 129 self.cache.set(self.key, self.history, self.duration) 130 return True 131 132 def throttle_failure(self): 133 """ 134 Called when a request to the API has failed due to throttling. 135 """ 136 return False 137 138 def wait(self): #返回响应 139 """ 140 Returns the recommended next request time in seconds. 141 """ 142 if self.history: 143 remaining_duration = self.duration - (self.now - self.history[-1]) 144 else: 145 remaining_duration = self.duration 146 147 available_requests = self.num_requests - len(self.history) + 1 148 if available_requests <= 0: 149 return None 150 151 return remaining_duration / float(available_requests)
152 class UserRateThrottle(SimpleRateThrottle): 153 """ 154 Limits the rate of API calls that may be made by a given user. 155 156 The user id will be used as a unique cache key if the user is 157 authenticated. For anonymous requests, the IP address of the request will 158 be used. 159 """ 160 scope = 'user' 161 162 def get_cache_key(self, request, view): 163 if request.user.is_authenticated: #如果用户是登录后的就返回用户的id 164 ident = request.user.pk 165 else: 166 ident = self.get_ident(request) #返回请求的ip 167 168 return self.cache_format % { 169 'scope': self.scope, 170 'ident': ident 171 }