rest_framework的频率限制和简单操作

models.py

from django.db import models

# Create your models here.


class User(models.Model):
    name = models.CharField(max_length=32)
    pwd = models.CharField(max_length=32)
    type_choices = ((1, "普通用户"), (2, "VIP"), (3, "SVIP"))
    user_type = models.IntegerField(choices=type_choices, default=1)


class Token(models.Model):
    user = models.OneToOneField("User")
    token = models.CharField(max_length=128)

    def __str__(self):
        return self.token


class Book(models.Model):
    title = models.CharField(max_length=32)
    price = models.IntegerField()
    pub_date = models.DateField()
    publish = models.ForeignKey("Publish")
    authors = models.ManyToManyField("Author")

    def __str__(self):
        return self.title


class Publish(models.Model):
    name = models.CharField(max_length=32)
    email = models.EmailField()

    def __str__(self):
        return self.name


class Author(models.Model):
    name = models.CharField(max_length=32)
    age = models.IntegerField()

    def __str__(self):
        return self.name
View Code
urls.py
from django.conf.urls import url, include
from django.contrib import admin
from app01 import views
# 封装url
from rest_framework import routers
routers = routers.DefaultRouter()
routers.register("authors", views.AuthorModelView)

urlpatterns = [
    url(r'^admin/', admin.site.urls),
    url(r'^', include(routers.urls)),  # 封装url
    # url(r'^authors/$', views.AuthorModelView.as_view({"get":"list","post":"create"}), name="author"),
                # 有名分组                                                                                                    # 别名
    # url(r'^authors/(?P<pk>\d+)/$', views.AuthorModelView.as_view({"get":"retrieve","put":"update","delete":"destroy"}), name="detailauthor"),

    url(r'^login/$', views.LoginView.as_view(), name="login"),
]
View Code
views.py
from rest_framework.views import APIView
from rest_framework.pagination import PageNumberPagination
from rest_framework.response import Response
from rest_framework import viewsets
from app01.serilizer import *
from .models import *
import json


# 封装逻辑1
# class AuthorView(mixins.ListModelMixin, mixins.CreateModelMixin, generics.GenericAPIView):
#     queryset = Author.objects.all()
#     serializer_class = AuthorModelSerializers
#
#     def get(self, request, *args, **kwargs):
#         return self.list(request, *args, **kwargs)
#
#     def post(self, request, *args, **kwargs):
#         return self.create(request, *args, **kwargs)
#
#
# class AuthorDetailView(mixins.RetrieveModelMixin, mixins.DestroyModelMixin, mixins.UpdateModelMixin, generics.GenericAPIView):
#     queryset = Author.objects.all()
#     serializer_class = AuthorModelSerializers
#
#     def get(self, request, *args, **kwargs):
#         return self.retrieve(request, *args, **kwargs)
#
#     def delete(self, request, *args, **kwargs):
#         return self.destroy(request, *args, **kwargs)
#
#     def put(self, request, *args, **kwargs):
#         return self.update(request, *args, **kwargs)


# 封装逻辑2
# class AuthorView(generics.ListCreateAPIView):
#     queryset = Author.objects.all()
#     serializer_class = AuthorModelSerializers
#
#
# class AuthorDetailView(generics.RetrieveUpdateDestroyAPIView):
#     queryset = Author.objects.all()
#     serializer_class = AuthorModelSerializers


class MyPageNumberPagination(PageNumberPagination):
    # 分页
    page_size = 1
    page_query_param = "page"
    page_size_query_param = "size"
    max_page_size = 2


# 封装逻辑3
class AuthorModelView(viewsets.ModelViewSet):
    queryset = Author.objects.all()
    serializer_class = AuthorModelSerializers
    pagination_class = MyPageNumberPagination    # 分页


# 获取随机数
def get_random_str(user):
    import hashlib,time
    ctime=str(time.time())

    md5=hashlib.md5(bytes(user,encoding="utf8"))
    md5.update(bytes(ctime,encoding="utf8"))

    return md5.hexdigest()


class LoginView(APIView):
    authentication_classes = []
    permission_classes = []
    def post(self,request):
        name=request.data.get("name")
        pwd=request.data.get("pwd")
        user=User.objects.filter(name=name,pwd=pwd).first()
        res = {"state_code": 1000, "msg": None}
        if user:
            random_str=get_random_str(user.name)
            token=Token.objects.update_or_create(user=user,defaults={"token":random_str})
            res["token"]=random_str
        else:
            res["state_code"]=1001 #错误状态码
            res["msg"] = "用户名或者密码错误"
        return Response(json.dumps(res,ensure_ascii=False))
View Code
serilizer.py
from rest_framework import serializers
from .models import *


# 序列化
class AuthorModelSerializers(serializers.ModelSerializer):
    class Meta:
        model = Author
        fields = "__all__"
View Code
utils.py
from rest_framework.authentication import BaseAuthentication
from rest_framework.throttling import BaseThrottle
from rest_framework import exceptions
from .models import *
import time


class TokenAuth(BaseAuthentication):
    # 认证组件
    def authenticate(self, request):
        token = request.GET.get("token")
        token_obj = Token.objects.filter(token=token).first()
        if not token_obj:
            raise exceptions.AuthenticationFailed("验证失败!")
        else:
            return token_obj.user.name,token_obj.token


class SVIPPermission(object):
    # 权限组件
    message = "只有超级用户才能访问"

    def has_permission(self, request, view):
        username = request.user
        user_type = User.objects.filter(name=username).first().user_type
        if user_type == 3:
            return True
        else:
            return False


VISIT_RECORD = {}    # 缓存访问记录

class VisitRateThrottle(BaseThrottle):
    # 频率组件
    def __init__(self):
        self.history = None

    def allow_request(self, request, view):
        # 获取用户IP
        remote_addr = request.META.get('REMOTE_ADDR')
        ctime = time.time()
        if remote_addr not in VISIT_RECORD:  # 如果是第一次访问,就存放访问时间以及IP地址
            VISIT_RECORD[remote_addr] = [ctime, ]  # 添加到VISIT_RECORD中
            return True
        history = VISIT_RECORD.get(remote_addr)  # 不是第一次访问,先获取记录
        self.history = history
        while history and history[-1] < ctime - 60:  # 如果最早一次访问时间超过一分钟,就删掉 去掉history and  后把while改成if,可以实现一样的功能
            # 上一行代码中while循环一直循环,如果列表history为空,循环的时候都会报错,因为找不到history[-1]这个值,所以要加上history,用来跳出循环,防止代码出错
            history.pop()

        if len(history) < 3:  # 不用写else,如果不小于3,会有错误处理机制,直接拒绝访问。
            history.insert(0, ctime)  # 按照索引插入元素
            return True

    def wait(self):
        ctime = time.time()
        return 60 - (ctime - self.history[-1])
View Code
settings.py
REST_FRAMEWORK = {
    "DEFAULT_AUTHENTICATION_CLASSES": ["app01.utils.TokenAuth"],     # 认证
    "DEFAULT_PERMISSION_CLASSES": ["app01.utils.SVIPPermission"],     # 权限
    "DEFAULT_THROTTLE_CLASSES": ["app01.utils.VisitRateThrottle"],    # 频率
}
View Code

内容参考于------》》https://www.cnblogs.com/yuanchenqi/articles/8719520.html 

posted @ 2018-09-20 07:40  知你几分  阅读(313)  评论(0编辑  收藏  举报