DRF之三大认证
【一】三大认证执行顺序


【二】认证
| |
| from rest_framework.authentication import BaseAuthentication |
【1】源码


【2】认证类的使用
| |
| |
| from rest_framework.authentication import BaseAuthentication |
| |
| from rest_framework.exceptions import AuthenticationFailed |
| |
| |
| class UserAuthenticate(BaseAuthentication): |
| |
| def authenticate(self, request): |
| ''' |
| 进行登录的判断,比如是否携带了token,或是否携带了证明身份信息的东西 |
| ''' |
| |
| if 'token校验失败': |
| |
| raise AuthenticationFailed('请检查token') |
| |
| return user, token |
- 局部使用和全局使用
- 可以通过
authentication_classes = []
实现局部禁用
- 查找验证类的顺序为,先查找类属性中的,再查找项目配置中的,最后去drf默认配置中查找
| |
| |
| class 视图类(ViewSet): |
| |
| authentication_classes = ['认证类'] |
| |
| |
| |
| |
| REST_FRAMEWORK = { |
| 'DEFAULT_AUTHENTICATION_CLASSES': [ |
| |
| |
| ], |
| } |
【3】实例
| |
| |
| |
| class UserViewV2(ViewSetMixin, ListCreateAPIView): |
| queryset = UserInfo.objects.all() |
| authentication_classes = [] |
| |
| @action(methods=['POST'], detail=False) |
| def login(self, request): |
| username = request.data.get('username') |
| password = request.data.get('password') |
| |
| user_obj = auth.authenticate(username=username, password=password) |
| if not user_obj: |
| return Response({'code': 101, 'msg': '登录失败!用户名或密码错误'}) |
| user_token = uuid.uuid4() |
| |
| UserToken.objects.update_or_create(defaults={'token': user_token}, user=user_obj) |
| return Response({'code': 100, 'msg': '登录成功!', 'token': user_token}) |
| |
| |
| |
| class UserEditViewV2(ViewSetMixin, RetrieveUpdateDestroyAPIView): |
| queryset = UserInfo.objects.all() |
| serializer_class = UpdatePasswordSerializerV2 |
| |
| authentication_classes = [UserAuthenticate] |
| |
| @action(methods=['PUT'], detail=False) |
| def password(self, request, *args, **kwargs): |
| ser = self.get_serializer(instance=request.user, data=request.data) |
| ser.is_valid(raise_exception=True) |
| ser.save() |
| return Response({'code': 100, 'msg': '修改成功'}) |
| |
| from rest_framework.authentication import BaseAuthentication |
| from rest_framework.exceptions import AuthenticationFailed |
| from .models import UserToken |
| |
| |
| |
| class UserAuthenticate(BaseAuthentication): |
| |
| def authenticate(self, request): |
| token = request.META.get('HTTP_TOKEN') |
| user_token_obj = UserToken.objects.filter(token=token).first() |
| if not user_token_obj: |
| raise AuthenticationFailed('请检查token') |
| user = user_token_obj.user |
| return user, token |
【三】权限
| |
| from rest_framework.permissions import BasePermission |
【1】源码

【2】权限类的使用
| |
| class CommonPermission(BasePermission): |
| def has_permission(self, request, view): |
| ''' |
| :param request: 当前请求request对象 |
| :param view: 视图类 |
| :return: 布尔值 |
| ''' |
| |
| |
| |
| self.message = '可以指定提示信息' |
| self.code = '可以指定返回的响应码' |
| return True |
| return False |
- 局部使用和全局使用
- 可以通过
permission_classes= []
实现局部禁用
| |
| |
| class 视图类(ViewSet): |
| authentication_classes = ['认证类'] |
| permission_classes = ['权限类'] |
| |
| |
| |
| |
| REST_FRAMEWORK = { |
| 'DEFAULT_AUTHENTICATION_CLASSES': [ |
| |
| 'app001.authenticate.UserAuthenticate' |
| ], |
| 'DEFAULT_PERMISSION_CLASSES': [ |
| |
| 'app001.permissions.CommonPermission' |
| ], |
| } |
【3】实例
| |
| |
| class CommonPermission(BasePermission): |
| def has_permission(self, request, view): |
| user = request.user |
| |
| if user.is_superuser: |
| return True |
| else: |
| |
| if view.basename == 'car_model' and request.method == 'DELETE': |
| return True |
| elif request.method != 'GET': |
| return False |
| else: |
| return True |
【四】频率
| |
| from rest_framework.throttling import BaseThrottle |
| |
| from rest_framework.throttling import SimpleRateThrottle |
【1】源码

【2】频率类的使用(SimpleRateThrottle)
| class CommonThrottle(SimpleRateThrottle): |
| |
| rate = '5/m' |
| |
| |
| def get_cache_key(self, request, view): |
| return '返回唯一用户访问的唯一标识 如ip,设备id号' |
【2.1】SimpleRateThrottle
实例
| |
| class CommonThrottle(SimpleRateThrottle): |
| rate = '5/d' |
| |
| def get_cache_key(self, request, view): |
| return request.META.get("REMOTE_ADDR") |
【2.2】SimpleRateThrottle
源码分析

【3】继承BaseThrottle
自定义频率类
- 重写频率类的,最关键方法就是
allow_request
方法,在其中构建限制逻辑即可
- 直接上实例,该实例大部分参考
SimpleRateThrottle
不必为仿照而瞧不起自己,读得懂源码,理清楚其中的逻辑也能帮我们很多
| |
| from rest_framework.throttling import BaseThrottle |
| |
| |
| class ExtendsThrottle(BaseThrottle): |
| |
| rate = '3/m' |
| history = {} |
| |
| def __init__(self): |
| self.count, self.duration = self.parse_rate(self.rate) |
| |
| def parse_rate(self, rate: str): |
| ''' |
| 将【'3/s'】解析成 次数和持续时间 |
| :param rate: '3/m' 每分钟限制访问3此 |
| :return: (次数,持续时间) |
| ''' |
| if not rate: |
| return None, None |
| time_dict = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400} |
| count, duration = rate.split('/') |
| return int(count), time_dict[duration] |
| |
| def get_ident(self, request): |
| |
| |
| return super().get_ident(request) |
| |
| def allow_request(self, request, view): |
| ''' |
| 允许访问的主要逻辑代码 |
| :return: 执行允许通过或不允许通过的方法 |
| ''' |
| if not self.rate: |
| |
| return True |
| ip = self.get_ident(request) |
| if ip not in self.history: |
| |
| self.history[ip] = [] |
| ip_history = self.history.get(ip) |
| |
| now = time.time() |
| |
| while len(ip_history) == self.count: |
| |
| |
| if now - ip_history[0] < self.duration: |
| return False |
| else: |
| |
| |
| ip_history.pop(0) |
| break |
| |
| ip_history.append(now) |
| self.history[ip] = ip_history |
| |
| return True |
| |
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 记一次.NET内存居高不下排查解决与启示
· DeepSeek 开源周回顾「GitHub 热点速览」
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了