Django REST framework 自定义(认证、权限、访问频率)组件
本篇随笔在 "Django REST framework 初识" 基础上扩展
一、认证组件
# models.py class Account(models.Model): """用户表""" username = models.CharField(verbose_name="用户名", max_length=64, unique=True) password = models.CharField(verbose_name="密码", max_length=64) class UserToken(models.Model): """用户Token表""" user = models.OneToOneField(to="Account") token = models.CharField(max_length=64, unique=True)
当然也可以使用django自带的 auth_user 表来保存用户信息,Token表一对一关联这张表或者继承这张表:
from django.contrib.auth.models import User class Token(models.Model): user = models.OneToOneField(User) token = models.CharField(max_length=64) from django.contrib.auth.models import AbstractUser class Token(AbstractUser): token = models.CharField(max_length=64)
auth.py
from rest_framework import authentication from rest_framework import exceptions from api import models class UserTokenAuth(authentication.BaseAuthentication): """用户身份认证""" def authenticate(self, request): token = request.query_params.get("token") obj = models.UserToken.objects.filter(token=token).first() if not obj: raise exceptions.AuthenticationFailed({"code": 200, "error": "用户身份认证失败!"}) else: return (obj.user.username, obj)
Views.py
import time import hashlib from rest_framework import viewsets from rest_framework.views import APIView from rest_framework.response import Response from django.core.exceptions import ObjectDoesNotExist from api import models from appxx import serializers from appxx.auth.auth import UserTokenAuth class LoginView(APIView): """ 用户认证接口 """ def post(self, request, *args, **kwargs): rep = {"code": 1000} username = request.data.get("username") password = request.data.get("password") try: user = models.Account.objects.get(username=username, password=password) token = self.get_token(user.password) rep["token"] = token models.UserToken.objects.update_or_create(user=user, defaults={"token": token}) except ObjectDoesNotExist as e: rep["code"] = 1001 rep["error"] = "用户名或密码错误" except Exception as e: rep["code"] = 1002 rep["error"] = "发生错误,请重试" return Response(rep) @staticmethod def get_token(password): timestamp = str(time.time()) md5 = hashlib.md5(bytes(password, encoding="utf-8")) md5.update(bytes(timestamp, encoding="utf-8")) return md5.hexdigest() class BookViewSet(viewsets.ModelViewSet): authentication_classes = [utils.AuthToken] queryset = models.Book.objects.all() serializer_class = serializers.BookSerializer
urls.py
from django.conf.urls import url, include from rest_framework.routers import DefaultRouter from appxx import views router = DefaultRouter() router.register(r"books", views.BookViewSet) router.register(r"publishers", views.PublisherViewSet) urlpatterns = [ url(r"^login/$", views.LoginView.as_view(), name="login"), url(r"", include(router.urls)), ]
局部认证(哪个视图类需要认证就在哪加上)
如果需要每条URL都加上身份认证,那么是不是views.py中每个对应的类视图都加上authentication_classes呢?那多麻烦,有没有更简便的方法?请看下面如何设置全局的认证。
全局认证
在settings.py中设置:
REST_FRAMEWORK = { "DEFAULT_AUTHENTICATION_CLASSES": ["appxx.utils.TokenAuthentication",], # "UNAUTHENTICATED_USER": None, # 匿名,request.user = None # "UNAUTHENTICATED_TOKEN": None, # 匿名,request.auth = None
}
可以看到,AuthToken 就是 BookViewSet 用到的 authentication_classes,这样views.py中的每个类视图都不需要加 authentication_classes 了;每条URL都必须经过此认证才能访问。
class BookViewSet(viewsets.ModelViewSet): queryset = models.Book.objects.all() serializer_class = serializers.BookSerializer class PublisherViewSet(viewsets.ModelViewSet): queryset = models.Publisher.objects.all() serializer_class = serializers.PublisherSerializer
二、权限组件
修改模型表,给用户加上用户类型字段:
class UserProfile(models.Model): username = models.CharField(verbose_name="用户名", max_length=16) password = models.CharField(verbose_name="密码", max_length=64) user_type_choices = ((1, "管理员"), (2, "普通用户"), (3, "VIP")) user_type = models.SmallIntegerField(choices=user_type_choices, default=2)
class UserTypePermission(permissions.BasePermission): """权限认证""" message = "只有管理员才能访问" def has_permission(self, request, view): user = request.user try: user_type = models.UserProfile.objects.filter(username=user).first().user_type except AttributeError: return False if user_type == 1: return True else: return False
局部权限
class BookViewSet(viewsets.ModelViewSet): permission_classes = [utils.UserTypePermission] queryset = models.Book.objects.all() serializer_class = serializers.BookSerializer
全局权限
REST_FRAMEWORK = {"DEFAULT_PERMISSION_CLASSES": ["appxx.utils.UserTypePermission",], }
三、访问频率组件
import time visit_record = {} # 可以放在redis中 class IpRateThrottle(object): """60s内只能访问3次""" def __init__(self): self.history = None def allow_request(self, request, view): ip = request.META.get("REMOTE_ADDR") # 获取用户IP current_time = time.time() if ip not in visit_record: # 用户第一次访问 visit_record[ip] = [current_time] return True history = visit_record.get(ip) self.history = history while history and history[-1] < current_time - 60: history.pop() if len(history) < 3: history.insert(0, current_time) return True # return True # 表示可以继续访问 # return False # 表示访问频率太高,被限制 def wait(self): """还需要等多久才能访问""" current_time = time.time() return 60 - (current_time - self.history[-1])
局部节流
class BookViewSet(viewsets.ModelViewSet): throttle_classes = [IpRateThrottle] queryset = models.Book.objects.all() serializer_class = serializers.BookSerializer
全局节流
REST_FRAMEWORK = { "DEFAULT_THROTTLE_CLASSES": ["appxx.utils.IpRateThrottle",], }
PS:
匿名用户:无法控制,因为用户可以换代理IP
登录用户:如果有很多账号,也无法限制