drf 认证、权限、频率三组件、排序、过滤
一、三大认证
要使用三大认证的功能,视图类至少要继承 APIview 视图类。因为APIview视图类去除了csrf认证、封装新的request、走三大认证和全局异常
1、支持序列化的视图类中,比如APIview,源码中view---> self.dispatch--->self.initial---> initial()
def initial(self, request, *args, **kwargs): self.format_kwarg = self.get_format_suffix(**kwargs) neg = self.perform_content_negotiation(request) request.accepted_renderer, request.accepted_media_type = neg version, scheme = self.determine_version(request, *args, **kwargs) request.version, request.versioning_scheme = version, scheme # Ensure that the incoming request is permitted self.perform_authentication(request) self.check_permissions(request) self.check_throttles(request)
二、认证组件
判断用户是否登录,数据库是否有值
1、需求:
通过认证组件去认证,没有认证通过的用户不让登录。认证方式前端发来的token值与数据库进行对比
2、models
from django.db import models class User(models.Model): username = models.CharField(max_length=32) password = models.CharField(max_length=32) # 用户类型 user_type = models.IntegerField(choices=((1, '2B用户'), (2, '普通用户'), (3, '超级用户'))) # 一对一的关系 class UserToken(models.Model): token = models.CharField(max_length=64) # OneToOneField本质就是ForeignKey+unique user = models.OneToOneField(to=User, on_delete=models.CASCADE) # user = models.ForeignKey(to=User, unique=True,on_delete=models.CASCADE)
2、写一个认证 auth.py
继承 BaseAuthentication
from .models import UserToken from rest_framework.authentication import BaseAuthentication from rest_framework.exceptions import AuthenticationFailed class LoginAuth(BaseAuthentication): # 重写一下authenticate方法 def authenticate(self, request): token = request.query_params.get('token') # 从url中的参数去拿 print(token) #4d1455f4-97be-4cf6-88a1-dee6e21433ad user_token = UserToken.objects.filter(token=token).first() if user_token: return user_token.user, token else: raise AuthenticationFailed('您没有登录!')
补充:
token从请求头中去拿
token = request.META.get('HTTP_TOKEN')
token从 请求体中取
token = request.data.get('token')
3、views
class UserView(ViewSet): # authentication_classes = [] # 局部解除认证禁用,token没有认证通过就不让登录 permission_classes = [] # 局部解除权限禁用 @action(methods=['POST'], detail=False) # /user/login/ post 请求就会执行 def login(self, request, *args, **kwargs): # 前端传入用户名密码 username = request.data.get('username') password = request.data.get('password') user = User.objects.filter(username=username, password=password).first() print(user) if user: # 生成一个随机字符串,返回给前端,并且要把随机字符串存到token表中 token = str(uuid.uuid4()) ##### 方式一:麻烦方式 # user_token=UserToken.objects.filter(user=user).first() # if user_token: # user_token.token=token # user_token.save() # else: # UserToken.objects.create(user=user,token=token) ## 方式二:通过user去UserToken表中查,如果能查到用defaults的更新,如果查不到,就用user和defaults新增一条记录 # 每次登录会返回一个uuid,会更新uuid UserToken.objects.update_or_create(defaults={'token': token}, user=user) return Response({'code': 100, 'msg': '登录成功', 'token': token, 'username': user.username}) else: return Response({'code': 101, 'msg': '用户名或密码错误'})
补充:
1、继承 ViewSetMixin,APIView,常见的5个路由自动匹配,其他的通过action装饰器进行指定
4、开启认证
在全局开启认证时需要指定app01下面的auth认证模块
## 全局开启 REST_FRAMEWORK = { 'DEFAULT_AUTHENTICATION_CLASSES': [ 'app01.auth.LoginAuth' ], } ##局部开启 authentication_classes = [authenticate] ##局部禁用 authentication_classes = []
5、拿着数据库的token访问
127.0.0.1:8000/user/login/?token=730c7ebf-4b12-4f55-a337-0b341d333395
三、权限组件
判断登录成功后的用户是否有操作权限
1、permissions 文件
继承: BasePermission
from rest_framework.permissions import BasePermission # 1 写一个类,继承 BasePermission # 2 重写 has_permission # 3 在方法中校验用户是否有权限,如果有,就返回True,如果没有,就返回False class UserPermission(BasePermission): def has_permission(self, request, view): # request 当次请求的request, 新的,它是在认证类之后执行的,如果认证通过了request.user 就是当前登录用户 # 拿到当前登录用户,查看它的类型,确定有没有权限 if request.user.user_type == 3: return True else: self.message = '您的用户类型是:%s,您没有权限操作' % (request.user.get_user_type_display()) return False
2、views
from rest_framework.viewsets import ViewSetMixin, ViewSet from rest_framework.response import Response from rest_framework.decorators import action from .models import User, UserToken import uuid # class UserView(ViewSetMixin,APIView): class UserView(ViewSet): # authentication_classes = [] # 局部解除认证禁用,token没有认证通过就不让登录 # permission_classes = [] # 局部解除权限禁用 @action(methods=['POST'], detail=False) # /user/login/ post 请求就会执行 def login(self, request, *args, **kwargs): # 前端传入用户名密码 username = request.data.get('username') password = request.data.get('password') user = User.objects.filter(username=username, password=password).first() print(username) if user: # 生成一个随机字符串,返回给前端,并且要把随机字符串存到token表中 token = str(uuid.uuid4()) ##### 方式一:麻烦方式 # user_token=UserToken.objects.filter(user=user).first() # if user_token: # user_token.token=token # user_token.save() # else: # UserToken.objects.create(user=user,token=token) ## 方式二:通过user去UserToken表中查,如果能查到用defaults的更新,如果查不到,就用user和defaults新增一条记录 # 每次登录会返回一个uuid,会更新uuid UserToken.objects.update_or_create(defaults={'token': token}, user=user) return Response({'code': 100, 'msg': '登录成功', 'token': token, 'username': user.username}) else: return Response({'code': 101, 'msg': '用户名或密码错误'})
3、开启权限认证
## 全局认证 REST_FRAMEWORK = { 'DEFAULT_PERMISSION_CLASSES': [ 'app01.permissions.UserPermission', ], } ## 局部认证 permission_classes = [UserPermission] ##局部解除权限认证 permission_classes = []
4、访问效果
四、频率组件
限制用户ip频繁访问
1、throtting 文件
继承:SimpleRateThrottle
from rest_framework.throttling import SimpleRateThrottle class IPRateThrottle(SimpleRateThrottle): scope = 'auth1' # scope: 范围 def get_cache_key(self, request, view): print(request.META) return request.META.get('REMOTE_ADDR') # REMOTE_ADDR 远程的ip地址
补充:
1 重写SimpleRateThrottle中的get_cache_key方法,这个方法返回什么就会用什么去做限制
2 以用户id做限制
def get_cache_key(self, request, view): # 重写get_cache_key,返回什么,就以什么做限制: IP地址,用户id限制 # print(request.META) # request.META 是一个包含有关HTTP请求的元数据的字典 # return request.META.get('REMOTE_ADDR') # REMOTE_ADDR 远程的ip地址 try: if request.user: user_id = request.user.pk print('用户id是:', user_id) return user_id except Exception: raise AuthenticationFailed('您没有登录,请先登录吧')
3 scope 是一个类属性,它的值用在全局配置中
'DEFAULT_THROTTLE_RATES': { 'auth1': '3/m', # 一分钟访问三次 },
2、views
class UserView(ViewSet): authentication_classes = [] # 局部解除认证禁用,token没有认证通过就不让登录 permission_classes = [] # 局部解除权限禁用 @action(methods=['POST'], detail=False) # /user/login/ post 请求就会执行 def login(self, request, *args, **kwargs): # 前端传入用户名密码 username = request.data.get('username') password = request.data.get('password') user = User.objects.filter(username=username, password=password).first() print(username) if user: # 生成一个随机字符串,返回给前端,并且要把随机字符串存到token表中 token = str(uuid.uuid4()) ##### 方式一:麻烦方式 # user_token=UserToken.objects.filter(user=user).first() # if user_token: # user_token.token=token # user_token.save() # else: # UserToken.objects.create(user=user,token=token) ## 方式二:通过user去UserToken表中查,如果能查到用defaults的更新,如果查不到,就用user和defaults新增一条记录 # 每次登录会返回一个uuid,会更新uuid UserToken.objects.update_or_create(defaults={'token': token}, user=user) return Response({'code': 100, 'msg': '登录成功', 'token': token, 'username': user.username}) else: return Response({'code': 101, 'msg': '用户名或密码错误'})
3、开启频率认证
## 全局认证 REST_FRAMEWORK = { 'DEFAULT_THROTTLE_RATES': { 'auth1': '3/m', # 一分钟访问三次 }, 'DEFAULT_THROTTLE_CLASSES': ['app01.throtting.IPRateThrottle'], } ## 局部认证 throttle_classes = [IPRateThrottle] ##局部解除权限认证 throttle_classes = []
4、效果
补充:自定义一个频率类
# 自定义的逻辑 #(1)取出访问者ip #(2)判断当前ip不在访问字典里,添加进去,并且直接返回True,表示第一次访问,在字典里,继续往下走 #(3)循环判断当前ip的列表,有值,并且当前时间减去列表的最后一个时间大于60s,把这种数据pop掉,这样列表中只有60s以内的访问时间, #(4)判断,当列表小于3,说明一分钟以内访问不足三次,把当前时间插入到列表第一个位置,返回True,顺利通过 #(5)当大于等于3,说明一分钟内访问超过三次,返回False验证失败 class MyThrottles(): VISIT_RECORD = {} def __init__(self): self.history=None def allow_request(self,request, view): #(1)取出访问者ip # print(request.META) ip=request.META.get('REMOTE_ADDR') import time ctime=time.time() # (2)判断当前ip不在访问字典里,添加进去,并且直接返回True,表示第一次访问 if ip not in self.VISIT_RECORD: self.VISIT_RECORD[ip]=[ctime,] return True self.history=self.VISIT_RECORD.get(ip) # (3)循环判断当前ip的列表,有值,并且当前时间减去列表的最后一个时间大于60s,把这种数据pop掉,这样列表中只有60s以内的访问时间, while self.history and ctime-self.history[-1]>60: self.history.pop() # (4)判断,当列表小于3,说明一分钟以内访问不足三次,把当前时间插入到列表第一个位置,返回True,顺利通过 # (5)当大于等于3,说明一分钟内访问超过三次,返回False验证失败 if len(self.history)<3: self.history.insert(0,ctime) return True else: return False def wait(self): import time ctime=time.time() return 60-(ctime-self.history[-1])
全局使用和局部使用
REST_FRAMEWORK = { 'DEFAULT_THROTTLE_CLASSES':['app01.utils.MyThrottles',], } #在视图类里使用 throttle_classes = [MyThrottles,]
补充:频率的源码分析
# 1 频率源码 -APIView----disaptch---》self.initial(request, *args, **kwargs)---》416行:self.check_throttles(request)----》352行 check_throttles def check_throttles(self, request): # self.get_throttles()就是咱们配置在视图类上频率类的对象列表[频率类对象,] for throttle in self.get_throttles(): # 执行频率类对象的allow_request,传了2个,返回True或False if not throttle.allow_request(request, self): # 反会给前端失败,显示还剩多长时间能再访问 throttle_durations.append(throttle.wait()) # 2 频率类要写 1 写一个类,继承,BaseThrottle 2 在类中重写:allow_request方法,传入 3个参数 3 在allow_request写限制逻辑,如果还能访问--》返回True 4 如果超了次数,就不能访问,返回False 5 局部配置在视图类上 6 全局配置在配置文件中 # 3 我们在drf中写的时候,不需要继承 BaseThrottle,继承了SimpleRateThrottle,重写get_cache_key -我们猜测:一定是 SimpleRateThrottle帮咱们写了咱们需要写的 # 4 自定义频率类,实现一分钟只能访问三次的控制: # (1)取出访问者ip # (2)判断当前ip不在访问字典里,添加进去,并且直接返回True,表示第一次访问,在字典里,继续往下走 # (3)循环判断当前ip的列表,有值,并且当前时间减去列表的最后一个时间大于60s,把这种数据pop掉,这样列表中只有60s以内的访问时间, # (4)判断,当列表小于3,说明一分钟以内访问不足三次,把当前时间插入到列表第一个位置,返回True,顺利通过 # (5)当大于等于3,说明一分钟内访问超过三次,返回False验证失败 #5 SimpleRateThrottle 源码分析 - SimpleRateThrottle内部一定有:allow_request---》 def allow_request(self, request, view): # 咱们没写,以后咱们可以在频率类中直接写 # rate='3/m' 以后不用写scope了,就会按一分钟访问3次现在 if self.rate is None: return True # 这里的self.rate是在对象初始化时候赋值的----> self.get_rate(): if not getattr(self, 'scope', None): msg = ("You must set either `.scope` or `.rate` for '%s' throttle" % self.__class__.__name__) raise ImproperlyConfigured(msg) try: return self.THROTTLE_RATES[self.scope] # 这里去配置文件中拿频率的配置,如3/m except KeyError: msg = "No default throttle rate set for '%s' scope" % self.scope raise ImproperlyConfigured(msg) # 初始化代码: def __init__(self): if not getattr(self, 'rate', None): self.rate = self.get_rate() # 拿到3/m self.num_requests, self.duration = self.parse_rate(self.rate) # self.parse_rate解压赋值 ###self.parse_rate(self.rate): if rate is None: return (None, None) num, period = rate.split('/') # 以/切割,num=3,period=m num_requests = int(num) duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]] # 字典按照k取值d[m]--->60 return (num_requests, duration) # 取出:重写的get_cache_key返回的值,咱们返回了访问者ip self.key = self.get_cache_key(request, view) if self.key is None: # 重写get_cache_key方法,ruturn ip,这里key是ip return True # 根据当前访问者ip,取出 这个人的访问时间列表 [访问时间1,访问2,访问3,访问4] self.history = self.cache.get(self.key, []) # 取出当前时间 self.now = self.timer() # 把访问时间列表中超过 限制时间外的时间剔除 while self.history and self.history[-1] <= self.now - self.duration: self.history.pop() # 判断访问时间列表是否大于 3 if len(self.history) >= self.num_requests: return self.throttle_failure() return self.throttle_success()
五、 排序使用
1、导入模块 OrderingFilter
from rest_framework.filters import OrderingFilter
导入模块之后,写一个类属性,指定按哪个字段排序
filter_backends = [OrderingFilter] ordering_fields = ['id', 'user_type']
2、views
这个案例中,能排序的前提是视图类至少继承 GenericViewSet + ListModelMixin, 和数据库打交道,需要序列化数据返回
解释:GenericViewSet 视图类继承了ViewSetMixin(重写了as_view, 改变了路由的新写法)+ GenericAPIView。ListModelMixin 重写了list(查询所有)方法
class UserView(GenericViewSet, ListModelMixin): throttle_classes = [] authentication_classes = [] permission_classes = [] queryset = User.objects.all() serializer_class = UserSerializer ## 排序功能 filter_backends = [OrderingFilter] ordering_fields = ['id', 'user_type'] @action(methods=['POST'], detail=False) # /user/login/ post 请求就会执行 def login(self, request, *args, **kwargs): # 前端传入用户名密码 username = request.data.get('username') password = request.data.get('password') user = User.objects.filter(username=username, password=password).first() if user: # 生成一个随机字符串,返回给前端,并且要把随机字符串存到token表中 # 随机字符串使用uuid生成 token = str(uuid.uuid4()) # 把随机字符串存到token表中会有两种情况(如果之前没有登录过就是新增,如果之前登录过修改) # 先去UserToken表中,根据user查,如果能查到,就修改,查不到就新增一条记录 ## 方式二:通过user去UserToken表中查,如果能查到用defaults的更新,如果查不到,就用user和defaults新增一条记录 UserToken.objects.update_or_create(defaults={'token': token}, user=user) return Response({'code': 100, 'msg': '登录成功', 'token': token, 'username': user.username}) else: return Response({'code': 101, 'msg': '用户名或密码错误'})
3、访问方式
get方法查询所有用户,在url地址中params传参数
先按照用户类型升序排,再按照id降序排
127.0.0.1:8000/user/?ordering=user_type,id
六、过滤的三种方式
1、方式一:内置模块
# 查询方式http://127.0.0.1:8000/books/?search=29 # 模糊匹配: 只要名字中有29或价格中有29都能搜出来 from rest_framework.filters import SearchFilter filter_backends = [SearchFilter, OrderingFilter] search_fields = ['name', 'price']
2、方式二 : 第三方模块
# http://127.0.0.1:8000/books/?name=红楼梦 # http://127.0.0.1:8000/books/?price=19&name=西游记 # 安装: pip install django-filter # 导入模块 from django_filters.rest_framework import DjangoFilterBackend filter_backends = [DjangoFilterBackend] filterset_fields=['name','price']
3、方式三:自定义一个类
自定义filter类:
from rest_framework.filters import BaseFilterBackend from django.db.models import Q class MyFilter(BaseFilterBackend): def filter_queryset(self, request, queryset, view): # 基于queryset 进行过滤,过滤后返回即可 # http://127.0.0.1:8000/books/?name=书 # 名字中有书的就查出来 search_param = request.query_params.get('name') price = request.query_params.get('price') if search_param and price: queryset = queryset.filter(Q(name__contains=search_param) | Q(price=price)) # qs对象的filter # queryset = queryset.filter(name__contains=search_param, price=price) # qs对象的filter # name__contains 基于双下滑线的模糊查询 return queryset
补充:
1、注意这里的Q查询,Q查询可以将多个查询条件组合成更复杂的逻辑表达式,包括 AND、OR、NOT 等。这里用的是OR关系
2、name__contains=search_param, price=price 基于双下线的模糊查询,AND 关系
使用
from .filters import MyFilter filter_backends = [MyFilter]
过滤源码:为什么在视图类中配置一个过滤类,就能走?
-filter_backends = [SearchFilter,MyFilter] -GenericAPIView:继承APIVIew的视图类,是不能这样配置的----》自己过滤 -filter_backends = api_settings.DEFAULT_FILTER_BACKENDS -还需要继承:ListModelMixin---》它的list方法中,在对所有数据做过滤:queryset = self.filter_queryset(self.get_queryset()) -self.filter_queryset如何做的过滤呢?ListModelMixin类中没有这个方法,最终从GenericAPIView中找到了 -GenericAPIView的filter_queryset干了啥事? def filter_queryset(self, queryset): # 过滤类 for backend in list(self.filter_backends): # filter_backends视图类中配置的过滤类,列表 # 过滤类加括号---》 过滤类的对象---》调用过滤类对象的filter_queryset queryset = backend().filter_queryset(self.request, queryset, self) return queryset # 继承APIView---》写过滤---》可以复制GenericAPIView一些方法和属性,让我们少写代码
七、
查询list方法
class BookView(ViewSetMixin, APIView): def list(self, request, *args, **kwargs): # 按照名字过滤 name = request.query_params.get('name') # 前端 book_list = Book.objects.all().filter(name__contains=name, price__gt=150) ser = BookSerializer(instance=book_list, many=True) return Response(ser.data) ### 序列化类 class BookSerializer(serializers.ModelSerializer): class Meta: model = Book fields = '__all__'
效果:127.0.0.1:8000/books/?name=国&prince>150
八、
class BookView(GenericViewSet, ListModelMixin): queryset = Book.objects.all() serializer_class = BookSerializer # 过滤 127.0.0.1:8000/books/?search=123&name=平凡的世界2&price=123 filter_backends = [SearchFilter, MyFilter] # 两个过滤类 search_fields = ['name', 'price'] # 访问:http://127.0.0.1:8000/api/v1/books/?search=33&ordering=price
九、五个图书接口设置权限
1、views
要求所有接口必须登录后才能访问,限制普通登录用户只能查看所有和新增一条,超级用户能查看一条,删除,修改
from .models import Book from .serializer import BookSerializer from rest_framework.generics import ListCreateAPIView, RetrieveUpdateDestroyAPIView from .permissions import UserPermission from .auth import LoginAuth class BookView(ViewSetMixin, ListCreateAPIView): authentication_classes = [LoginAuth] queryset = Book.objects.all() serializer_class = BookSerializer class BookDetailView(ViewSetMixin, RetrieveUpdateDestroyAPIView): queryset = Book.objects.all() serializer_class = BookSerializer permission_classes = [UserPermission] authentication_classes = [LoginAuth]
2、url
from django.urls import path, include from app01.views import UserView, BookView, BookDetailView from rest_framework.routers import SimpleRouter, DefaultRouter router = SimpleRouter() router.register('books', BookView, 'books') # 查询所有的权限和新增 router.register('books1', BookDetailView, 'books1') # 查询一条、删除、更新 urlpatterns = [ path('admin/', admin.site.urls), path('', include(router.urls)), ]