认证、权限、频率
数据准备
在前面说的 APIView 中封装了三大认证,分别为认证、权限、频率。认证即登录认证,权限表示该用户是否有权限访问接口,频率表示用户指定时间内能访问接口的次数
为了方便举例说明,事先定义好模型表
from django.db import models # 图书跟作者:多对多,需要建立中间表,但是我们可以通过ManyToManyField自动生成,写在哪里都行 # 图书跟出版社:一对多,一个出版社,出版多本书,关联字段写在多的一方,写在Book class Book(models.Model): name = models.CharField(max_length=32, default='xx') price = models.IntegerField() publish = models.ForeignKey(to='Publish', on_delete=models.CASCADE) # 留住,还有很多 authors = models.ManyToManyField(to='Author') def publish_detail(self): return {'name': self.publish.name, 'addr': self.publish.addr} def author_list(self): l = [] for author in self.authors.all(): l.append({'name': author.name, 'phone': author.phone}) return l class Publish(models.Model): name = models.CharField(max_length=32) addr = models.CharField(max_length=32) class Author(models.Model): name = models.CharField(max_length=32) phone = models.CharField(max_length=11) # User与UserToken一对一关联,UserToken表用于储存用户登录的Token,进行相关操作需要带着token过来才能登录,不带就不能登录 class User(models.Model): username = models.CharField(max_length=32) password = models.CharField(max_length=32) user_type = models.IntegerField(choices=((1,'超级管理员'),(2,'普通管理员'),(3,'游客'))) class UserToken(models.Model): token = models.CharField(max_length=64) user = models.OneToOneField(to='User',on_delete=models.CASCADE, null=True)
认证
登录接口
自定义登录认证可以生成随机字符串,并添加进 UserToken 表中,每一次登录都会生成,若该字符串已存在则更新,不存在则新建。
# views.py
from .models import User,UserToken from .serializer import Bookserializer from rest_framework.response import Response from rest_framework.viewsets import ViewSet from rest_framework.decorators import action import uuid class UserView(ViewSet): @action(methods=['POST'], detail=False) def login(self,request): username = request.data.get('username') password = request.data.get('password') user = User.objects.filter(username=username,password=password).first() if user: token=str(uuid.uuid4()) # 生成一个随机字符串 # 在userToken表中存储一下:
# 如果查找不到相关数据,则插入一条,能查到则按照defaults修改记录 UserToken.objects.update_or_create(user_id=user.pk, defaults={'token':token}) return Response({'code': '100', 'msg': '登录成功'}) else: return Response({'code':'101','msg':'用户名或密码错误'})
# urls.py
from app01 import views from rest_framework.routers import SimpleRouter router = SimpleRouter() router.register('user',views.UserView,'user') urlpatterns = [ path(r'api/v1/',include(router.urls)) ]
自定义认证
自定义认证表需要创建认证类,首先继承拓展 BaseAuthentication
导入语句:from rest_framework.authentication import BaseAuthentication
自定义认证类
from rest_framework.authentication import BaseAuthentication
from rest_framework.exceptions import AuthenticationFailed # 异常
class MyAuthentication(BaseAuthentication):
# 重写 BaseAuthentication 中的 authenticate 方法
def authenticate(self, request):
# 在请求头中获取用户登录的 token 字符串
token = request.query_params.get('token')
# 判断该字符串是否存在
user_token = UserToken.objects.filter(token=token).first()
if user_token:
# 返回的第一个值是当前登录用户,第二个值是 token
return user_token.user, token
else:
# 若不存在则报 AuthenticationFailed 异常
raise AuthenticationFailed('请先登录')
在配置文件中添加配置,DEFAULT_AUTHENTICATION_CLASSES 的值为列表,其包含认证类路径
REST_FRAMEWORK={
"DEFAULT_AUTHENTICATION_CLASSES":["app01.auth.MyAuthentication",]
}
局部使用
局部使用,只需要在视图类里加入:
authentication_classes = [MyAuthentication, ]
局部禁用
可以选择某一些方法可以认证,在视图类中添加 get_authenticators
authentication_classes = [ ]
代码示例
- 查询所有不需要登录就能访问
- 查询单个,需要登录才能访问
# 查询所有(不需要登录) class BookView(ViewSetMixin,ListAPIView): queryset = Book.objects.all() serializer_class = Bookserializer authentication_classes = [] # 查询单个(需要登录) class BookDetail(ViewSetMixin,RetrieveAPIView): queryset = Book.objects.all() serializer_class = Bookserializer authentication_classes = [MyAuthentication]
总结流程
- 写一个认证类,继承 BaseAuthentication 类
- 重写 authenticate 方法,方法中编写逻辑,实现登录认证
- 认证成功需要返回两个值,为登录用户和 token,
- 认证不通过,抛 AuthenticationFailed 认证异常
- 编写好认证类可以在局部使用,也可以在全局使用
内置认证类
-SessionAuthentication 之前老的 session 认证登录方式用,后期不用
-BasicAuthentication 基本认证方式
-TokenAuthentication 使用 token 认证方式,也可以自己写
可以在配置文件中配置全局默认的认证方案
REST_FRAMEWORK = {
'DEFAULT_AUTHENTICATION_CLASSES': (
'rest_framework.authentication.SessionAuthentication', # session认证
'rest_framework.authentication.BasicAuthentication', # 基本认证
)
}
也可以在每个视图中通过设置authentication_classess属性来设置
from rest_framework.authentication import SessionAuthentication, BasicAuthentication
from rest_framework.views import APIView
class ExampleView(APIView):
authentication_classes = [SessionAuthentication, BasicAuthentication]
...
权限
自定义权限
登录认证成功后,还需要认证权限,有一些接口需要指定权限才能访问。所以权限需要和登录认证相关联。每个人的权限在表中默认设为普通用户。
自定义权限需要继承 BasePermission 编写权限类
导入语句:from rest_framework.permissions import BasePermission
自定义权限类
from rest_framework.permissions import BasePermission
from .models import UserToken, MyUser
class MyPermission(BasePermission):
# message 为认证失败提示信息
message = ''
# 需要重写 has_permission 方法,原方法默认返回 True
def has_permission(self, request, view):
# 获取当前登录用户
user = request.user
# 获取当前用户权限类型
user_type = user.user_type
if user_type == 1:
# 权限符合返回 True
return True
else:
# 权限不符合,添加提示信息,并返回 False
self.message = '你是: %s,权限不够' % user.get_user_type_display()
return False
全局使用
全局使用也是在配置文件中添加
REST_FRAMEWORK={
"DEFAULT_AUTHENTICATION_CLASSES":["app01.auth.MyAuthentication",],
"DEFAULT_PERMISSION_CLASSES":["app01.auth.MyPermission",]
}
局部使用
局部使用,只需要在视图类里加入该权限类即可:
permission_classes = [MyPermission,]
选择使用
在视图类中添加 get_permissions 判断如果请求方式符合就去认证
permission_classes = [ ]
总结流程
- 写一个权限类,继承 BasePermission 类
- 重写 has_permission 方法,方法中编写逻辑,实现权限认证,在这方法中,request.user就是当前登录用户
- 有权限返回 True,没有权限可以添加 message 并返回 False
- 编写好权限类可以在局部使用,也可以在全局使用
内置权限类
from rest_framework.permissions import AllowAny,IsAuthenticated,IsAdminUser,IsAuthenticatedOrReadOnly
-AllowAny 允许所有用户
-IsAdminUser 校验是不是 auth 的超级管理员权限
-IsAuthenticated 后面用,验证用户是否登录,登录后才有权限,没登录就没有权限
-IsAuthenticatedOrReadOnly 了解即可
全局使用
可以在配置文件中全局设置默认的权限管理类,如下
REST_FRAMEWORK = {
....
'DEFAULT_PERMISSION_CLASSES': (
'rest_framework.permissions.IsAuthenticated',
)
}
如果未指明,则采用如下默认配置
'DEFAULT_PERMISSION_CLASSES': (
'rest_framework.permissions.AllowAny',
)
局部使用
也可以在具体的视图中通过 permission_classes 属性来设置,如下
from rest_framework.permissions import IsAuthenticated
from rest_framework.views import APIView
class ExampleView(APIView):
permission_classes = (IsAuthenticated,)
...
频率
SimpleRateThrottle 内置频率类
from rest_framework.throttling import SimpleRateThrottle
class MyThrottle(SimpleRateThrottle):
# 该属性作为键名在 setting 配置文件中使用
scope = 'count_time'
# 重写 get_cache_key 方法,该方法返回的值会被作为限制的依据
def get_cache_key(self, request, view):
return request.META.get('REMOTE_ADDR') # 获取请求头中的IP地址
setting 配置
REST_FRAMEWORK = {
"DEFAULT_THROTTLE_RATES": {
# 频率类中scope对应的值
'count_time': '3/m', # 数字/s m h d
},
}
全局使用
REST_FRAMEWORK = {
"DEFAULT_THROTTLE_RATES": {
# 频率类中scope对应的值
'count_time': '3/m', # 数字/s m h d
},
'DEFAULT_THROTTLE_CLASSES': ['app01.auth.MyThrottle', ]
}
局部使用
在视图类中添加,同样,局部禁用只要赋值空列表即可
throttle_classes = [MyThrottle, ]
总结流程
- 继承 SimpleRateThrottle 类
- 重写 get_cache_key 方法,该方法返回的值是限制的依据,例如按照 IP 地址来作为限制条件,设置每个 IP 地址访问的次数。需要注意的是,IP 地址在
request.META
中获取'REMOTE_ADDR'
- 接下来需要配置 setting ,需要设置访问的频率。首先需要设置属性
scope
,该属性的值会作为频率的键名,在 setting 配置文件 REST_FRAMEWORK 中的 DEFAULT_THROTTLE_RATES 配置,键名是 scope,键值是字符串,格式为'x/y'
,x 表示访问的次数,y 表示访问的时间区间(可以为 s(秒)、m(份)、h(时)、d(天)) - 编写好频率类可以在局部使用,也可以在全局使用
AnonRateThrottle 内置频率类
AnonRateThrottle 内置频率类的功能:对于登录用户不限制次数,只未登录用户限制次数,限制的次数需要在配置文件中配置。使用也支持全局和局部
配置文件
REST_FRAMEWORK = {
'DEFAULT_THROTTLE_CLASSES': (
'rest_framework.throttling.AnonRateThrottle',
),
'DEFAULT_THROTTLE_RATES': {
'anon': '3/m',
}
}
UserRateThrottle 内置频率类
UserRateThrottle 内置频率类的功能:限制登录用户的频率,限制的次数需要在配置文件中配置。也支持全局和局部使用
配置文件
REST_FRAMEWORK = {
'DEFAULT_THROTTLE_CLASSES': (
'rest_framework.throttling.UserRateThrottle'
),
'DEFAULT_THROTTLE_RATES': {
'user': '10/m'
}
}
自定义频率类
# 自定义的逻辑
#(1)取出访问者ip
#(2)判断当前ip不在访问字典里,添加进去,并且直接返回True,表示第一次访问,在字典里,继续往下走
#(3)循环判断当前ip的列表,有值,并且当前时间减去列表的最后一个时间大于60s,把这种数据pop掉,这样列表中只有60s以内的访问时间,
#(4)判断,当列表小于3,说明一分钟以内访问不足三次,把当前时间插入到列表第一个位置,返回True,顺利通过
#(5)当大于等于3,说明一分钟内访问超过三次,返回False验证失败
class MyThrottles():
VISIT_RECORD = {}
def __init__(self):
self.history=None
def allow_request(self,request, view):
#(1)取出访问者ip
# print(request.META)
ip=request.META.get('REMOTE_ADDR')
import time
ctime=time.time()
# (2)判断当前ip不在访问字典里,添加进去,并且直接返回True,表示第一次访问
if ip not in self.VISIT_RECORD:
self.VISIT_RECORD[ip]=[ctime,]
return True
self.history=self.VISIT_RECORD.get(ip)
# (3)循环判断当前ip的列表,有值,并且当前时间减去列表的最后一个时间大于60s,把这种数据pop掉,这样列表中只有60s以内的访问时间,
while self.history and ctime-self.history[-1]>60:
self.history.pop()
# (4)判断,当列表小于3,说明一分钟以内访问不足三次,把当前时间插入到列表第一个位置,返回True,顺利通过
# (5)当大于等于3,说明一分钟内访问超过三次,返回False验证失败
if len(self.history)<3:
self.history.insert(0,ctime)
return True
else:
return False
def wait(self):
import time
ctime=time.time()
return 60-(ctime-self.history[-1])
其他
-
AnonRateThrottle
限制所有匿名未认证用户,使用IP
区分用户。
使用 DEFAULT_THROTTLE_RATES[‘anon’] 来设置频次 -
UserRateThrottle
限制认证用户,使用User id
来区分。
使用 DEFAULT_THROTTLE_RATES[‘user’] 来设置频次 -
ScopedRateThrottle
限制用户对于每个视图的访问频次,使用 ip 或 user id
认证源码分析
认证顺序源码
三大认证的顺序是:登录 ----> 权限 ----> 频率 。
已知的是三大认证是在 APIView 中封装,其源码如下
dispatch
def dispatch(self, request, *args, **kwargs):
...
try:
# 三大认证
self.initial(request, *args, **kwargs)
...
except Exception as exc:
...
...
initial
def initial(self, request, *args, **kwargs):
...
# 认证
self.perform_authentication(request)
# 权限
self.check_permissions(request)
# 频率
self.check_throttles(request)
在代码中由上而下执行,因此有了三大认证的顺序
局部认证分析
局部使用三大认证是在视图类中使用,例如 authentication_classes、permission_classes、throttle_classes,要想知道如何在视图类中配置就可以进行认证的可以分析其源码。
- self.perform_authentication(request)
def perform_authentication(self, request):
request.user
- 其返回的值是 request.user,这里的 request 是 APIView 重新分装过的新 request
- 从
rest_framework.request Request
中查看 Request 源码,发现其有一个名为 user 的方法,如下所示
- user
class Request:
...
@property
def user(self):
if not hasattr(self, '_user'):
with wrap_attributeerrors():
self._authenticate()
return self._user
- 该方法使用了 property 装饰器,将 user 方法装饰成一个属性。
- 方法中进行了 if 判断,由于开始并没有 _user,所以会执行
_authenticate
方法 - 方法中的 self 表示的是 Request 的对象,执行的是 Request 的 _authenticate 方法
- _authenticate
class Request:
def _authenticate(self):
for authenticator in self.authenticators:
try:
user_auth_tuple = authenticator.authenticate(self)
except exceptions.APIException:
self._not_authenticated()
raise
if user_auth_tuple is not None:
self._authenticator = authenticator
self.user, self.auth = user_auth_tuple
return
self._not_authenticated()
for authenticator in self.authenticators
语句中的 self.authenticators 是我们在视图类上添加的认证类,但是是以: [认证类(), 认证类2()] 的形式存在。所以 authenticator 表示的是 认证类()user_auth_tuple = authenticator.authenticate(self)
调用认证类中的 authenticate 方法,返回的元组(一个是登录用户、一个是 token)用 user_auth_tuple 接收。- 调用 authenticate 方法了但是需要注意的是总共传了俩个参数,该方法时绑定给对象的,会自动传入认证类本身,还有一个是 self,也就是 Request类的对象
except exceptions.APIException:
捕获的是 APIException,我们抛出的是 AuthenticationFailed,但是由于其继承了 APIException,相当于也捕获了。self.user, self.auth = user_auth_tuple
如果返回了两个值,第一个值给了 request.user ,第二个值给了 request.auth。因此认证过后 request.user 会有当前用户。- 留了一个点是为什么 self.authenticators 就是我们写的认证类并且还加了括号。
- _init_
class Request:
def __init__(self, request, parsers=None, authenticators=None,
negotiator=None, parser_context=None):
...
self.authenticators = authenticators or ()
- self.authenticators 是 Reqeust 在类实例化的时候传入的,如果不传就是空元组
- Request 的实例化也就是 dispatch 方法中包装了新的 Request
- dispatch
request = self.initialize_request(request, *args, **kwargs)
- initialize_request
class APIView(View):
...
def initialize_request(self, request, *args, **kwargs):
"""
Returns the initial request object.
"""
parser_context = self.get_parser_context(request)
return Request(
request,
parsers=self.get_parsers(),
authenticators=self.get_authenticators(),
negotiator=self.get_content_negotiator(),
parser_context=parser_context
)
authenticators=self.get_authenticators()
由 get_authenticators() 赋值
- get_authenticators()
class APIView(View):
...
def get_permissions(self):
return [permission() for permission in self.permission_classes]
- 在这里通过列表生成式,从 self.permission_classes 也就是我们在类中编写的认证类列表(查找顺序)中获取认证类并加括号调用。
权限源码分析
同样的查看在 dispatch 中 initial 方法里的 check_permissions 源码
check_permissions
class APIView(View):
...
def check_permissions(self, request):
for permission in self.get_permissions():
if not permission.has_permission(request, self):
self.permission_denied(
request,
message=getattr(permission, 'message', None),
code=getattr(permission, 'code', None)
)
for permission in self.get_permissions():
permission 表示的是权限类对象if not permission.has_permission(request, self):
判断 has_permission 方法返回的是 True 还是 False,如果为 False 则执行 permission_denied 方法报异常message=getattr(permission, 'message', None),
获取在视图类中编写的 message
get_permissions
def get_permissions(self):
return [permission() for permission in self.permission_classes]
permission_denied
def permission_denied(self, request, message=None, code=None):
if request.authenticators and not request.successful_authenticator:
raise exceptions.NotAuthenticated()
raise exceptions.PermissionDenied(detail=message, code=code)
频率源码分析
查看 check_throttles 源码
check_throttles
def check_throttles(self, request):
throttle_durations = []
for throttle in self.get_throttles():
if not throttle.allow_request(request, self):
throttle_durations.append(throttle.wait())
if throttle_durations:
durations = [
duration for duration in throttle_durations
if duration is not None
]
duration = max(durations, default=None)
self.throttled(request, duration)
for throttle in self.get_throttles():
throttle 表示的是频率类对象,get_throttles() 表示的是频率类对象列表。if not throttle.allow_request(request, self):
获取对象的 allow_request 方法,返回 True 就是没有频率限制住,返回 False 就是被频率限制了
我们可以查看 SimpleRateThrottle 的 allow_request 方法
allow_request
def allow_request(self, request, view):
if self.rate is None:
return True
self.key = self.get_cache_key(request, view)
if self.key is None:
return True
self.history = self.cache.get(self.key, [])
self.now = self.timer()
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
return self.throttle_failure()
return self.throttle_success()
if self.rate is None:
对 rate 进行判断,rate 属性在其双下 _init_ 方法中实现
_init_
def __init__(self):
if not getattr(self, 'rate', None):
# 没有该属性调用 get_rate
self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate)
- 查看 get_rate 源码
get_rate
def get_rate(self):
if not getattr(self, 'scope', None):
msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
self.__class__.__name__)
raise ImproperlyConfigured(msg)
try:
return self.THROTTLE_RATES[self.scope]
except KeyError:
msg = "No default throttle rate set for '%s' scope" % self.scope
raise ImproperlyConfigured(msg)
- 在 get_rate 方法中获取 scope 若是没有写则直接抛出异常。
- 若有则去 THROTTLE_RATES 中取出 scope,THROTTLE_RATES 源码如下
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
- 也就是说 rate 的值就是 scope 对应的值,类似于 5/m 表示频率。
parse_rate
def parse_rate(self, rate):
if rate is None:
return (None, None)
num, period = rate.split('/')
num_requests = int(num)
duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
return (num_requests, duration)
num, period = rate.split('/')
对 rate 进行切割,这里我们假设频率为 5/m,那么切割后的 num 为 5、period 为 mduration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
[period[0]] 只取出了第一个值然后去字典中取值,也就是说,定义 scope 的时候不一定要写 m,只要以 m 开头即可。- 最后将 (num_requests, duration) 返回出去,也就是说,赋值给了 _init_ 中的 self.num_requests, self.duration
回到 allow_request 方法
allow_request
def allow_request(self, request, view):
if self.rate is None:
return True
self.key = self.get_cache_key(request, view)
if self.key is None:
return True
self.history = self.cache.get(self.key, [])
self.now = self.timer()
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
return self.throttle_failure()
return self.throttle_success()
self.key = self.get_cache_key(request, view)
调用 get_cache_key 方法,该方法需要自己重写的,用于指定判断的依据。self.history = self.cache.get(self.key, [])
从缓存中拿数据,取出的数据是时间的列表,类似于 [时间2, 时间1],没有则赋值空列表self.now = self.timer()
timer() 是类属性,加括号调用了,获取时间。源码如下所示
timer = time.time
self.history.pop()
把所有超过时间的数据都剔除,self.history 只剩限定时间内的访问时间if len(self.history) >= self.num_requests:
大于等于配置的次数执行 throttle_failure 返回 False,否则执行 throttle_success 把当前时间插入,并返回 True