drf 认证、权限、频率三组件、排序、过滤
一、三大认证
要使用三大认证的功能,视图类至少要继承 APIview 视图类。因为APIview视图类去除了csrf认证、封装新的request、走三大认证和全局异常
1、支持序列化的视图类中,比如APIview,源码中view---> self.dispatch--->self.initial---> initial()
1 2 3 4 5 6 7 8 9 10 11 12 13 | def initial( self , request, * args, * * kwargs): self .format_kwarg = self .get_format_suffix( * * kwargs) neg = self .perform_content_negotiation(request) request.accepted_renderer, request.accepted_media_type = neg version, scheme = self .determine_version(request, * args, * * kwargs) request.version, request.versioning_scheme = version, scheme # Ensure that the incoming request is permitted self .perform_authentication(request) self .check_permissions(request) self .check_throttles(request) |
二、认证组件
判断用户是否登录,数据库是否有值
1、需求:
通过认证组件去认证,没有认证通过的用户不让登录。认证方式前端发来的token值与数据库进行对比
2、models
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | from django.db import models class User(models.Model): username = models.CharField(max_length = 32 ) password = models.CharField(max_length = 32 ) # 用户类型 user_type = models.IntegerField(choices = (( 1 , '2B用户' ), ( 2 , '普通用户' ), ( 3 , '超级用户' ))) # 一对一的关系 class UserToken(models.Model): token = models.CharField(max_length = 64 ) # OneToOneField本质就是ForeignKey+unique user = models.OneToOneField(to = User, on_delete = models.CASCADE) # user = models.ForeignKey(to=User, unique=True,on_delete=models.CASCADE) |
2、写一个认证 auth.py
继承 BaseAuthentication
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | from .models import UserToken from rest_framework.authentication import BaseAuthentication from rest_framework.exceptions import AuthenticationFailed class LoginAuth(BaseAuthentication): # 重写一下authenticate方法 def authenticate( self , request): token = request.query_params.get( 'token' ) # 从url中的参数去拿 print (token) #4d1455f4-97be-4cf6-88a1-dee6e21433ad user_token = UserToken.objects. filter (token = token).first() if user_token: return user_token.user, token else : raise AuthenticationFailed( '您没有登录!' ) |
补充:
token从请求头中去拿
token = request.META.get('HTTP_TOKEN')
token从 请求体中取
token = request.data.get('token')
3、views
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | class UserView(ViewSet): # authentication_classes = [] # 局部解除认证禁用,token没有认证通过就不让登录 permission_classes = [] # 局部解除权限禁用 @action (methods = [ 'POST' ], detail = False ) # /user/login/ post 请求就会执行 def login( self , request, * args, * * kwargs): # 前端传入用户名密码 username = request.data.get( 'username' ) password = request.data.get( 'password' ) user = User.objects. filter (username = username, password = password).first() print (user) if user: # 生成一个随机字符串,返回给前端,并且要把随机字符串存到token表中 token = str (uuid.uuid4()) ##### 方式一:麻烦方式 # user_token=UserToken.objects.filter(user=user).first() # if user_token: # user_token.token=token # user_token.save() # else: # UserToken.objects.create(user=user,token=token) ## 方式二:通过user去UserToken表中查,如果能查到用defaults的更新,如果查不到,就用user和defaults新增一条记录 # 每次登录会返回一个uuid,会更新uuid UserToken.objects.update_or_create(defaults = { 'token' : token}, user = user) return Response({ 'code' : 100 , 'msg' : '登录成功' , 'token' : token, 'username' : user.username}) else : return Response({ 'code' : 101 , 'msg' : '用户名或密码错误' }) |
补充:
1、继承 ViewSetMixin,APIView,常见的5个路由自动匹配,其他的通过action装饰器进行指定
4、开启认证
在全局开启认证时需要指定app01下面的auth认证模块
1 2 3 4 5 6 7 8 9 10 11 12 | ## 全局开启 REST_FRAMEWORK = { 'DEFAULT_AUTHENTICATION_CLASSES' : [ 'app01.auth.LoginAuth' ], } ##局部开启 authentication_classes = [authenticate] ##局部禁用 authentication_classes = [] |
5、拿着数据库的token访问
1 | 127.0 . 0.1 : 8000 / user / login / ?token = 730c7ebf - 4b12 - 4f55 - a337 - 0b341d333395 |
三、权限组件
判断登录成功后的用户是否有操作权限
1、permissions 文件
继承: BasePermission
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | from rest_framework.permissions import BasePermission # 1 写一个类,继承 BasePermission # 2 重写 has_permission # 3 在方法中校验用户是否有权限,如果有,就返回True,如果没有,就返回False class UserPermission(BasePermission): def has_permission( self , request, view): # request 当次请求的request, 新的,它是在认证类之后执行的,如果认证通过了request.user 就是当前登录用户 # 拿到当前登录用户,查看它的类型,确定有没有权限 if request.user.user_type = = 3 : return True else : self .message = '您的用户类型是:%s,您没有权限操作' % (request.user.get_user_type_display()) return False |
2、views
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | from rest_framework.viewsets import ViewSetMixin, ViewSet from rest_framework.response import Response from rest_framework.decorators import action from .models import User, UserToken import uuid # class UserView(ViewSetMixin,APIView): class UserView(ViewSet): # authentication_classes = [] # 局部解除认证禁用,token没有认证通过就不让登录 # permission_classes = [] # 局部解除权限禁用 @action (methods = [ 'POST' ], detail = False ) # /user/login/ post 请求就会执行 def login( self , request, * args, * * kwargs): # 前端传入用户名密码 username = request.data.get( 'username' ) password = request.data.get( 'password' ) user = User.objects. filter (username = username, password = password).first() print (username) if user: # 生成一个随机字符串,返回给前端,并且要把随机字符串存到token表中 token = str (uuid.uuid4()) ##### 方式一:麻烦方式 # user_token=UserToken.objects.filter(user=user).first() # if user_token: # user_token.token=token # user_token.save() # else: # UserToken.objects.create(user=user,token=token) ## 方式二:通过user去UserToken表中查,如果能查到用defaults的更新,如果查不到,就用user和defaults新增一条记录 # 每次登录会返回一个uuid,会更新uuid UserToken.objects.update_or_create(defaults = { 'token' : token}, user = user) return Response({ 'code' : 100 , 'msg' : '登录成功' , 'token' : token, 'username' : user.username}) else : return Response({ 'code' : 101 , 'msg' : '用户名或密码错误' }) |
3、开启权限认证
1 2 3 4 5 6 7 8 9 10 11 12 | ## 全局认证 REST_FRAMEWORK = { 'DEFAULT_PERMISSION_CLASSES' : [ 'app01.permissions.UserPermission' , ], } ## 局部认证 permission_classes = [UserPermission] ##局部解除权限认证 permission_classes = [] |
4、访问效果
四、频率组件
限制用户ip频繁访问
1、throtting 文件
继承:SimpleRateThrottle
1 2 3 4 5 6 7 8 | from rest_framework.throttling import SimpleRateThrottle class IPRateThrottle(SimpleRateThrottle): scope = 'auth1' # scope: 范围 def get_cache_key( self , request, view): print (request.META) return request.META.get( 'REMOTE_ADDR' ) # REMOTE_ADDR 远程的ip地址 |
补充:
1 重写SimpleRateThrottle中的get_cache_key方法,这个方法返回什么就会用什么去做限制
2 以用户id做限制
1 2 3 4 5 6 7 8 9 10 | def get_cache_key( self , request, view): # 重写get_cache_key,返回什么,就以什么做限制: IP地址,用户id限制 # print(request.META) # request.META 是一个包含有关HTTP请求的元数据的字典 # return request.META.get('REMOTE_ADDR') # REMOTE_ADDR 远程的ip地址 try : if request.user: user_id = request.user.pk print ( '用户id是:' , user_id) return user_id except Exception: raise AuthenticationFailed( '您没有登录,请先登录吧' ) |
3 scope 是一个类属性,它的值用在全局配置中
1 2 3 | 'DEFAULT_THROTTLE_RATES' : { 'auth1' : '3/m' , # 一分钟访问三次 }, |
2、views
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | class UserView(ViewSet): authentication_classes = [] # 局部解除认证禁用,token没有认证通过就不让登录 permission_classes = [] # 局部解除权限禁用 @action (methods = [ 'POST' ], detail = False ) # /user/login/ post 请求就会执行 def login( self , request, * args, * * kwargs): # 前端传入用户名密码 username = request.data.get( 'username' ) password = request.data.get( 'password' ) user = User.objects. filter (username = username, password = password).first() print (username) if user: # 生成一个随机字符串,返回给前端,并且要把随机字符串存到token表中 token = str (uuid.uuid4()) ##### 方式一:麻烦方式 # user_token=UserToken.objects.filter(user=user).first() # if user_token: # user_token.token=token # user_token.save() # else: # UserToken.objects.create(user=user,token=token) ## 方式二:通过user去UserToken表中查,如果能查到用defaults的更新,如果查不到,就用user和defaults新增一条记录 # 每次登录会返回一个uuid,会更新uuid UserToken.objects.update_or_create(defaults = { 'token' : token}, user = user) return Response({ 'code' : 100 , 'msg' : '登录成功' , 'token' : token, 'username' : user.username}) else : return Response({ 'code' : 101 , 'msg' : '用户名或密码错误' }) |
3、开启频率认证
1 2 3 4 5 6 7 8 9 10 11 12 13 | ## 全局认证 REST_FRAMEWORK = { 'DEFAULT_THROTTLE_RATES' : { 'auth1' : '3/m' , # 一分钟访问三次 }, 'DEFAULT_THROTTLE_CLASSES' : [ 'app01.throtting.IPRateThrottle' ], } ## 局部认证 throttle_classes = [IPRateThrottle] ##局部解除权限认证 throttle_classes = [] |
4、效果
补充:自定义一个频率类
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 | # 自定义的逻辑 #(1)取出访问者ip #(2)判断当前ip不在访问字典里,添加进去,并且直接返回True,表示第一次访问,在字典里,继续往下走 #(3)循环判断当前ip的列表,有值,并且当前时间减去列表的最后一个时间大于60s,把这种数据pop掉,这样列表中只有60s以内的访问时间, #(4)判断,当列表小于3,说明一分钟以内访问不足三次,把当前时间插入到列表第一个位置,返回True,顺利通过 #(5)当大于等于3,说明一分钟内访问超过三次,返回False验证失败 class MyThrottles(): VISIT_RECORD = {} def __init__( self ): self .history = None def allow_request( self ,request, view): #(1)取出访问者ip # print(request.META) ip = request.META.get( 'REMOTE_ADDR' ) import time ctime = time.time() # (2)判断当前ip不在访问字典里,添加进去,并且直接返回True,表示第一次访问 if ip not in self .VISIT_RECORD: self .VISIT_RECORD[ip] = [ctime,] return True self .history = self .VISIT_RECORD.get(ip) # (3)循环判断当前ip的列表,有值,并且当前时间减去列表的最后一个时间大于60s,把这种数据pop掉,这样列表中只有60s以内的访问时间, while self .history and ctime - self .history[ - 1 ]> 60 : self .history.pop() # (4)判断,当列表小于3,说明一分钟以内访问不足三次,把当前时间插入到列表第一个位置,返回True,顺利通过 # (5)当大于等于3,说明一分钟内访问超过三次,返回False验证失败 if len ( self .history)< 3 : self .history.insert( 0 ,ctime) return True else : return False def wait( self ): import time ctime = time.time() return 60 - (ctime - self .history[ - 1 ]) |
全局使用和局部使用
1 2 3 4 5 6 | REST_FRAMEWORK = { 'DEFAULT_THROTTLE_CLASSES' :[ 'app01.utils.MyThrottles' ,], } #在视图类里使用 throttle_classes = [MyThrottles,] |
补充:频率的源码分析
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 | # 1 频率源码 - APIView - - - - disaptch - - - 》 self .initial(request, * args, * * kwargs) - - - 》 416 行: self .check_throttles(request) - - - - 》 352 行 check_throttles def check_throttles( self , request): # self.get_throttles()就是咱们配置在视图类上频率类的对象列表[频率类对象,] for throttle in self .get_throttles(): # 执行频率类对象的allow_request,传了2个,返回True或False if not throttle.allow_request(request, self ): # 反会给前端失败,显示还剩多长时间能再访问 throttle_durations.append(throttle.wait()) # 2 频率类要写 1 写一个类,继承,BaseThrottle 2 在类中重写:allow_request方法,传入 3 个参数 3 在allow_request写限制逻辑,如果还能访问 - - 》返回 True 4 如果超了次数,就不能访问,返回 False 5 局部配置在视图类上 6 全局配置在配置文件中 # 3 我们在drf中写的时候,不需要继承 BaseThrottle,继承了SimpleRateThrottle,重写get_cache_key - 我们猜测:一定是 SimpleRateThrottle帮咱们写了咱们需要写的 # 4 自定义频率类,实现一分钟只能访问三次的控制: # (1)取出访问者ip # (2)判断当前ip不在访问字典里,添加进去,并且直接返回True,表示第一次访问,在字典里,继续往下走 # (3)循环判断当前ip的列表,有值,并且当前时间减去列表的最后一个时间大于60s,把这种数据pop掉,这样列表中只有60s以内的访问时间, # (4)判断,当列表小于3,说明一分钟以内访问不足三次,把当前时间插入到列表第一个位置,返回True,顺利通过 # (5)当大于等于3,说明一分钟内访问超过三次,返回False验证失败 #5 SimpleRateThrottle 源码分析 - SimpleRateThrottle内部一定有:allow_request - - - 》 def allow_request( self , request, view): # 咱们没写,以后咱们可以在频率类中直接写 # rate='3/m' 以后不用写scope了,就会按一分钟访问3次现在 if self .rate is None : return True # 这里的self.rate是在对象初始化时候赋值的----> self.get_rate(): if not getattr ( self , 'scope' , None ): msg = ( "You must set either `.scope` or `.rate` for '%s' throttle" % self .__class__.__name__) raise ImproperlyConfigured(msg) try : return self .THROTTLE_RATES[ self .scope] # 这里去配置文件中拿频率的配置,如3/m except KeyError: msg = "No default throttle rate set for '%s' scope" % self .scope raise ImproperlyConfigured(msg) # 初始化代码: def __init__( self ): if not getattr ( self , 'rate' , None ): self .rate = self .get_rate() # 拿到3/m self .num_requests, self .duration = self .parse_rate( self .rate) # self.parse_rate解压赋值 ###self.parse_rate(self.rate): if rate is None : return ( None , None ) num, period = rate.split( '/' ) # 以/切割,num=3,period=m num_requests = int (num) duration = { 's' : 1 , 'm' : 60 , 'h' : 3600 , 'd' : 86400 }[period[ 0 ]] # 字典按照k取值d[m]--->60 return (num_requests, duration) # 取出:重写的get_cache_key返回的值,咱们返回了访问者ip self .key = self .get_cache_key(request, view) if self .key is None : # 重写get_cache_key方法,ruturn ip,这里key是ip return True # 根据当前访问者ip,取出 这个人的访问时间列表 [访问时间1,访问2,访问3,访问4] self .history = self .cache.get( self .key, []) # 取出当前时间 self .now = self .timer() # 把访问时间列表中超过 限制时间外的时间剔除 while self .history and self .history[ - 1 ] < = self .now - self .duration: self .history.pop() # 判断访问时间列表是否大于 3 if len ( self .history) > = self .num_requests: return self .throttle_failure() return self .throttle_success() |
五、 排序使用
1、导入模块 OrderingFilter
1 | from rest_framework.filters import OrderingFilter |
导入模块之后,写一个类属性,指定按哪个字段排序
1 2 | filter_backends = [OrderingFilter] ordering_fields = [ 'id' , 'user_type' ] |
2、views
这个案例中,能排序的前提是视图类至少继承 GenericViewSet + ListModelMixin, 和数据库打交道,需要序列化数据返回
解释:GenericViewSet 视图类继承了ViewSetMixin(重写了as_view, 改变了路由的新写法)+ GenericAPIView。ListModelMixin 重写了list(查询所有)方法
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | class UserView(GenericViewSet, ListModelMixin): throttle_classes = [] authentication_classes = [] permission_classes = [] queryset = User.objects. all () serializer_class = UserSerializer ## 排序功能 filter_backends = [OrderingFilter] ordering_fields = [ 'id' , 'user_type' ] @action (methods = [ 'POST' ], detail = False ) # /user/login/ post 请求就会执行 def login( self , request, * args, * * kwargs): # 前端传入用户名密码 username = request.data.get( 'username' ) password = request.data.get( 'password' ) user = User.objects. filter (username = username, password = password).first() if user: # 生成一个随机字符串,返回给前端,并且要把随机字符串存到token表中 # 随机字符串使用uuid生成 token = str (uuid.uuid4()) # 把随机字符串存到token表中会有两种情况(如果之前没有登录过就是新增,如果之前登录过修改) # 先去UserToken表中,根据user查,如果能查到,就修改,查不到就新增一条记录 ## 方式二:通过user去UserToken表中查,如果能查到用defaults的更新,如果查不到,就用user和defaults新增一条记录 UserToken.objects.update_or_create(defaults = { 'token' : token}, user = user) return Response({ 'code' : 100 , 'msg' : '登录成功' , 'token' : token, 'username' : user.username}) else : return Response({ 'code' : 101 , 'msg' : '用户名或密码错误' }) |
3、访问方式
get方法查询所有用户,在url地址中params传参数
先按照用户类型升序排,再按照id降序排
1 | 127.0 . 0.1 : 8000 / user / ?ordering = user_type, id |
六、过滤的三种方式
1、方式一:内置模块
1 2 3 4 5 6 | # 查询方式http://127.0.0.1:8000/books/?search=29 # 模糊匹配: 只要名字中有29或价格中有29都能搜出来 from rest_framework.filters import SearchFilter filter_backends = [SearchFilter, OrderingFilter] search_fields = [ 'name' , 'price' ] |
2、方式二 : 第三方模块
1 2 3 4 5 6 7 8 9 10 | # http://127.0.0.1:8000/books/?name=红楼梦 # http://127.0.0.1:8000/books/?price=19&name=西游记 # 安装: pip install django - filter # 导入模块 from django_filters.rest_framework import DjangoFilterBackend filter_backends = [DjangoFilterBackend] filterset_fields = [ 'name' , 'price' ] |
3、方式三:自定义一个类
自定义filter类:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | from rest_framework.filters import BaseFilterBackend from django.db.models import Q class MyFilter(BaseFilterBackend): def filter_queryset( self , request, queryset, view): # 基于queryset 进行过滤,过滤后返回即可 # http://127.0.0.1:8000/books/?name=书 # 名字中有书的就查出来 search_param = request.query_params.get( 'name' ) price = request.query_params.get( 'price' ) if search_param and price: queryset = queryset. filter (Q(name__contains = search_param) | Q(price = price)) # qs对象的filter # queryset = queryset.filter(name__contains=search_param, price=price) # qs对象的filter # name__contains 基于双下滑线的模糊查询 return queryset |
补充:
1、注意这里的Q查询,Q查询可以将多个查询条件组合成更复杂的逻辑表达式,包括 AND、OR、NOT 等。这里用的是OR关系
2、name__contains=search_param, price=price 基于双下线的模糊查询,AND 关系
使用
1 2 | from .filters import MyFilter filter_backends = [MyFilter] |
过滤源码:为什么在视图类中配置一个过滤类,就能走?
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | - filter_backends = [SearchFilter,MyFilter] - GenericAPIView:继承APIVIew的视图类,是不能这样配置的 - - - - 》自己过滤 - filter_backends = api_settings.DEFAULT_FILTER_BACKENDS - 还需要继承:ListModelMixin - - - 》它的 list 方法中,在对所有数据做过滤:queryset = self .filter_queryset( self .get_queryset()) - self .filter_queryset如何做的过滤呢?ListModelMixin类中没有这个方法,最终从GenericAPIView中找到了 - GenericAPIView的filter_queryset干了啥事? def filter_queryset( self , queryset): # 过滤类 for backend in list ( self .filter_backends): # filter_backends视图类中配置的过滤类,列表 # 过滤类加括号---》 过滤类的对象---》调用过滤类对象的filter_queryset queryset = backend().filter_queryset( self .request, queryset, self ) return queryset # 继承APIView---》写过滤---》可以复制GenericAPIView一些方法和属性,让我们少写代码 |
七、
查询list方法
1 2 3 4 5 6 7 8 9 10 11 12 13 | class BookView(ViewSetMixin, APIView): def list ( self , request, * args, * * kwargs): # 按照名字过滤 name = request.query_params.get( 'name' ) # 前端 book_list = Book.objects. all (). filter (name__contains = name, price__gt = 150 ) ser = BookSerializer(instance = book_list, many = True ) return Response(ser.data) ### 序列化类 class BookSerializer(serializers.ModelSerializer): class Meta: model = Book fields = '__all__' |
效果:127.0.0.1:8000/books/?name=国&prince>150
八、
1 2 3 4 5 6 7 8 | class BookView(GenericViewSet, ListModelMixin): queryset = Book.objects. all () serializer_class = BookSerializer # 过滤 127.0.0.1:8000/books/?search=123&name=平凡的世界2&price=123 filter_backends = [SearchFilter, MyFilter] # 两个过滤类 search_fields = [ 'name' , 'price' ] # 访问:http://127.0.0.1:8000/api/v1/books/?search=33&ordering=price |
九、五个图书接口设置权限
1、views
要求所有接口必须登录后才能访问,限制普通登录用户只能查看所有和新增一条,超级用户能查看一条,删除,修改
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | from .models import Book from .serializer import BookSerializer from rest_framework.generics import ListCreateAPIView, RetrieveUpdateDestroyAPIView from .permissions import UserPermission from .auth import LoginAuth class BookView(ViewSetMixin, ListCreateAPIView): authentication_classes = [LoginAuth] queryset = Book.objects. all () serializer_class = BookSerializer class BookDetailView(ViewSetMixin, RetrieveUpdateDestroyAPIView): queryset = Book.objects. all () serializer_class = BookSerializer permission_classes = [UserPermission] authentication_classes = [LoginAuth] |
2、url
1 2 3 4 5 6 7 8 9 10 11 12 | from django.urls import path, include from app01.views import UserView, BookView, BookDetailView from rest_framework.routers import SimpleRouter, DefaultRouter router = SimpleRouter() router.register( 'books' , BookView, 'books' ) # 查询所有的权限和新增 router.register( 'books1' , BookDetailView, 'books1' ) # 查询一条、删除、更新 urlpatterns = [ path( 'admin/' , admin.site.urls), path('', include(router.urls)), ] |