三大认证源码分析
drf的APIView在执行视图类的方法之前在dispatch中执行了三大认证
self.initial(request, *args, **kwargs)
initial的源码如下:
def initial(self, request, *args, **kwargs):
self.format_kwarg = self.get_format_suffix(**kwargs)
neg = self.perform_content_negotiation(request)
request.accepted_renderer, request.accepted_media_type = neg
version, scheme = self.determine_version(request, *args, **kwargs)
request.version, request.versioning_scheme = version, scheme
self.perform_authentication(request)
self.check_permissions(request)
self.check_throttles(request)
认证类执行源码分析
认证组件源码:
def perform_authentication(self, request):
request.user
Request类的user方法:
@property
def user(self):
if not hasattr(self, '_user'):
with wrap_attributeerrors():
self._authenticate()
return self._user
def _authenticate(self):
for authenticator in self.authenticators:
try:
user_auth_tuple = authenticator.authenticate(self)
except exceptions.APIException:
self._not_authenticated()
raise
if user_auth_tuple is not None:
self._authenticator = authenticator
self.user, self.auth = user_auth_tuple
return
self._not_authenticated()
总结:
1.配置在视图类上的认证类会在执行视图类方法之前执行,且在权限认证之前执行
2.自己写的认证类可以返回两个值或者返回一个None
3.后续可以从request.user中取出当前登录用户(前提是要在认证类中返回)
权限类执行源码分析
权限组件源码:
def check_permissions(self, request):
for permission in self.get_permissions():
if not permission.has_permission(request, self):
self.permission_denied(
request,
message=getattr(permission, 'message', None),
code=getattr(permission, 'code', None)
)
get_permissions源码:
def get_permissions(self):
return [permission() for permission in self.permission_classes]
总结:
1.权限组件的执行就是取出配置在视图类上的权限类,实例化得到对象,然后一个一个执行对象的has_perminnion方法,如果返回False,就直接结束,不再往下继续执行,权限认证不通过
2.如果视图类上没配置权限类,那么会使用配置文件中的配置
3.优先使用项目配置文件,其次使用drf内置配置文件
4.配置在视图类上的权限类会在认证类之后,频率类之前执行
频率组件源码分析
频率组件源码:
def check_throttles(self, request):
throttle_durations = []
for throttle in self.get_throttles():
if not throttle.allow_request(request, self):
throttle_durations.append(throttle.wait())
if throttle_durations:
durations = [
duration for duration in throttle_durations
if duration is not None
]
duration = max(durations, default=None)
self.throttled(request, duration)
总结:
我们自己写的频率类继承了BaseThrottle,重写了allow_request,在内部判断如果超频就返回False,如果没超频就返回True
自定义频率类
from rest_framework.throttling import BaseThrottle
import time
class MyThrottle(BaseThrottle):
USER_THROTTLE = {}
def __init__(self):
self.history = None
def allow_request(self, request, view):
ip = request.META.get('REMOTE_ADDR')
ctime = time.time()
if ip not in self.USER_THROTTLE:
self.USER_THROTTLE[ip] = [ctime]
return True
self.history = self.USER_THROTTLE.get(ip)
while self.history and ctime - self.history[-1] > 60:
self.history.pop()
if len(self.history) < 3:
self.history.insert(0, ctime)
return True
else:
return False
def wait(self):
import time
ctime = time.time()
return 60 - (ctime - self.history[-1])
SinpleRateThrottle
SimpleRateThrottle的allow_request方法
def allow_request(self, request, view):
if self.rate is None:
return True
self.key = self.get_cache_key(request, view)
if self.key is None:
return True
self.history = self.cache.get(self.key, [])
self.now = self.timer()
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
return self.throttle_failure()
return self.throttle_success()
SimpleRateThrottle的init方法
def __init__(self):
if not getattr(self, 'rate', None):
self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate)
SimpleRateThrottle的get_rate方法
def get_rate(self):
if not getattr(self, 'scope', None):
msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
self.__class__.__name__)
raise ImproperlyConfigured(msg)
try:
return self.THROTTLE_RATES[self.scope]
except KeyError:
msg = "No default throttle rate set for '%s' scope" % self.scope
raise ImproperlyConfigured(msg)
SimpleRateThrottle的parse_rate方法
def parse_rate(self, rate):
if rate is None:
return (None, None)
num, period = rate.split('/')
num_requests = int(num)
duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
return (num_requests, duration)
基于APIView写分页
class BookView(ViewSetMixin, APIView):
def list(self, request):
books = Book.objects.all()
paginator = CommonLimitOffsetPagination()
page = paginator.paginate_queryset(books, request, self)
if page is not None:
serializer = BookSerializer(instance=page, many=True)
return Response({
'total': paginator.count,
'next': paginator.get_next_link(),
'previous': paginator.get_previous_link(),
'results': serializer.data
})
异常处理
APIView的dispatch执行三大认证和视图类的方法时,如果出了异常会被异常捕获,之后做统一处理
drf中就内置了一个函数做上面的操作,但是这个函数只处理drf的异常,程序出错或者主动抛的异常都不会被处理,如果需要处理的话就需要重新写一个函数
def common_exception_handler(exc, context):
request = context.get('request')
id = request.user
ip = request.META.get('REMOTE_ADDR')
print(f'{datetime.now()},{id},{ip},{request.method},{request.path},{context.get("view").__class__.__name__},{str(exc)}')
res = exception_handler(exc, context)
if res:
res = Response(data={'code': 888, 'msg': res.data.get('detail', '请联系系统管理员')})
else:
res = Response(data={'code': 999, 'msg': '系统错误,请联系系统管理员'})
return res
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)