RestFramework——API基本实现及dispatch基本源码剖析
基于Django实现
在使用RestFramework之前我们先用Django自己实现以下API。
API完全可以有我们基于Django自己开发,原理是给出一个接口(URL),前端向URL发送请求以获取数据。这样能实现前后端分离的效果。但Django实现的API许多功能都需要我们自己写。
URL
from django.contrib import admin from django.conf.urls import url, include from app01 import views from app02 import views urlpatterns = [ url('admin/', admin.site.urls), url('app02/', include('app02.urls'))#路由分发 ]
from app02 import views from django.conf.urls import url urlpatterns = [ url('^users/', views.users), url('^user/(\d+)', views.user), #######FBV与CBV的分界线######## url('^users/', views.UsersView.as_view()), url('^user/', views.UserView.as_view()), ]
views.py
FBV
from django.shortcuts import render,HttpResponse import json def users(request): response = {'code':1000,'data':None} #code用来表示状态,比如1000代表成功,1001代表 response['data'] = [ {'name':'Damon','age':22}, {'name':'Stefan','age':10}, {'name':'Elena','age':11}, ] return HttpResponse(json.dumps(response)) #返回就送字符串,前端解析 def user(request,pk): if request.method =='GET': return HttpResponse(json.dumps({'name':'Stefan','age':11})) #返回一条数据 elif request.method =='POST': return HttpResponse(json.dumps({'code':1111})) #返回一条数据 elif request.method =='PUT': pass elif request.method =='DELETE': pass
CBV
from django.views import View class UsersView(View): def get(self,request): response = {'code':1000,'data':None} response['data'] = [ {'name': 'Damon', 'age': 22}, {'name': 'Stefan', 'age': 10}, {'name': 'Elena', 'age': 11}, ] return HttpResponse(json.dumps(response),stutas=200) class UserView(View): def get(self,request,pk): return HttpResponse(json.dumps({'name':'haiyan','age':11})) #返回一条数据 def post(self,request,pk): return HttpResponse(json.dumps({'code':1111})) #返回一条数据 def put(self,request,pk): pass def delete(self,request,pk): pass
注:通常我们在前后端分离进行编程时会推崇使用CBV的形式,CBV的代码可读性较高。
基于RestFramework实现
安装:
pip3 install djangorestframework -i http://pypi.douban.com/simple/ --trusted-host=pypi.douban.com
RestFramework可以直接在Django中使用,安装完RestFramework后在Django中可以当做模块一般导入即可使用。(记得在settings.py中进行注册,如app)
URL与基于Django实现相同,这里选用CBV的形式
from app02 import views from django.conf.urls import url urlpatterns = [ url('^users/', views.UsersView.as_view()),#CBV必须要有as_view() url('^user/', views.UserView.as_view()), ]
views.py
CBV
#导入rest_framework,自定义视图的类需继承APIView from rest_framework.views import APIView from rest_framework.response import Response class TestView(APIView): def dispatch(self, request, *args, **kwargs): """ 请求到来之后,在url中执行as_view()就会执行dispatch方法,dispatch方法是APIView类中内置的,根据请求方式不同触发 get/post/put等方法。可自定制~ 注意:APIView中的dispatch方法有好多好多的功能 """ return super().dispatch(request, *args, **kwargs) def get(self, request, *args, **kwargs): return Response('GET请求,响应内容') def post(self, request, *args, **kwargs): return Response('POST请求,响应内容') def put(self, request, *args, **kwargs): return Response('PUT请求,响应内容')
注:重要的功都在APIView的dispatch中触发。要掌握RestFramework,必须弄懂dispatch方法做了些什么,这样我们才可以根据自己的需求进行自定制。
dispatch基本源码剖析
我们在继承了APIView之后就可以重写里面的方法进行自定制。此时我们需要先弄懂APIView里到底封装了哪些方法。在APIView中,最重要的就是dispatch方法。
请求在url中执行as_view()时就会触发dispatch,进入源码我们可以看到dispatch主要做了四件事:
#在APIView类中: def dispatch(self, request, *args, **kwargs): """ `.dispatch()` is pretty much the same as Django's regular dispatch, but with extra hooks for startup, finalize, and exception handling. """ self.args = args self.kwargs = kwargs # 1.封装Django原始的request,使得用了framework以后的request都具有更多功能 """ 增加的功能 parsers=self.get_parsers(), authenticators=self.get_authenticators(),#获取认证相关的所有类并实例化,传入request对象 negotiator=self.get_content_negotiator(), parser_context=parser_context """ request = self.initialize_request(request, *args, **kwargs) self.request = request#将封装后的request赋值给原始request self.headers = self.default_response_headers # deprecate? try: """ 2.版本处理、用户认证、权限、访问频率限制 """ self.initial(request, *args, **kwargs) # Get the appropriate handler method if request.method.lower() in self.http_method_names: handler = getattr(self, request.method.lower(), self.http_method_not_allowed) else: handler = self.http_method_not_allowed #3.执行函数get/post/put/delete response = handler(request, *args, **kwargs) except Exception as exc: response = self.handle_exception(exc) #4.对返回结果进行再次加工 self.response = self.finalize_response(request, response, *args, **kwargs) return self.response
接下来我们对每一步进行具体的分析
第一步:封装request
request = self.initialize_request(request, *args, **kwargs) #查看initialize_request做了什么 def initialize_request(self, request, *args, **kwargs): """ Returns the initial request object. """ parser_context = self.get_parser_context(request) return Request( request, parsers=self.get_parsers(), authenticators=self.get_authenticators(),#获取认证相关的所有类并实例化,传入request对象。优先找自己的,没有就找父类的 negotiator=self.get_content_negotiator(), parser_context=parser_context )
1.1、我们看到request封装了一个认证的功能——获取认证相关的所有的类并实例化,看看get_authenticators()做了什么
def get_authenticators(self): """ Instantiates and returns the list of authenticators that this view can use. """ #返回的是对象列表[SessionAuthentication,BaseAuthentication] return [auth() for auth in self.authentication_classes] #self.authentication_classes是封装有认证功能的类的列表 authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES #默认的,如果自己有会优先执行直接的
1.2、去settings.py查看默认的配置是什么
DEFAULTS = { # Base API policies 'DEFAULT_AUTHENTICATION_CLASSES': ( 'rest_framework.authentication.SessionAuthentication', #这时候就找到了他默认认证的类了,可以导入看看 'rest_framework.authentication.BasicAuthentication' ),
1.3、导入SessionAutentication和BasicAuthentication查看这两个雷兜风装了什么功能
from rest_framework.authentication import SessionAuthentication from rest_framework.authentication import BaseAuthentication
class BaseAuthentication(object): """ All authentication classes should extend BaseAuthentication. """ def authenticate(self, request): """ Authenticate the request and return a two-tuple of (user, token). """ raise NotImplementedError(".authenticate() must be overridden.") def authenticate_header(self, request): """ Return a string to be used as the value of the `WWW-Authenticate` header in a `401 Unauthenticated` response, or `None` if the authentication scheme should return `403 Permission Denied` responses. """ pass
class BasicAuthentication(BaseAuthentication): """ HTTP Basic authentication against username/password. """ www_authenticate_realm = 'api' def authenticate(self, request): """ Returns a `User` if a correct username and password have been supplied using HTTP Basic authentication. Otherwise returns `None`. """ auth = get_authorization_header(request).split() if not auth or auth[0].lower() != b'basic': return None #返回none不处理。让下一个处理 if len(auth) == 1: msg = _('Invalid basic header. No credentials provided.') raise exceptions.AuthenticationFailed(msg) elif len(auth) > 2: msg = _('Invalid basic header. Credentials string should not contain spaces.') raise exceptions.AuthenticationFailed(msg) try: auth_parts = base64.b64decode(auth[1]).decode(HTTP_HEADER_ENCODING).partition(':') #用partition切割冒号也包括 except (TypeError, UnicodeDecodeError, binascii.Error): msg = _('Invalid basic header. Credentials not correctly base64 encoded.') raise exceptions.AuthenticationFailed(msg) userid, password = auth_parts[0], auth_parts[2] # 返回用户和密码 return self.authenticate_credentials(userid, password, request) def authenticate_credentials(self, userid, password, request=None): """ Authenticate the userid and password against username and password with optional request for context. """ credentials = { get_user_model().USERNAME_FIELD: userid, 'password': password } user = authenticate(request=request, **credentials) if user is None: raise exceptions.AuthenticationFailed(_('Invalid username/password.')) if not user.is_active: raise exceptions.AuthenticationFailed(_('User inactive or deleted.')) return (user, None) def authenticate_header(self, request): return 'Basic realm="%s"' % self.www_authenticate_realm
1.4、简单自定制认证的类
我们可以看到源码中最重要的就是BasicAuthentication的authenticate方法,所以要自定制认证的类只需重写该方法即可
class MyAuthentication(BaseAuthentication): def authenticate(self, request): token = request.query_params.get('token')#登录用户有tocken字段 obj = models.UserInfo.objects.filter(token=token).first() if obj: return (obj.username,obj) return None def authenticate_header(self, request): pass
第二步、版本处理、认证、权限、访问频率限制
self.initial(request, *args, **kwargs) #查看initial方法做了什么 def initial(self, request, *args, **kwargs): """ Runs anything that needs to occur prior to calling the method handler. """ self.format_kwarg = self.get_format_suffix(**kwargs) # Perform content negotiation and store the accepted info on the request neg = self.perform_content_negotiation(request) request.accepted_renderer, request.accepted_media_type = neg # Determine the API version, if versioning is in use. #2.1 处理版本信息,获取版本必须用version version, scheme = self.determine_version(request, *args, **kwargs) request.version, request.versioning_scheme = version, scheme # Ensure that the incoming request is permitted #2.2认证,将user封装到request对象中 self.perform_authentication(request) #2.3 权限 self.check_permissions(request) #2.4 对请求用户进行访问频率的限制 self.check_throttles(request)
2.2.1、认证:查看perform_authentication方法,发现只是将user封装到了request中
def perform_authentication(self, request): request.user
2.2.2、查看request.user中都封装了什么
class Request(object): @property def user(self): """ Returns the user associated with the current request, as authenticated by the authentication classes provided to the request. """ if not hasattr(self, '_user'): self._authenticate()#执行用户认证 return self._user
2.2.3、执行self._authenticate() 开始用户认证,如果验证成功后返回元组: (用户,用户Token)
def _authenticate(self): """ Attempt to authenticate the request using each authentication instance in turn. """ #循环对象列表 for authenticator in self.authenticators: try: # 执行每一个对象的authenticate 方法 user_auth_tuple = authenticator.authenticate(self) except exceptions.APIException: self._not_authenticated() raise if user_auth_tuple is not None: self._authenticator = authenticator self.user, self.auth = user_auth_tuple#返回一个元组,赋给了self,就可以request.user,request.auth了 return self._not_authenticated()
2.2.4、在user_auth_tuple = authenticator.authenticate(self) 进行验证,如果验证成功,执行类里的authenticatie方法。如果用户没有认证成功:self._not_authenticated()
def _not_authenticated(self): """ Set authenticator, user & authtoken representing an unauthenticated request. Defaults are None, AnonymousUser & None. """ # 如果跳过了所有认证,默认用户和Token和使用配置文件进行设置 self._authenticator = None if api_settings.UNAUTHENTICATED_USER: self.user = api_settings.UNAUTHENTICATED_USER()# 默认值为:匿名用户AnonymousUser else: self.user = None# None 表示跳过该认证 if api_settings.UNAUTHENTICATED_TOKEN: self.auth = api_settings.UNAUTHENTICATED_TOKEN()# 默认值为:None else: self.auth = None #默认值都可以在settings.py中进行自定制配置 REST_FRAMEWORK = { 'UNAUTHENTICATED_USER': None, 'UNAUTHENTICATED_TOKEN': None, }
2.3.、权限控制
######check_permissions方法####### def check_permissions(self, request): """ Check if the request should be permitted. Raises an appropriate exception if the request is not permitted. """ for permission in self.get_permissions():#寻找类中的get_permissions()方法 if not permission.has_permission(request, self):#无权限则抛出异常 self.permission_denied( request, message=getattr(permission, 'message', None) ) ######get_permissions方法####### def get_permissions(self): """ Instantiates and returns the list of permissions that this view requires. """ return [permission() for permission in self.permission_classes]#去settings.py中查是否有权限配置
2.4、对访问频率进行限制(以下简称限流)
限流的主要目的还是为了防爬。一个网站的最终目的是为了让人去访问的,但是有时候会有一些人工智能做一些对网站有伤害的事,这时候我们就需要进行相应的限制了。权限的分配是一种对网站的保护的限制,但有些功能(比如看新闻、看动态等)是不需要任何权限只需要进入网站就可以查看的,这时我们就需要进行相应的限流操作,区分出非人类的用户访问予以限制。
def check_throttles(self, request): """ Check if request should be throttled. Raises an appropriate exception if the request is throttled. """ """ 循环每一个throttle对象,执行allow_request方法 allow_request: 返回False,说明限制访问频率 返回True,说明不限制,通行 可自定制 """ for throttle in self.get_throttles(): if not throttle.allow_request(request, self): self.throttled(request, throttle.wait())#throttle.wait()表示多少秒后可再次访问
from __future__ import unicode_literals import time from django.core.cache import cache as default_cache from django.core.exceptions import ImproperlyConfigured from rest_framework.settings import api_settings class BaseThrottle(object): """ Rate throttling of requests. """ def allow_request(self, request, view): """ Return `True` if the request should be allowed, `False` otherwise. """ raise NotImplementedError('.allow_request() must be overridden') def get_ident(self, request):#唯一标识 """ Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR if present and number of proxies is > 0. If not use all of HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR. """ xff = request.META.get('HTTP_X_FORWARDED_FOR') remote_addr = request.META.get('REMOTE_ADDR')#获取IP等 num_proxies = api_settings.NUM_PROXIES#获取代理信息 if num_proxies is not None: if num_proxies == 0 or xff is None: return remote_addr addrs = xff.split(',') client_addr = addrs[-min(num_proxies, len(addrs))] return client_addr.strip() return ''.join(xff.split()) if xff else remote_addr def wait(self): """ Optionally, return a recommended number of seconds to wait before the next request. """ return None#等待时长,可重写
class SimpleRateThrottle(BaseThrottle): """ 一个简单的缓存实现,只需要` get_cache_key() `。被覆盖。 速率(请求/秒)是由视图上的“速率”属性设置的。类。该属性是一个字符串的形式number_of_requests /期。 周期应该是:(的),“秒”,“M”,“min”,“h”,“小时”,“D”,“一天”。 以前用于节流的请求信息存储在高速缓存中。 A simple cache implementation, that only requires `.get_cache_key()` to be overridden. The rate (requests / seconds) is set by a `throttle` attribute on the View class. The attribute is a string of the form 'number_of_requests/period'. Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day') Previous request information used for throttling is stored in the cache. """ cache = default_cache timer = time.time cache_format = 'throttle_%(scope)s_%(ident)s' scope = None THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES def __init__(self): if not getattr(self, 'rate', None): self.rate = self.get_rate()#点进去看到需要些一个scope ,2/m self.num_requests, self.duration = self.parse_rate(self.rate) def get_cache_key(self, request, view):#这个相当于是一个半成品,我们可以来补充它 """ Should return a unique cache-key which can be used for throttling. Must be overridden. May return `None` if the request should not be throttled. """ raise NotImplementedError('.get_cache_key() must be overridden') def get_rate(self): """ Determine the string representation of the allowed request rate. """ if not getattr(self, 'scope', None):#检测必须有scope,没有就报错了 msg = ("You must set either `.scope` or `.rate` for '%s' throttle" % self.__class__.__name__) raise ImproperlyConfigured(msg) try: return self.THROTTLE_RATES[self.scope] except KeyError: msg = "No default throttle rate set for '%s' scope" % self.scope raise ImproperlyConfigured(msg) def parse_rate(self, rate): """ Given the request rate string, return a two tuple of: <allowed number of requests>, <period of time in seconds> """ if rate is None: return (None, None) num, period = rate.split('/') num_requests = int(num) duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]] return (num_requests, duration) # 1、一进来首先执行, def allow_request(self, request, view): """ Implement the check to see if the request should be throttled. On success calls `throttle_success`. On failure calls `throttle_failure`. """ if self.rate is None: return True self.key = self.get_cache_key(request, view)#2、执行get_cache_key if self.key is None: return True#不限制 self.history = self.cache.get(self.key, [])#3、得到的key,默认是一个列表,赋值给了self.history, # self.history可以理解为每一个ip对应的访问记录 self.now = self.timer() # Drop any requests from the history which have now passed the # throttle duration while self.history and self.history[-1] <= self.now - self.duration: self.history.pop() if len(self.history) >= self.num_requests: return self.throttle_failure() return self.throttle_success() def throttle_success(self): """ Inserts the current request's timestamp along with the key into the cache. """ self.history.insert(0, self.now) self.cache.set(self.key, self.history, self.duration) return True def throttle_failure(self): """ Called when a request to the API has failed due to throttling. """ return False def wait(self): """ Returns the recommended next request time in seconds. """ if self.history: remaining_duration = self.duration - (self.now - self.history[-1]) else: remaining_duration = self.duration available_requests = self.num_requests - len(self.history) + 1 if available_requests <= 0: return None return remaining_duration / float(available_requests)
第三步、执行函数get/post/put/delete
if request.method.lower() in self.http_method_names:#http_method_names = ['get', 'post', 'put', 'patch', 'delete', 'head', 'options', 'trace'] handler = getattr(self, request.method.lower(),self.http_method_not_allowed)#反射 else: handler = self.http_method_not_allowed#抛出异常
第四步、对返回结果进行再次加工