Django Restframework

Content

Intro

 

 

Ⅰ Intro

djangorestframework模块为django提供了restful接口

安装 : pip install djangorestframework
    django settings.py   INSTALLED_APPS 中加入 "rest_framework"

 

 

Ⅱ 序列化

序列化是该模块的最基本功能,能够将表数据,转换为json字典以及json字符串

形式1

复制代码
from rest_framework import serializers 

class UserSerializer(serializers.Serializer):
    ut_title = serializers.CharField(source='ut.title')
    user = serializers.CharField(min_length=6)
    pwd = serializers.CharField(error_messages={'required': '密码不能为空'}, validators=[PasswordValidator()])
  
x1 = serializers.CharField(source='group.mu.name') # 多对一或一对一跨表查询
 
    def create(self, validated_data):
        """
        根据提供的验证过的数据创建并返回一个新的`Snippet`实例。用于保存
        """
        return Userinfo.objects.create(**validated_data)

    def update(self, instance, validated_data):
        """
        根据提供的验证过的数据更新和返回一个已经存在的`Snippet`实例。
        """
        instance.ut_title= validated_data.get('ut_title', instance.ut_title)
        instance.user= validated_data.get('user', instance.user)
        instance.pwd = validated_data.get('pwd', instance.pwd)
        instance.save()
        return instance
# 自定义序列化 # 这种形式类似于django的form组件

复制代码

形式2

复制代码
from rest_framework import serializers
class ModelUserSerializer(serializers.ModelSerializer):
    user = serializers.CharField(max_length=32)
    class Meta:
        model = models.UserInfo
        fields = "__all__"
        # fields = ['user', 'pwd', 'ut'] 和这样写相同
        depth = 2
        extra_kwargs = {'user': {'min_length': 6}, 'pwd': {'validators': [PasswordValidator(666), ]}}
        # read_only_fields = ['user']

# 这种形式更像是modelform组件
复制代码

 

复制代码
class NewhomeSerializer(serializers.ModelSerializer):
    published_time = serializers.SerializerMethodField()  # 自定义一个字段

    class Meta:
        model = models.Article
        fields = ['title', 'source', 'brief','published_time']  # 将自定义的字段放到fields中
        depth = 1  # 表示深度


    def get_published_time(self, obj):  # 定义方法,def get_自定义的字段名:
        time = (now() - obj.pub_date).days

        week, day = divmod(time, 7)
        hour, min = divmod(int((now() - obj.pub_date).seconds), 3600)
        # print(week)
        # print(day)
        if week > 1:
            return '%s周前' % week
        elif day > 1:
            return '%s天前' % day
        elif hour > 1:
            return '%s小时前' % hour
        elif min > 1:
            return '%s小时前' % min
        else:
            return '刚刚'
对多对多以及自定义字段的处理
复制代码

 

 

序列化用法:

views.py中

 保存

snippet = Snippet(ut_title="abc",user="abc",pwd= "abcd')
snippet.save()

读取:

serializer = SnippetSerializer(instance=对象,many=True) # many=True 决定instance对象是集合还是单独obj
serializer.data

转化为json字符串

from rest_framework.renderers import JSONRenderer

content = JSONRenderer().render(serializer.data) # json字符串

反序列化

from rest_framework.parsers import JSONParser
from django.utils.six import BytesIO

stream = BytesIO(content)
data = JSONParser().parse(stream)

 

多对多序列化处理有三种方式:

1. 从新定义CharField()

复制代码
class TempCharField(serializers.CharField):
    def to_representation(self, value): # 打印的是所有的数据
        data_list = []
        for row in value:
            data_list.append(row.name)
        return data_list

class UsersSerializer(serializers.Serializer):
    # a = serializers.CharField(source="roles.all") #  多对多关系的这样查出的是queryset对象
    a = TempCharField(source="roles.all") 
从新定义类
复制代码

2. 定义child

复制代码
class MyCharField(serializers.CharField):
    def to_representation(self, value):
        return {'id':value.pk, 'name':value.name}

class UsersSerializer(serializers.Serializer):
    # a = serializers.CharField(source="roles.all") # obj.mu.name
    a2 = serializers.ListField(child=MyCharField(),source="roles.all") 
定义child
复制代码

3. 定义方法(推荐)

复制代码
class UsersSerializer(serializers.Serializer):
    # a = serializers.ListField(child=MyCharField(),source="roles.all") # obj.mu.name
    a3 = serializers.SerializerMethodField()
    def get_a3(self,obj):  #get_字段名
        print(obj)   # object
        obj.roles.all()
        role_list = obj.roles.filter(id__gt=1)
        data_list = []
        for row in role_list:
            data_list.append({'pk':row.pk,'name':row.name})
        return data_list
定义方法
复制代码

 

反向生成url

复制代码
class UserSerializer(serializers.ModelSerializer):
    bbb = serializers.HyperlinkedIdentityField(view_name='detail')  #让bbb的结果为按照urls中name为detail的url反向生成url
    class Meta:
        model = models.BBB
        fields = ["b","bb","bbb"]
        depth = 1

class A(APIView):
    def get(self, request, *args, **kwargs):
        user = models.BBB.objects.all()
        print(UserSerializer(instance=user,many=True,context={'request':request}).data)
        return HttpResponse("111")



#url.py
url(r'^users5/(?P<pk>.*)', views.BBB.as_view(), name='detail')

# 结果:
[OrderedDict([('b', 'b1'), ('bb', 'b2'), ('bbb', 'http://127.0.0.1:8000/a/1')]), OrderedDict([('b', 'b2'), ('bb', 'b3'), ('bbb', 'http://127.0.0.1:8000/a/2')])]
反向生成单独的url
复制代码

 

复制代码
class UsersSerializer(serializers.HyperlinkedModelSerializer): #继承他自动生成
    class Meta:
        model = models.UserInfo
        fields = "__all__"
        exclude=["aaa"]  # 排除某个字段


class UsersView(APIView):
    def get(self,request,*args,**kwargs):
        # 方式一:
        # user_list = models.UserInfo.objects.all().values('name','pwd','group__id',"group__title")
        # return Response(user_list)

        # 方式二之多对象
        user_list = models.UserInfo.objects.all()
        ser = UsersSerializer(instance=user_list,many=True,context={'request':request})
        return Response(ser.data)


 url.py
     url(r'^aa/(?P<pk>\)$', aaa.MicroView.as_view(),name="ser_detail") # 默认name为字段名+detail
生成多个
复制代码

 

Ⅲ 验证

序列化模块类似于django的form以及modelform模块提供验证功能

复制代码
class PasswordValidator(object):
    def __init__(self, anything):
        self.anything = anything

    def __call__(self, value):
        if value != self.base:
            message = '用户输入的值必须是 %s.' % self.anything
            raise serializers.ValidationError(message)

    def set_context(self, serializer_field):
        """
        This hook is called by the serializer instance,
        prior to the validation call being made.
        """
        pass

class UsersSerializer(serializers.Serializer):
        name = serializers.CharField(min_length=6)
        pwd = serializers.CharField(error_messages={'required': '密码不能为空'}, validators=[PasswordValidator('666')])
自定义
复制代码
复制代码
class UsersSerializer(serializers.ModelSerializer):
    class Meta:
        model = models.UserInfo
        fields = "__all__"
        #自定义验证规则
        extra_kwargs = {
            'name': {'min_length': 6},
            'pwd': {'validators': [PasswordValidator("anything"), ]}
        }
基于modelserializer的
复制代码
复制代码
class UsersSerializer(serializers.ModelSerializer):
    class Meta:
        model = models.BBB
        fields = "__all__"
        extra_kwargs = {
            'name': {'min_length': 6},
            'pwd': {'validators': [PasswordValidator(666), ]}
        }
        def validate_字段(self,validated_value):   # 钩子函数进行验证
              if...:raise ValidationError(detail='xxxxxx')
              else:return validated_value
复制代码

 

# view.py 使用类似于form组件

        ser = UsersSerializer(data=request.data)
        if ser.is_valid():
            print(ser.validated_data)
        else:
            print(ser.errors)

 

Ⅳ 解析器

除了直接用url发get请求,我们还有post请求,还有url参数。restframework组件提供了更方便的功能处理各种请求数据

from rest_framework.parsers import FormParser,JSONParser,FileUploadParser,MultiPartParser 
以上四个模块分别用来处理【application/x-www-form-urlencoded】,【application/json】,【multipart/form-data】,【文件上传】,把相关内容赋值给request.data
class UsersView(APIView):
    parser_classes = [JSONParser,]  # 
    def get(request,*args,**kwargs):
        request.content_type # 数据类型
        request.data # 这里取解析器解析的数据
        request.query_string #这里取GET的数据
        request.POST #用对了Parser才有和在post请求的时候request.data一样

其他数据可以通过request._request拿到django的request对象

复制代码
REST_FRAMEWORK = {
    'DEFAULT_PARSER_CLASSES':[
        'rest_framework.parsers.JSONParser'
        'rest_framework.parsers.FormParser'
        'rest_framework.parsers.MultiPartParser'
    ]

}
全局配置settings.py
复制代码

 

 

Ⅴ 认证&权限

django restframework还在django auth模块的基础上集成了认证和权限模块,方便客户登陆,并对访问进行限制

其中认证模块主要的功能是确认客户端身份并,并将身份赋值给相关变量(BaseAuthentication)

权限模块是通过相关变量拿到客户端身份,并对该身份客户的请求进行限制()

5.1 认证

5.1.1 基本模式

复制代码
class TestAuthentication(BaseAuthentication):
    def authenticate(self, request):
        """
        用户认证,如果验证成功后返回元组: (用户,用户Token)
        :return: 
            None,表示跳过该验证;
                如果跳过了所有认证,默认用户和Token和使用配置文件进行设置
        """
        val = request.query_params.get('token')
        if val not in token_list:
            raise exceptions.AuthenticationFailed("用户认证失败")
        return ('登录用户', '用户token')    # 这里会讲第一个值赋值给request.user, 第二个值会赋值给request.auth
     def authenticate_header(self, request):
         # 验证失败时,返回的响应头WWW-Authenticate对应的值
        pass  # 如果return'Basic reala="api"' 会弹窗

复制代码
复制代码
#使用方式1:
class TestView(APIView):
    # 认证的动作是由request.user触发
    authentication_classes = [TestAuthentication, ]
    def get(self,request....): pass


#使用方式2:
class TestAuth():
    authentication_classes = [TestAuthentication, ]

class TestView(TestAuth,APIView):
    def get......
使用
复制代码
复制代码
# 全局配置: 
   "DEFAULT_AUTHENTICATION_CLASSES": [
        "web.utils.TestAuthentication",
    ],

不需要经过认证的视图类需要
class NoAuthView(APIView):
    authentication_classes = [ ] # 这里需要写一个空列表
全局配置
复制代码

 

5.2 权限

复制代码
class TestPermission(BasePermission):
    message ="fail"

    def has_permission(self, request, view):
        """
        Return `True` if permission is granted, `False` otherwise.
        """
        if request.user == "管理员":
            return True

    # GenericAPIView中get_object时调用
    def has_object_permission(self, request, view, obj):
        """
        视图继承GenericAPIView,并在其中使用get_object时获取对象时,触发单独对象权限验证
        if request.user == "管理员":
            return True

class TestView(APIView):
authentication_classes = [TestAuthentication, ]
def get(self,request):pass
复制代码
    "DEFAULT_PERMISSION_CLASSES": [
        "web.utils.TestPermission",
    ],
全局配置

 

复制代码
# APIView中有check_permission方法   
 def check_permissions(self, request):
        """
        Check if the request should be permitted.
        Raises an appropriate exception if the request is not permitted.
        """


# 还有permission_denied方法用来处理无权访问的情况,可以在视图中重写,达到友好的效果
def permission_denied(self, request, message=None):
"""
If request is not permitted, determine what kind of exception to raise.
"""
    if request.authenticators and not request.successful_authenticator:
        raise exceptions.NotAuthenticated()
    raise exceptions.PermissionDenied(detail=message)
more!
复制代码

 

 

Ⅵ 版本控制

6.1 URL方法

REST_FRAMEWORK = {
    'DEFAULT_VERSION': 'v1',            # 默认版本
    'ALLOWED_VERSIONS': ['v1', 'v2'],   # 允许的版本
    'VERSION_PARAM': 'version'          # URL中获取值的key
# 'DEFAULT_VERSIONING_CLASS':"rest_framework.versioning.URLPathVersioning" 有这个配置就不需要在类中写verioning_class了
}

 

复制代码
from django.conf.urls import url, include
from web.views import TestView

urlpatterns = [
    url(r'^(?P<version>[v1|v2]+)/test/', TestView.as_view(), name='test'),
]
url.py
复制代码
复制代码
from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework.versioning import URLPathVersioning

class TestView(APIView):
    versioning_class = URLPathVersioning
    def get(self, request, *args, **kwargs):
        # 获取版本
        request.version
        # 获取版本管理的类
        request.versioning_scheme
        # 反向生成URL
        reverse_url = request.versioning_scheme.reverse('test', request=request)
        return Response('GET请求')
复制代码

 

6.2 URL参数方法

复制代码
from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework.versioning import QueryParameterVersioning

class TestView(APIView):
    versioning_class = QueryParameterVersioning
    def get(self, request, *args, **kwargs):
        # 获取版本
        request.version
        # 获取版本管理的类
        request.versioning_scheme
        # 反向生成URL
        reverse_url = request.versioning_scheme.reverse('test', request=request)
        return Response('GET请求,响应内容')
复制代码

 

6.3 请求头方法

复制代码
from rest_framework.versioning import AcceptHeaderVersioning

class TestView(APIView):
    versioning_class = AcceptHeaderVersioning
    def get(self, request, *args, **kwargs):
        # 获取版本 HTTP_ACCEPT头
        request.version
        # 获取版本管理的类
        request.versioning_scheme
复制代码

 

6.4 主机名方法

复制代码
from rest_framework.versioning import HostNameVersioning

class TestView(APIView):
    versioning_class = HostNameVersioning
    def get(self, request, *args, **kwargs):
        # 获取版本
        request.version
        # 获取版本管理的类
        request.versioning_scheme
复制代码

 

复制代码
ALLOWED_HOSTS = ['*']  # 注意在原有REST_FRAMEWORK 的version settings的基础上加入允许的主机名

REST_FRAMEWORK = {
                'VERSION_PARAM':'version',
                'DEFAULT_VERSION':'v1',
                'ALLOWED_VERSIONS':['v1','v2'],                    'DEFAULT_VERSIONING_CLASS':"rest_framework.versioning.HostNameVersioning"
            }    
settings.py
复制代码

 

6.5 Namespace方法

复制代码
from django.conf.urls import url, include
from web.views import TestView

urlpatterns = [
    url(r'^v1/', ([
                      url(r'test/', TestView.as_view(), name='test'),
                  ], None, 'v1')),
    url(r'^v2/', ([
                      url(r'test/', TestView.as_view(), name='test'),
                  ], None, 'v2')),

]
url.py
复制代码
复制代码
from rest_framework.versioning import NamespaceVersioning

class TestView(APIView):
    versioning_class = NamespaceVersioning

    def get(self, request, *args, **kwargs):
        # 获取版本
        request.version
        # 获取版本管理的类
        request.versioning_scheme
复制代码
from rest_framework.reverse import reverse
url = request.versioning_scheme.reverse(viewname='users-list',request=request)
p.s. 反向解析

 

 

 

 高级部分

  基于以上部分,小伙伴们已经基本可以写出restful 的后端。但是restframework为我们解决的不仅仅只有这些!

 Ⅶ 分页

 7.1 LimitOffsetPagination

from rest_framework.pagination import LimitOffsetPagination
class LP(LimitOffsetPagination):
    max_limit = 40  # 最大每页显示的条数
    default_limit =20  # 默认每页显示的条数
    limit_query_param = 'limit'  # 往后取几条,get请求参数名
    offset_query_param = 'offset'  # 当前所在的位置,get请求参数名
复制代码
class UserViewSet(APIView):
    def get(self, request, *args, **kwargs):
        user_list = models.UserInfo.objects.all()
        # 实例化分页对象,获取数据库中的分页数据
        paginator = LP()
        page_user_list = paginator.paginate_queryset(user_list, self.request, view=self)

        # 序列化对象
        serializer = UserSerializer(page_user_list, many=True)

        # 生成分页和数据
        response = paginator.get_paginated_response(serializer.data)
        return response
用法
复制代码

 

 

7.2 PageNumberPagination

from rest_framework.pagination import PageNumberPagination

复制代码
from rest_framework.pagination import PageNumberPagination

class StandardResultsSetPagination(PageNumberPagination):
    # 默认每页显示的数据条数
    page_size = 1
    # 获取URL参数中设置的每页显示数据条数
    page_size_query_param = 'page_size'
    # 获取URL参数中传入的页码key
    page_query_param = 'page'
    # 最大支持的每页显示的数据条数
    max_page_size = 1
复制代码
复制代码
class UserViewSet(APIView):
    def get(self, request, *args, **kwargs):
        user_list = models.UserInfo.objects.all().order_by('-id')
        # 实例化分页对象,获取数据库中的分页数据
        paginator = StandardResultsSetPagination()
        page_user_list = paginator.paginate_queryset(user_list, self.request, view=self)

        # 序列化对象
        serializer = UserSerializer(page_user_list, many=True)

        # 生成分页和数据
        response = paginator.get_paginated_response(serializer.data)
        return response
用法
复制代码

 

 7.3 CursorPagination

from rest_framework.pagination import CursorPagination

from rest_framework.pagination import CursorPagination
class
CP(CursorPagination): cursor_query_param = 'cursor' # URL传入的游标参数 page_size = 2 # 默认每页显示的数据条数 page_size_query_param = 'page_size' # URL传入的每页显示条数的参数 max_page_size = 1000 # 每页显示数据最大条数 ordering = "id" # 根据ID从大到小排列
复制代码
class UserViewSet(APIView):
    def get(self, request, *args, **kwargs):
        user_list = models.UserInfo.objects.all().order_by('-id')

        # 实例化分页对象,获取数据库中的分页数据
        paginator = CP()
        page_user_list = paginator.paginate_queryset(user_list, self.request, view=self)

        # 序列化对象
        serializer = UserSerializer(page_user_list, many=True)
        # 生成分页和数据
        response = paginator.get_paginated_response(serializer.data)
        return response
用法
复制代码

 

Ⅷ 视图 & 路由

除了最基本的APIView, restframework还为我们提供了GenericViewSet,ModelViewSet,ModelViewSet三个视图实现方案

8.1 GenericViewSet

from django.conf.urls import url, include

urlpatterns = [
    url(r'test/', TestView.as_view({'get':'list'}), name='test'),
    url(r'detail/(?P<pk>\d+)/', TestView.as_view({'get':'list'}), name='xxxx'),  
]  
# 这里会为每一种请求设置一个视图函数,例如为get设置一个叫list的视图函数
from rest_framework import viewsets
class TestView(viewsets.GenericViewSet):
    def list(self, request, *args, **kwargs):  # get请求会找这个方法
        return Response('...')

 

8.2 ModelViewSet

复制代码
from django.conf.urls import url, include
from web.views import s10_generic

urlpatterns = [
    url(r'^test/$', s10_generic.UserViewSet.as_view({'get': 'list', 'post': 'create'})),
    url(r'^test/(?P<pk>\d+)/$', s10_generic.UserViewSet.as_view(
        {'get': 'retrieve', 'put': 'update', 'patch': 'partial_update', 'delete': 'destroy'})),
]
url.py
复制代码
复制代码
from rest_framework.viewsets import ModelViewSet
from rest_framework import serializers
from .. import models


class UserSerializer(serializers.ModelSerializer):
    class Meta:
        model = models.UserInfo
        fields = "__all__"


class UserViewSet(ModelViewSet):
    queryset = models.UserInfo.objects.all()
    serializer_class = UserSerializer
用法:对于简单请求可以省去写视图了
复制代码

 

8.3 路由系统

配合ModelViewSet,路由系统可以帮我们省去写路由的繁琐,可以自动实现增删改查

复制代码
from django.conf.urls import url, include
from rest_framework import routers
from app01 import views

router = routers.DefaultRouter()
router.register(r'users', views.UserViewSet)
router.register(r'groups', views.GroupViewSet)

# Wire up our API using automatic URL routing.
# Additionally, we include login URLs for the browsable API.
urlpatterns = [
    url(r'^', include(router.urls)),
]
复制代码

 

 

Ⅹ 访问频率

10.1 手动设置访问频率

复制代码
import time
from rest_framework.views import APIView
from rest_framework.response import Response

from rest_framework import exceptions
from rest_framework.throttling import BaseThrottle
from rest_framework.settings import api_settings

# 保存访问记录
RECORD = {
    '用户IP': [12312139, 12312135, 12312133, ]
}


class TestThrottle(BaseThrottle):
    ctime = time.time

    def get_ident(self, request):
        """
        根据用户IP和代理IP,当做请求者的唯一IP
        Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR
        if present and number of proxies is > 0. If not use all of
        HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR.
        """
        xff = request.META.get('HTTP_X_FORWARDED_FOR')  
# X-Forwarded-For(XFF)是用来识别通过HTTP代理或负载均衡方式连接到Web服务器的客户端最原始的IP地址的HTTP请求头字段。
        remote_addr = request.META.get('REMOTE_ADDR')
#REMOTE_ADDR不可以伪造的,就在curl 中也无法伪造 相对是比较安全的服务端ip获取方法,当然,也有可能被路由伪造 这个不好说,因为REMOTE_ADDR 是底层的回话ip地址,路由是可以发起伪造。
        num_proxies = api_settings.NUM_PROXIES

        if num_proxies is not None:
            if num_proxies == 0 or xff is None:
                return remote_addr
            addrs = xff.split(',')
            client_addr = addrs[-min(num_proxies, len(addrs))]
            return client_addr.strip()

        return ''.join(xff.split()) if xff else remote_addr

    def allow_request(self, request, view):
        """
        是否仍然在允许范围内
        Return `True` if the request should be allowed, `False` otherwise.
        :param request: 
        :param view: 
        :return: True,表示可以通过;False表示已超过限制,不允许访问
        """
        # 获取用户唯一标识(如:IP)

        # 允许一分钟访问10次
        num_request = 10
        time_request = 60

        now = self.ctime()
        ident = self.get_ident(request)
        self.ident = ident
        if ident not in RECORD:
            RECORD[ident] = [now, ]
            return True
        history = RECORD[ident]
        while history and history[-1] <= now - time_request:
            history.pop()
        if len(history) < num_request:
            history.insert(0, now)
            return True

    def wait(self):
        """
        多少秒后可以允许继续访问
        Optionally, return a recommended number of seconds to wait before
        the next request.
        """
        last_time = RECORD[self.ident][0]
        now = self.ctime()
        return int(60 + last_time - now)


class TestView(APIView):
    throttle_classes = [TestThrottle, ]

    def get(self, request, *args, **kwargs):
        # self.dispatch
        print(request.user)
        print(request.auth)
        return Response('GET请求,响应内容')

    def post(self, request, *args, **kwargs):
        return Response('POST请求,响应内容')

    def put(self, request, *args, **kwargs):
        return Response('PUT请求,响应内容')

    def throttled(self, request, wait):
        """
        访问次数被限制时,定制错误信息
        """

        class Throttled(exceptions.Throttled):
            default_detail = '请求被限制.'
            extra_detail_singular = '请 {wait} 秒之后再重试.'
            extra_detail_plural = '请 {wait} 秒之后再重试.'

        raise Throttled(wait)
手动
复制代码

 

10.2 利用rest模块实现

复制代码
from rest_framework.views import APIView
from rest_framework.response import Response

from rest_framework import exceptions
from rest_framework.throttling import SimpleRateThrottle


class TestThrottle(SimpleRateThrottle):

    # 配置文件定义的显示频率的Key
    scope = "test_scope"

    def get_cache_key(self, request, view):
        """
        Should return a unique cache-key which can be used for throttling.
        Must be overridden.

        May return `None` if the request should not be throttled.
        """
        if not request.user:
            ident = self.get_ident(request)
        else:
            ident = request.user

        return self.cache_format % {
            'scope': self.scope,
            'ident': ident
        }


class TestView(APIView):
    throttle_classes = [TestThrottle, ]

    def get(self, request, *args, **kwargs):
        # self.dispatch
        print(request.user)
        print(request.auth)
        return Response('GET请求,响应内容')

    def post(self, request, *args, **kwargs):
        return Response('POST请求,响应内容')

    def put(self, request, *args, **kwargs):
        return Response('PUT请求,响应内容')

    def throttled(self, request, wait):
        """
        访问次数被限制时,定制错误信息
        """

        class Throttled(exceptions.Throttled):
            default_detail = '请求被限制.'
            extra_detail_singular = '请 {wait} 秒之后再重试.'
            extra_detail_plural = '请 {wait} 秒之后再重试.'

        raise Throttled(wait)
使用
复制代码
复制代码
REST_FRAMEWORK = {
    'DEFAULT_THROTTLE_RATES': {
        'test_scope': '10/m',
    },
}
配套setting
复制代码

10.3 匿名用ip限制,登陆用user限制

复制代码
REST_FRAMEWORK = {
    'UNAUTHENTICATED_USER': None,
    'UNAUTHENTICATED_TOKEN': None,
    'DEFAULT_THROTTLE_RATES': {
        'luffy_anon': '10/m',
        'luffy_user': '20/m',
    },
}
setting
复制代码
复制代码
from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework.throttling import SimpleRateThrottle


classAnonRateThrottle(SimpleRateThrottle):
    """
    匿名用户,根据IP进行限制
    """
    scope = "anon"

    def get_cache_key(self, request, view):
        # 用户已登录,则跳过 匿名频率限制
        if request.user:
            return None

        return self.cache_format % {
            'scope': self.scope,
            'ident': self.get_ident(request)
        }


class UserRateThrottle(SimpleRateThrottle):
    """
    登录用户,根据用户token限制
    """
    scope = "user"


    def get_cache_key(self, request, view):
        """
        获取缓存key
        :param request: 
        :param view: 
        :return: 
        """
        # 未登录用户,则跳过 Token限制
        if not request.user:
            return None

        return self.cache_format % {
            'scope': self.scope,
            'ident': self.get_ident(request)
        }


class TestView(APIView):
    throttle_classes = [UserRateThrottle, AnonRateThrottle, ]

    def get(self, request, *args, **kwargs):
        print(request.user)
        print(request.auth)
        return Response('GET请求,响应内容')

    def post(self, request, *args, **kwargs):
        return Response('POST请求,响应内容')

    def put(self, request, *args, **kwargs):
        return Response('PUT请求,响应内容')
用法
复制代码