Django REST framework基础:版本、认证、权限、限制

DRF的版本

版本控制是做什么用的, 我们为什么要用

首先我们要知道我们的版本是干嘛用的呢~~大家都知道我们开发项目是有多个版本的~~

随着我们项目的更新~版本就越来越多~~我们不可能新的版本出了~以前旧的版本就不进行维护了~~~

那我们就需要对版本进行控制~~这个DRF也给我们提供了一些封装好的版本控制方法~~

版本控制怎么用

之前我们学视图的时候知道APIView,也知道APIView返回View中的view函数,然后调用的dispatch方法~

那我们现在看下dispatch方法~~看下它都做了什么~~

执行self.initial方法之前是各种赋值,包括request的重新封装赋值,下面是路由的分发,那我们看下这个方法都做了什么~~

我们可以看到,我们的version版本信息赋值给了 request.version  版本控制方案赋值给了 request.versioning_scheme~~

其实这个版本控制方案~就是我们配置的版本控制的类~~

也就是说,APIView通过这个方法初始化自己提供的组件~~

我们接下来看看框架提供了哪些版本的控制方法~~在rest_framework.versioning里~~

框架一共给我们提供了这几个版本控制的方法~~我们在这里只演示一个~~因为基本配置都是一样的~~

详细用法

全局配置

这里我们以 URLPathVersioning 为例,还是在项目的settings.py中REST_FRAMEWORK配置项下配置:

复制代码
REST_FRAMEWORK = {
    ...
    'DEFAULT_VERSIONING_CLASS': 'rest_framework.versioning.URLPathVersioning',
    'DEFAULT_VERSION': 'v1',  # 默认的版本
    'ALLOWED_VERSIONS': ['v1', 'v2'],  # 有效的版本
    'VERSION_PARAM': 'version',  # 版本的参数名与URL conf中一致
}
复制代码

urls.py中

复制代码
urlpatterns = [
    ...
    url(r'^(?P<version>[v1|v2]+)/publishers/$', views.PublisherViewSet.as_view({'get': 'list', 'post': 'create'})),
    url(r'^(?P<version>[v1|v2]+)/publishers/(?P<pk>\d+)/$', views.PublisherViewSet.as_view({'get': 'retrieve', 'put': 'update', 'delete': 'destroy'})),

]
复制代码

我们在视图中可以通过访问 request.version 来获取当前请求的具体版本,然后根据不同的版本来返回不同的内容:

我们可以在视图中自定义具体的行为,下面以不同的版本返回不同的序列化类为例

复制代码
class PublisherViewSet(ModelViewSet):

    def get_serializer_class(self):
        """不同的版本使用不同的序列化类"""
        if self.request.version == 'v1':
            return PublisherModelSerializerVersion1
        else:
            return PublisherModelSerializer
    queryset = models.Publisher.objects.all()
复制代码

局部配置

注意,通常我们是不会单独给某个视图设置版本控制的,如果你确实需要给单独的视图设置版本控制,你可以在视图中设置versioning_class属性,如下:

class PublisherViewSet(ModelViewSet):

    ...
    versioning_class = URLPathVersioning

认证、权限和限制

身份验证是将传入请求与一组标识凭据(例如请求来自的用户或其签名的令牌)相关联的机制。然后 权限 和 限制 组件决定是否拒绝这个请求。

简单来说就是:

认证确定了你是谁

权限确定你能不能访问某个接口

限制确定你访问某个接口的频率

认证

可以settings中全局设置,然后某些视图函数中(例如:登陆注册,主页)设置局部认证为空 

class Reg(APIView):
    authentication_classes = [ ]  # 局部配置认证

    def post(self, request):
        user = request.data.get('user')
        pwd = request.data.get('pwd')
        re_pwd = request.data.get('pwd')

        if user and pwd:
            if re_pwd == pwd:
                User_info.objects.create(user=user, pwd=pwd)
                return Response('创建成功!')
            else:
                return Response('两次密码不一致')

        else:
            return Response('参数不合法')
authentication_classes = [ ] # 局部配置认证

 

REST framework 提供了一些开箱即用的身份验证方案,并且还允许你实现自定义方案。

 

接下类我们就自己动手实现一个基于Token的认证方案:

自定义Token认证

定义一个用户表和一个保存用户Token的表:

复制代码
class UserInfo(models.Model):
    username = models.CharField(max_length=16)
    password = models.CharField(max_length=32)
    type = models.SmallIntegerField(
        choices=((1, '普通用户'), (2, 'VIP用户')),
        default=1
    )


class Token(models.Model):
    user = models.OneToOneField(to='UserInfo')
    token_code = models.CharField(max_length=128)
复制代码

定义一个登录视图:

复制代码
def get_random_token(username):
    """
    根据用户名和时间戳生成随机token
    :param username:
    :return:
    """
    import hashlib, time
    timestamp = str(time.time())
    m = hashlib.md5(bytes(username, encoding="utf8"))
    m.update(bytes(timestamp, encoding="utf8"))
    return m.hexdigest()


class LoginView(APIView):
    """
    校验用户名密码是否正确从而生成token的视图
    """
    def post(self, request):
        res = {"code": 0}
        print(request.data)
        username = request.data.get("username")
        password = request.data.get("password")

        user = models.UserInfo.objects.filter(username=username, password=password).first()
        if user:
            # 如果用户名密码正确
            token = get_random_token(username)
            models.Token.objects.update_or_create(defaults={"token_code": token}, user=user)
            res["token"] = token
        else:
            res["code"] = 1
            res["error"] = "用户名或密码错误"
        return Response(res)
复制代码

定义一个认证类

复制代码
from rest_framework.authentication import BaseAuthentication
from rest_framework.exceptions import AuthenticationFailed


class MyAuth(BaseAuthentication):
    def authenticate(self, request):
        if request.method in ["POST", "PUT", "DELETE"]:
            request_token = request.data.get("token", None)
            if not request_token:
                raise AuthenticationFailed('缺少token')
            token_obj = models.Token.objects.filter(token_code=request_token).first()
            if not token_obj:
                raise AuthenticationFailed('无效的token')
            return token_obj.user.username, None
        else:
            return None, None
复制代码

视图级别认证

class CommentViewSet(ModelViewSet):

    queryset = models.Comment.objects.all()
    serializer_class = app01_serializers.CommentSerializer
    authentication_classes = [MyAuth, ]

全局级别认证

# 在settings.py中配置
REST_FRAMEWORK = {
    "DEFAULT_AUTHENTICATION_CLASSES": ["app01.utils.MyAuth", ]
}

权限

只有VIP用户才能看的内容。

自定义一个权限类

复制代码
# 自定义权限
class MyPermission(BasePermission):
    message = 'VIP用户才能访问'

    def has_permission(self, request, view):
        """
        自定义权限只有VIP用户才能访问
        """
        # 因为在进行权限判断之前已经做了认证判断,所以这里可以直接拿到request.user
        if request.user and request.user.type == 2:  # 如果是VIP用户
            return True
        else:
            return False
复制代码

视图级别配置

class CommentViewSet(ModelViewSet):

    queryset = models.Comment.objects.all()
    serializer_class = app01_serializers.CommentSerializer
    authentication_classes = [MyAuth, ]
    permission_classes = [MyPermission, ]

全局级别设置

# 在settings.py中设置rest framework相关配置项
REST_FRAMEWORK = {
    "DEFAULT_AUTHENTICATION_CLASSES": ["app01.utils.MyAuth", ],
    "DEFAULT_PERMISSION_CLASSES": ["app01.utils.MyPermission", ]
}

限制

DRF内置了基本的限制类,首先我们自己动手写一个限制类,熟悉下限制组件的执行过程。

自定义限制类

复制代码
VISIT_RECORD = {}
# 自定义限制
class MyThrottle(object):

    def __init__(self):
        self.history = None

    def allow_request(self, request, view):
        """
        自定义频率限制60秒内只能访问三次
        """
        # 获取用户IP
        ip = request.META.get("REMOTE_ADDR")
        timestamp = time.time()
        if ip not in VISIT_RECORD:
            VISIT_RECORD[ip] = [timestamp, ]
            return True
        history = VISIT_RECORD[ip]
        self.history = history
        history.insert(0, timestamp)
        while history and history[-1] < timestamp - 60:
            history.pop()
        if len(history) > 3:
            return False
        else:
            return True

    def wait(self):
        """
        限制时间还剩多少
        """
        timestamp = time.time()
        return 60 - (timestamp - self.history[-1])
复制代码

 

视图使用

class CommentViewSet(ModelViewSet):

    queryset = models.Comment.objects.all()
    serializer_class = app01_serializers.CommentSerializer
    throttle_classes = [MyThrottle, ]

全局使用

# 在settings.py中设置rest framework相关配置项
REST_FRAMEWORK = {
    "DEFAULT_AUTHENTICATION_CLASSES": ["app01.utils.MyAuth", ],
    "DEFAULT_PERMISSION_CLASSES": ["app01.utils.MyPermission", ]
    "DEFAULT_THROTTLE_CLASSES": ["app01.utils.MyThrottle", ]
}

使用内置限制类

复制代码
from rest_framework.throttling import SimpleRateThrottle


class VisitThrottle(SimpleRateThrottle):

    scope = "xxx"

    def get_cache_key(self, request, view):
        return self.get_ident(request)
复制代码

全局配置

复制代码
# 在settings.py中设置rest framework相关配置项
REST_FRAMEWORK = {
    "DEFAULT_AUTHENTICATION_CLASSES": ["app01.utils.MyAuth", ],
    # "DEFAULT_PERMISSION_CLASSES": ["app01.utils.MyPermission", ]
    "DEFAULT_THROTTLE_CLASSES": ["app01.utils.VisitThrottle", ],
    "DEFAULT_THROTTLE_RATES": {
        "xxx": "5/m",
    }
}
复制代码

 

posted @ 2019-01-15 17:38  Niuli'blog  阅读(257)  评论(0编辑  收藏  举报