DRF源码刨析
一、开发模式
1、前后端不分离
- 前后端放在一块写
2、前后端分离
-
2.1、前端开发
-
2.2、后端开发
- 为前端提供API开发
- 永远返回HttpResponse
二、Django-wsgi
1、wsgi
wsgi | 备注 |
---|---|
wsgi | 协议 |
wsgiref | 是实现了wsgi协议的一个模块,一个socket服务端 (django) |
werkzeug | 是实现了wsgi协议的一个模块,一个socket服务端 (flask) |
tornado | 是实现了wsgi协议的一个模块,一个socket服务端 (tornado) |
uwsgi | 是实现了wsgi协议的一个模块,一个socket服务端 |
三、Django FBV、 CBV
1、FBV
# 路由
path('user/', view.users)
# 视图
def users(request):
if request.method == 'GET':
user_list = ['name': 'yiwen']
return HttpResponse(json.dumps((user_list)))
2、CBV
-
2.1、django中CBV是基于反射实现根据请求方式不同,执行不同的方法
-
2.2、原理:
- url ——> view方法 ——> dispatch方法(反射执行其他:GET/POST/DELETE/PUT/...)
-
2.3、什么是反射
-
它可以把字符串映射到实例的变量或者实例的方法然后可以去执行调用、修改等操作。它有四个重要的方法:
方法 备注 getattr 获取指定字符串名称的对象属性 setattr 为对象设置一个对象 hasattr 判断对象是否有对应的对象(字符串) delattr 删除指定属性
-
-
2.4、代码
# 路由 path('student/', views.LoginView.as_view(), name='student'), # 视图 class BaseView(View): def dispatch(self, request, *args, **kwargs): print('hello') ret = super(BaseView, self).dispatch(request, *args, **kwargs) return ret class StudentsView(BaseView): def get(self, request): return HttpResponse('get') def post(self, request): return HttpResponse('post') def put(self, request): return HttpResponse('put') def delete(self, request): return HttpResponse('delete')
四、django中间件
适用于所有请求批量做操作
基于角色的权限控制
用户认证
csrf
session
日志
1、中间件
- process_request
- process_view
- process_response
- process_exception
- process_render_template
2、django的csrf是如何实现的
- process_view方法
- 去请求体或cookie中获取token进行校验
3、CBV中csrf中使用csrf_exempt装饰器
from django.views.decorators.csrf import csrf_exempt, csrf_protect
from django.utils.decorators import method_decorator
# 方式一
@method_decorarot(csrf_exempt, name='dispatch')
class StudentsView(View):
def get(self, request):
return HttpResponse('get')
def post(self, request):
return HttpResponse('post')
def put(self, request):
return HttpResponse('put')
def delete(self, request):
return HttpResponse('delete')
# 方式二
class StudentsView(View):
@method_decorator(csrf_exempt)
def dispatch(self, request, *args, **kwargs):
ret = super(BaseView, self).dispatch(request, *args, **kwargs)
return ret
def get(self, request):
return HttpResponse('get')
def post(self, request):
return HttpResponse('post')
def put(self, request):
return HttpResponse('put')
def delete(self, request):
return HttpResponse('delete')
五、RESTFUL规范
restful 10 规范
规范 | 示例 | 备注 |
---|---|---|
协议 | https://www.baidu.com/ | API与用户的通信协议(通常使用https) |
域名 | https://api.example.com或https://example.org/api/ | 将API部署在专用域名之下或放在主域名下 |
版本 | https://api.example.com/v1/ | 将API的版本号放入URL。 |
路径 | https://api.example.com/v1/user | 路径又称"终点"(endpoint),表示API的具体网址 |
HTTP动词 | GET/POST/PUT/DELETE | 对于资源的具体操作类型,由HTTP动词表示 |
过滤信息 | ?limit=10&offset=10&page=2&per_page=100 | API应该提供参数,过滤返回结果 |
状态码 | 200/201/204/401/404/403/500 | 服务器向用户返回的状态码和提示信息 |
错误处理 | 如果状态码是4xx,就应该向用户返回出错信息。一般来说,返回的信息中将error作为键名,出错信息作为键值即可。 | |
返回结果 | GET /collection:返回资源对象的列表 | 针对不同操作,服务器向用户返回的结果遵循规范 |
Hypermedia API | {"link": { "rel": "collection https://www.example.com/zoos", "href": "https://api.example.com/zoos", "title": "List of zoos", "type": "application/vnd.yourformat+json" }} | RESTful API最好做到Hypermedia,即返回结果中提供链接,连向其他API方法,使得用户不查文档,也知道下一步应该做什么 |
1、根据method不同做不同的操作
# 路由
path('user/', views.UserView.as_view())
# 视图
class UserView(View):
def get(self, request):
return HttpResponse('获取')
def post(self, request):
return HttpResponse('创建')
def put(self, request):
return HttpResponse('更新')
def delete(self, request):
return HttpResponse('删除')
六、Django Rest framework
1、认证
有些api需要认证才能访问
-
1.1、认证流程原理
-
①、dispatch()对原生request进行加工
class APIView(View): permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES def dispatch(self, request, *args, **kwargs): self.args = args self.kwargs = kwargs # 对原生的request进行加工(丰富了一些功能) request = self.initialize_request(request, *args, **kwargs) self.request = request self.headers = self.default_response_headers # deprecate? try: # 认证 self.initial(request, *args, **kwargs) # 反射 if request.method.lower() in self.http_method_names: handler = getattr(self, request.method.lower(), self.http_method_not_allowed) else: handler = self.http_method_not_allowed response = handler(request, *args, **kwargs) except Exception as exc: # 认证失败抛出异常 response = self.handle_exception(exc) def initialize_request(self, request, *args, **kwargs): parser_context = self.get_parser_context(request) return Request( request, parsers=self.get_parsers(), # 认证类 authenticators=self.get_authenticators(), negotiator=self.get_content_negotiator(), parser_context=parser_context ) def get_authenticators(self): return [auth() for auth in self.authentication_classes]
-
②、initial认证
def initial(self, request, *args, **kwargs): self.perform_authentication(request)
-
③、调用request.user
def perform_authentication(self, request): request.user
-
④、request.py
class Request: @property def user(self): if not hasattr(self, '_user'): with wrap_attributeerrors(): # 获取认证对象,进行进一步认证 self._authenticate() return self._user
-
⑤、循环所有authenticatior对象
def _authenticate(self): for authenticator in self.authenticators: try: # 执行认证类的authenticate方法 # 1、如果authenticate方法抛出异常:self._not_authenticated()执行 # 2、成功,返回值必须是元组(request.user, request.auth) # 3、返回None user_auth_tuple = authenticator.authenticate(self) except exceptions.APIException: self._not_authenticated() raise if user_auth_tuple is not None: self._authenticator = authenticator self.user, self.auth = user_auth_tuple return self._not_authenticated()
-
⑥、Authication认证类
class Authication: def authenticate(self, request): token = request._request.Get.get('token') token_obj = models.UserToken.objects.filter(token=token).first() if not token_obj: raise exceptions.AuthenticationFailed('用户认证失败') return token_obj.user, token_obj def authenticate_header(self, request): pass
-
⑦、执行方法
class UserView(APIView): def get(self, request): pass def post(...): ...
-
-
1.2、认证类配置
-
①、在CBV中配置
class UserView(APIView): authentication_classes = [Xxx, Xzz] def get(self, request): pass def post(...): ...
-
②、全局配置
# setting.py REST_FRAMEWORK = { "DEFAULT_AUTHENTICATION_CLASSES": ['类路径'] }
-
-
1.3、匿名用户
# setting.py REST_FRAMEWORK = { "DEFAULT_AUTHENTICATION_CLASSES": ['类路径'], # request.user = None # request.auth = None "UNAUTHENTICATED_USER": None, "UNAUTHENTICATED_TOKEN": None }
-
1.4、内置认证类
-
①、认证基类
rest_framework.authentication.BaseAuthentication
-
源码
class BaseAuthentication: """ All authentication classes should extend BaseAuthentication. 所有认证类都应该继承于BaseAuthentication """ def authenticate(self, request): """ Authenticate the request and return a two-tuple of (user, token). """ raise NotImplementedError(".authenticate() must be overridden.") def authenticate_header(self, request): """ Return a string to be used as the value of the `WWW-Authenticate` header in a `401 Unauthenticated` response, or `None` if the authentication scheme should return `403 Permission Denied` responses. """ pass
-
-
②、其他认证类
-
源码
class BasicAuthentication(BaseAuthentication): """ HTTP Basic authentication against username/password. """ ... class SessionAuthentication(BaseAuthentication): """ Use Django's session framework for authentication. """ ... class TokenAuthentication(BaseAuthentication): """ Simple token based authentication. Clients should authenticate by passing the token key in the "Authorization" HTTP header, prepended with the string "Token ". For example: Authorization: Token 401f7ac837da42b97f613d789819ff93537bee6a """ ... class RemoteUserAuthentication(BaseAuthentication): """ REMOTE_USER authentication. To use this, set up your web server to perform authentication, which will set the REMOTE_USER environment variable. You will need to have 'django.contrib.auth.backends.RemoteUserBackend in your AUTHENTICATION_BACKENDS setting """ ...
-
-
-
1.5、总结
-
①、使用
-
创建类:继承BaseAuthentication
-
实现authenticate方法
返回值 备注 None 下一认证来执行 raise exceptions.AuthenticationFailed('用户认证失败') rest_framework.exceptions (元素一, 元素二) 元素一赋值给request.user, 元素二赋值给request.auth
-
-
②、源码流程
-
dispatch
-
封装request
- 获取定义的认证类(全局/局部),通过列表生成式创建对象
-
inital
- perform_authentication
- request.user(内部循环...)
- perform_authentication
-
-
-
2、权限
不同的用户访问接口有不同的权限
-
2.1、源码流程
-
与认证类源码流程相似
- dispatch
- 封装request
- inital
- check_permissions
- get_permissions获取定义的权限类(全局/局部),通过列表生成式创建对象
- permission.has_permission(判断权限)
- check_permissions
- dispatch
-
-
2.2、权限内置类
-
rest_framework.permissions.BasePermission
-
①、内置基类
-
源码
class BasePermission(metaclass=BasePermissionMetaclass): """ A base class from which all permission classes should inherit. """ def has_permission(self, request, view): """ Return `True` if permission is granted, `False` otherwise. """ return True def has_object_permission(self, request, view, obj): """ Return `True` if permission is granted, `False` otherwise. """ return True
-
-
②、其他权限类
-
源码
class AllowAny(BasePermission): """ Allow any access. This isn't strictly required, since you could use an empty permission_classes list, but it's useful because it makes the intention more explicit. """ def has_permission(self, request, view): return True class IsAuthenticated(BasePermission): """ Allows access only to authenticated users. """ def has_permission(self, request, view): return bool(request.user and request.user.is_authenticated) class IsAdminUser(BasePermission): """ Allows access only to admin users. """ def has_permission(self, request, view): return bool(request.user and request.user.is_staff) class IsAuthenticatedOrReadOnly(BasePermission): """ The request is authenticated as a user, or is a read-only request. """ def has_permission(self, request, view): return bool( request.method in SAFE_METHODS or request.user and request.user.is_authenticated )
-
-
-
2.3、使用
-
创建类: 继承BasePermission(rest_framework.permissions.BasePermission)
-
实现 has_permission方法
返回值 备注 True 有权访问 False 无权访问 -
全局使用
# settings.py REST_FRAMEWORK = { "DEFAULT_PERMISSION_CLASSES": ["类路径"] }
-
局部使用
class UserView(APIView): permission_classes = [Xxx, Xzz] def get(self, request): pass def post(...): ...
-
3、节流
控制访问频率
-
3.1、源码流程
- dispatch
- inital
- check_throttles
- get_throttles
- throttle.allow_request方法
- check_throttles
-
3.2、内置类
rest_framework.throtting.BaseThrottle
-
①、基类源码
class BaseThrottle: """ Rate throttling of requests. """ def allow_request(self, request, view): """ Return `True` if the request should be allowed, `False` otherwise. """ raise NotImplementedError('.allow_request() must be overridden') def get_ident(self, request): """ Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR if present and number of proxies is > 0. If not use all of HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR. """ xff = request.META.get('HTTP_X_FORWARDED_FOR') remote_addr = request.META.get('REMOTE_ADDR') num_proxies = api_settings.NUM_PROXIES if num_proxies is not None: if num_proxies == 0 or xff is None: return remote_addr addrs = xff.split(',') client_addr = addrs[-min(num_proxies, len(addrs))] return client_addr.strip() return ''.join(xff.split()) if xff else remote_addr def wait(self): """ Optionally, return a recommended number of seconds to wait before the next request. """ return None class SimpleRateThrottle(BaseThrottle): """ A simple cache implementation, that only requires `.get_cache_key()` to be overridden. The rate (requests / seconds) is set by a `rate` attribute on the View class. The attribute is a string of the form 'number_of_requests/period'. Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day') Previous request information used for throttling is stored in the cache. """ cache = default_cache timer = time.time cache_format = 'throttle_%(scope)s_%(ident)s' scope = None THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES def __init__(self): if not getattr(self, 'rate', None): self.rate = self.get_rate() self.num_requests, self.duration = self.parse_rate(self.rate) def get_cache_key(self, request, view): """ Should return a unique cache-key which can be used for throttling. Must be overridden. May return `None` if the request should not be throttled. """ raise NotImplementedError('.get_cache_key() must be overridden') def get_rate(self): """ Determine the string representation of the allowed request rate. """ if not getattr(self, 'scope', None): msg = ("You must set either `.scope` or `.rate` for '%s' throttle" % self.__class__.__name__) raise ImproperlyConfigured(msg) try: return self.THROTTLE_RATES[self.scope] except KeyError: msg = "No default throttle rate set for '%s' scope" % self.scope raise ImproperlyConfigured(msg) def parse_rate(self, rate): """ Given the request rate string, return a two tuple of: <allowed number of requests>, <period of time in seconds> """ if rate is None: return (None, None) num, period = rate.split('/') num_requests = int(num) duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]] return (num_requests, duration) def allow_request(self, request, view): """ Implement the check to see if the request should be throttled. On success calls `throttle_success`. On failure calls `throttle_failure`. """ if self.rate is None: return True self.key = self.get_cache_key(request, view) if self.key is None: return True self.history = self.cache.get(self.key, []) self.now = self.timer() # Drop any requests from the history which have now passed the # throttle duration while self.history and self.history[-1] <= self.now - self.duration: self.history.pop() if len(self.history) >= self.num_requests: return self.throttle_failure() return self.throttle_success() def throttle_success(self): """ Inserts the current request's timestamp along with the key into the cache. """ self.history.insert(0, self.now) self.cache.set(self.key, self.history, self.duration) return True def throttle_failure(self): """ Called when a request to the API has failed due to throttling. """ return False def wait(self): """ Returns the recommended next request time in seconds. """ if self.history: remaining_duration = self.duration - (self.now - self.history[-1]) else: remaining_duration = self.duration available_requests = self.num_requests - len(self.history) + 1 if available_requests <= 0: return None return remaining_duration / float(available_requests)
-
-
3.3、使用
-
创建类,继承BaseThrottle, 实现allow_reauest、 wait
-
创建类, 继承SimpleRateThrottle, 实现get_cache_key、 scope(配置文件中的key)
-
全局使用
DEFAULT_THROTTLE_CLASSES = ['类路径']
-
局部使用
throttle_classes = [类]
-
4、 版本
-
4.1、全局使用
# url路由配置 path('meituan/<str:version>/', include('meituanapp.urls')), # settings配置 REST_FRAMEWORK = { # 版本类 'DEFAULT_VERSIONING_CLASS': 'rest_framework.versioning.URLPathVersioning', # 默认版本 'DEFAULT_VERSION': 'v1', # 允许使用的版本 'ALLOWED_VERSIONS': ['v1', 'v2'] }
-
4.2、源码流程
-
dispatch
-
initial
-
determine_versuib
返回值 request方法 备注 scheme.determine_version(request, *args, **kwargs) request.version 获取当前版本 scheme request.versioning_scheme 版本控制类对象
-
-
-
-
4.3、request.versioning_scheme.reverse方法(路由反转)
# 路由 path('meituan/<str:version>/', include('serializeapp.urls', namespace='meituan')), # seruakuzeapp.urls.py path('good/', views.GoodView.as_view(), name='good'), # 视图 class GoodView(APIView): def get(self, request, **kwargs): print(request.versioning_scheme.reverse(viewname='meituan:good', request=request)) # http://127.0.0.1:8000/meituan/v1/good/ print(reverse('meituan:good', kwargs={'version': 'v2'})) # /meituan/v2/good/ return Response({'good': '111'})
-
4.4、内置版本控制类
-
①、基类
class BaseVersioning: default_version = api_settings.DEFAULT_VERSION allowed_versions = api_settings.ALLOWED_VERSIONS version_param = api_settings.VERSION_PARAM def determine_version(self, request, *args, **kwargs): msg = '{cls}.determine_version() must be implemented.' raise NotImplementedError(msg.format( cls=self.__class__.__name__ )) def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra): return _reverse(viewname, args, kwargs, request, format, **extra) def is_allowed_version(self, version): if not self.allowed_versions: return True return ((version is not None and version == self.default_version) or (version in self.allowed_versions))
-
②、其他类
class AcceptHeaderVersioning(BaseVersioning): """ GET /something/ HTTP/1.1 Host: example.com Accept: application/json; version=1.0 """ invalid_version_message = _('Invalid version in "Accept" header.') def determine_version(self, request, *args, **kwargs): media_type = _MediaType(request.accepted_media_type) version = media_type.params.get(self.version_param, self.default_version) version = unicode_http_header(version) if not self.is_allowed_version(version): raise exceptions.NotAcceptable(self.invalid_version_message) return version # We don't need to implement `reverse`, as the versioning is based # on the `Accept` header, not on the request URL. class URLPathVersioning(BaseVersioning): """ To the client this is the same style as `NamespaceVersioning`. The difference is in the backend - this implementation uses Django's URL keyword arguments to determine the version. An example URL conf for two views that accept two different versions. urlpatterns = [ re_path(r'^(?P<version>[v1|v2]+)/users/$', users_list, name='users-list'), re_path(r'^(?P<version>[v1|v2]+)/users/(?P<pk>[0-9]+)/$', users_detail, name='users-detail') ] GET /1.0/something/ HTTP/1.1 Host: example.com Accept: application/json """ invalid_version_message = _('Invalid version in URL path.') def determine_version(self, request, *args, **kwargs): version = kwargs.get(self.version_param, self.default_version) if version is None: version = self.default_version if not self.is_allowed_version(version): raise exceptions.NotFound(self.invalid_version_message) return version def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra): if request.version is not None: kwargs = {} if (kwargs is None) else kwargs kwargs[self.version_param] = request.version return super().reverse( viewname, args, kwargs, request, format, **extra ) class NamespaceVersioning(BaseVersioning): """ To the client this is the same style as `URLPathVersioning`. The difference is in the backend - this implementation uses Django's URL namespaces to determine the version. An example URL conf that is namespaced into two separate versions # users/urls.py urlpatterns = [ path('/users/', users_list, name='users-list'), path('/users/<int:pk>/', users_detail, name='users-detail') ] # urls.py urlpatterns = [ path('v1/', include('users.urls', namespace='v1')), path('v2/', include('users.urls', namespace='v2')) ] GET /1.0/something/ HTTP/1.1 Host: example.com Accept: application/json """ invalid_version_message = _('Invalid version in URL path. Does not match any version namespace.') def determine_version(self, request, *args, **kwargs): resolver_match = getattr(request, 'resolver_match', None) if resolver_match is None or not resolver_match.namespace: return self.default_version # Allow for possibly nested namespaces. possible_versions = resolver_match.namespace.split(':') for version in possible_versions: if self.is_allowed_version(version): return version raise exceptions.NotFound(self.invalid_version_message) def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra): if request.version is not None: viewname = self.get_versioned_viewname(viewname, request) return super().reverse( viewname, args, kwargs, request, format, **extra ) def get_versioned_viewname(self, viewname, request): return request.version + ':' + viewname class HostNameVersioning(BaseVersioning): """ GET /something/ HTTP/1.1 Host: v1.example.com Accept: application/json """ hostname_regex = re.compile(r'^([a-zA-Z0-9]+)\.[a-zA-Z0-9]+\.[a-zA-Z0-9]+$') invalid_version_message = _('Invalid version in hostname.') def determine_version(self, request, *args, **kwargs): hostname, separator, port = request.get_host().partition(':') match = self.hostname_regex.match(hostname) if not match: return self.default_version version = match.group(1) if not self.is_allowed_version(version): raise exceptions.NotFound(self.invalid_version_message) return version # We don't need to implement `reverse`, as the hostname will already be # preserved as part of the REST framework `reverse` implementation. class QueryParameterVersioning(BaseVersioning): """ GET /something/?version=0.1 HTTP/1.1 Host: example.com Accept: application/json """ invalid_version_message = _('Invalid version in query parameter.') def determine_version(self, request, *args, **kwargs): version = request.query_params.get(self.version_param, self.default_version) if not self.is_allowed_version(version): raise exceptions.NotFound(self.invalid_version_message) return version def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra): url = super().reverse( viewname, args, kwargs, request, format, **extra ) if request.version is not None: return replace_query_param(url, self.version_param, request.version) return url
-