Django-restframework源码分析笔记
在 APIview 类中的属性有一条是:
authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES
定义了一个类属性为authentication_classes,值从api_settings中的一个属性,查看api_settings发现也是个类对象,
api_settings = APISettings(None, DEFAULTS, IMPORT_STRINGS)
是有 APISettings 实例化出来的一个对象,传入了三个参数,重点为后面两个参数
-
DEFAULTS是一个字典,其中有一个 key为DEFAULT_AUTHENTICATION_CLASSES,对应的值为一个元祖:
(
'rest_framework.authentication.SessionAuthentication',
'rest_framework.authentication.BasicAuthentication'
) -
IMPORT_STRINGS 为一个元祖,里面也包含有DEFAULT_AUTHENTICATION_CLASSES,这个元祖是一些列的设置参数通过字符串的形式导入(可能采用字符串导入表示法的设置列表。)
将这三个参数传到 APISettings 中进行初始化:
def init(self, user_settings=None, defaults=None, import_strings=None):
if user_settings:
self._user_settings = self.__check_user_settings(user_settings)
self.defaults = defaults or DEFAULTS
self.import_strings = import_strings or IMPORT_STRINGS
self._cached_attrs = set()
将对象赋值给 api_settings,之后通过该对象获取相关属性。
那么在 APIView 中是如何获取到 api_settings.DEFAULT_AUTHENTICATION_CLASS的呢?
去 APISettings 类中查找该属性,发现并没有该属性,但是具有__getattr__方法,该方法接收一个属性名,返回该属性的值,具体如下:
def getattr(self, attr):
if attr not in self.defaults:
raise AttributeError("Invalid API setting: '%s'" % attr)
try:
# Check if present in user settings
val = self.user_settings[attr]
except KeyError:
# Fall back to defaults
val = self.defaults[attr]
# Coerce import strings into classes
if attr in self.import_strings:
val = perform_import(val, attr)
# Cache the result
self._cached_attrs.add(attr)
setattr(self, attr, val)
return val
在 try 中执行self.user_settings[attr],是个类方法,使用property装饰器装饰为属性:
@property
def user_settings(self):
if not hasattr(self, '_user_settings'):
self._user_settings = getattr(settings, 'REST_FRAMEWORK', {})
return self._user_settings
该方法里面经过反射判断, self中是否含有_user_settings,因为在实例化该对象时,传入的三个参数第一个为 None,所以在 init 方法中self 没有该属性,所以执行self._user_settings = getattr(settings, 'REST_FRAMWORK', {}),该 settings 是个模块,该模块在执行项目的时候会加载全局配置和自定义配置,而在自定义配置中的INSTALLED_APPS中注册了rest_framework,所以可以返回需要寻找的值,而在下面的 if attr in self.import_strings:,我前面特意说了第三个参数里面也有DEFAULT_AUTHENTICATION_CLASSES,所以这个判断也成立,执行里面的val = perform_import(val, attr)方法,方法如下:
def perform_import(val, setting_name):
"""
If the given setting is a string import notation,
then perform the necessary import or imports.
"""
if val is None:
return None
elif isinstance(val, six.string_types):
return import_from_string(val, setting_name)
elif isinstance(val, (list, tuple)):
return [import_from_string(item, setting_name) for item in val]
return val
该方法把之前获取的 val 当做参数传进去,在里面进行 val 的类型判断,该 val 是一个列表,于是执行下面的方法,列表生成式生成一个列表,之后把结果进行缓存(把该 attr 加到一个集合中),之后给 api_settings对象设置一个 key 为 attr,value 为 val 的属性,并将 val 返回。这个值根据后续的执行判断是一个列表,列表里面有两个类的内存地址,分别是[SessionAuthentication, BasicAuthentication],
在return [import_from_string(item, setting_name) for item in val]该列表生成式中,item 是上面列表中的类,setting_name是DEFAULT_AUTHENTICATION_CLASSES,
最终结果就是将这两个类导入到内存中,也就是[SessionAuthentication, BasicAuthentication],上面那个列表里面是类似于['rest_framework.authentication.SessionAuthentication', 'rest_framework.authentication.BasicAuthentication']这样的。
最终得到的结果为:
authentication_classes = [SessionAuthentication, BasicAuthentication]。
之后再 APIView 中的 dispatch 方法中执行到 self.initial时,把经过 Request 类封装产生的对象和相关参数传进去,
initial 具体方法:
def initial(self, request, *args, kwargs):
"""
Runs anything that needs to occur prior to calling the method handler.
"""
self.format_kwarg = self.get_format_suffix(kwargs)
# Perform content negotiation and store the accepted info on the request
neg = self.perform_content_negotiation(request)
request.accepted_renderer, request.accepted_media_type = neg
# Determine the API version, if versioning is in use.
version, scheme = self.determine_version(request, *args, **kwargs)
request.version, request.versioning_scheme = version, scheme
# Ensure that the incoming request is permitted
self.perform_authentication(request)
self.check_permissions(request)
self.check_throttles(request)
这个方法可以说是 restframework中最重要的方法了,因为这里面有认证、权限以及频率的校验。
- 先去self.perform_authentication方法中看。
该方法将Request 对象当做参数传进去,在该方法里面只有一行代码:
request.user,这个方法是 Request 对象的方法,于是去 Request 类中寻找。
request.user是经过 property 装饰器装饰成的数据属性,
在该方法里面的代码如下:
@property
def user(self):
"""
Returns the user associated with the current request, as authenticated
by the authentication classes provided to the request.
"""
if not hasattr(self, '_user'):
with wrap_attributeerrors():
self._authenticate()
return self._user
首先判断 self, 也就是 request.user的 request 是否有_user这个属性,在 Request 类中的初始化方法中寻找,代码如下:
def init(self, request, parsers=None, authenticators=None,
negotiator=None, parser_context=None):
assert isinstance(request, HttpRequest), (
'The request
argument must be an instance of '
'django.http.HttpRequest
, not {}.{}
.'
.format(request.class.module, request.class.name)
)
self._request = request
self.parsers = parsers or ()
self.authenticators = authenticators or ()
self.negotiator = negotiator or self._default_negotiator()
self.parser_context = parser_context
self._data = Empty
self._files = Empty
self._full_data = Empty
self._content_type = Empty
self._stream = Empty
if self.parser_context is None:
self.parser_context = {}
self.parser_context['request'] = self
self.parser_context['encoding'] = request.encoding or settings.DEFAULT_CHARSET
force_user = getattr(request, '_force_auth_user', None)
force_token = getattr(request, '_force_auth_token', None)
if force_user is not None or force_token is not None:
forced_auth = ForcedAuthentication(force_user, force_token)
self.authenticators = (forced_auth,)
该对象中没有_user属性,进入到_authenticate方法中,
def _authenticate(self):
"""
Attempt to authenticate the request using each authentication instance
in turn.
"""
for authenticator in self.authenticators:
try:
user_auth_tuple = authenticator.authenticate(self)
except exceptions.APIException:
self._not_authenticated()
raise
if user_auth_tuple is not None:
self._authenticator = authenticator
self.user, self.auth = user_auth_tuple
return
self._not_authenticated()
在该方法中,之前花了大力气找到的authentication_classes = [SessionAuthentication, BasicAuthentication]可以派上用场了,
for authenticator in self.authenticators:
self.authenticators是 request 对象的一个属性,在初始化方法中,
self.authenticators = authenticators or {},这个 authenticators 是实例化 Request 对象时传进来的参数,回到APIView 中的 dispatch 方法里面的 self.initialize_request方法,代码如下:
def initialize_request(self, request, *args, **kwargs):
"""
Returns the initial request object.
"""
parser_context = self.get_parser_context(request)
return Request(
request,
parsers=self.get_parsers(),
authenticators=self.get_authenticators(),
negotiator=self.get_content_negotiator(),
parser_context=parser_context
)
在实例化 Request 对象时,传进了五个参数,一个是原生 request 对象,重点关注authenticators = self.get_authenticators(),点进去查看源码:
def get_authenticators(self):
"""
Instantiates and returns the list of authenticators that this view can use.
"""
return [auth() for auth in self.authentication_classes]
这个方法里面是个列表生成式,里面有个self.authentication_classes,这是 APIView 类中的一个方法,查看源码,这个self.authentication_classes就是一个类属性,就是authentication_classes,所以获取到这个属性对应的值为[SessionAuthentication, BasicAuthentication],
循环该列表,因为该列表中的对象是类,所以加括号就是实例化产生对应的对象。之后再_authenticate方法中循环该列表,就是一个个类对象并调用类对象的authenticate(self)方法,将 Request 对象作为参数传进去,
现在找具体生成的类对象,用SessionAuthentication
authenticate 方法如下:
def authenticate(self, request):
"""
Returns a User
if the request session currently has a logged in user.
Otherwise returns None
.
"""
# Get the session-based user from the underlying HttpRequest object
user = getattr(request._request, 'user', None)
# Unauthenticated, CSRF validation not required
if not user or not user.is_active:
return None
self.enforce_csrf(request)
# CSRF passed with authenticated user
return (user, None)
该方法的参数就是经过 Request 实例化产生的 request,该方法就是具体的认证方法,
当 request._request.user不为空是,执行self.enforce_csrf方法,具体代码如下:
def enforce_csrf(self, request):
"""
Enforce CSRF validation for session based authentication.
"""
check = CSRFCheck()
# populates request.META['CSRF_COOKIE'], which is used in process_view()
check.process_request(request)
reason = check.process_view(request, None, (), {})
if reason:
# CSRF failed, bail with explicit error message
raise exceptions.PermissionDenied('CSRF Failed: %s' % reason)
在该方法里面实例化了一个 CSRFCheck对象,mro 列表查到,在它的父父类中找到 init 方法,然后调用check.process_request方法,这个方法很眼熟,就是出中间件request的方法,具体代码如下(也是通过 mro 列表查到的):
def process_request(self, request):
csrf_token = self._get_token(request)
# print(csrf_token, '*****')
if csrf_token is not None:
# Use same token next time.
request.META['CSRF_COOKIE'] = csrf_token
该方法是处理 csrf_token的,首先执行self._get_token(request)方法,获取 token,具体代码如下:
def _get_token(self, request):
if settings.CSRF_USE_SESSIONS:
try:
return request.session.get(CSRF_SESSION_KEY)
except AttributeError:
raise ImproperlyConfigured(
'CSRF_USE_SESSIONS is enabled, but request.session is not '
'set. SessionMiddleware must appear before CsrfViewMiddleware '
'in MIDDLEWARE%s.' % ('_CLASSES' if settings.MIDDLEWARE is None else '')
)
else:
try:
cookie_token = request.COOKIES[settings.CSRF_COOKIE_NAME]
except KeyError:
return None
csrf_token = _sanitize_token(cookie_token)
if csrf_token != cookie_token:
# Cookie token needed to be replaced;
# the cookie needs to be reset.
request.csrf_cookie_needs_reset = True
return csrf_token
首先进行 if settings.CSRF_USE_SESSIONS判断,settings是由 LazySettings 类产生的一个对象,查找 init 方法,在父类LazyObject中找到 init 方法,代码如下:
def init(self):
# Note: if a subclass overrides init(), it will likely need to
# override copy() and deepcopy() as well.
self._wrapped = empty
if settings.CSRF_USE_TOKEN,因为该类中以及父类中没有该属性,但是在 LazySettings中有__getattr__方法,于是会调用该方法来获取属性值,代码如下:
def getattr(self, name):
"""
Return the value of a setting and cache it in self.dict.
"""
if self._wrapped is empty:
self._setup(name)
val = getattr(self._wrapped, name)
self.dict[name] = val
return val
该函数里面会判断self._wrapped是否为空对象,因为在父类 LazyObject 中的 init 方法中设置了self._wrapped = empty,所以判断为 True,会执行self._setup(name),该 name 是要查找的属性名,在这里也就是 CSRF_USE_SESSIONS,_setup代码如下:
def _setup(self, name=None):
"""
Load the settings module pointed to by the environment variable. This
is used the first time we need any settings at all, if the user has not
previously configured the settings manually.
"""
settings_module = os.environ.get(ENVIRONMENT_VARIABLE)
if not settings_module:
desc = ("setting %s" % name) if name else "settings"
raise ImproperlyConfigured(
"Requested %s, but settings are not configured. "
"You must either define the environment variable %s "
"or call settings.configure() before accessing settings."
% (desc, ENVIRONMENT_VARIABLE))
self._wrapped = Settings(settings_module)
settings_module = os.environ.get(ENVIRONMENT_VARIABLE)是为了获取相关配置模块,然后生成 Settings(settings_module)对象赋值给 self._wrapped,所以 self._wrapped是一个 Settings 对象,所以在执行val = getattr(self._wrapped, name)时就是在该 Settings 中找到属性名为CSRF_USE_TOKEN的值,在 django.conf.settings的全局配置中该值为 False,所以会走下面的 else 代码块。具体如下(_get_token方法):
else:
try:
cookie_token = request.COOKIES[settings.CSRF_COOKIE_NAME]
except KeyError:
return None
csrf_token = _sanitize_token(cookie_token)
if csrf_token != cookie_token:
# Cookie token needed to be replaced;
# the cookie needs to be reset.
request.csrf_cookie_needs_reset = True
return csrf_token
在 try 中会通过封装的 Request 对象来获取 cookie,在全局配置中,CSRF_COOKIE_NAME的值为
csrftoken,会把这个值当做 key 来查找值并赋值给cookie_token,之后执行该函数,代码如下:
def _sanitize_token(token):
# Allow only ASCII alphanumerics
if re.search('[^a-zA-Z0-9]', force_text(token)):
return _get_new_csrf_token()
elif len(token) == CSRF_TOKEN_LENGTH:
return token
elif len(token) == CSRF_SECRET_LENGTH:
# Older Django versions set cookies to values of CSRF_SECRET_LENGTH
# alphanumeric characters. For backwards compatibility, accept
# such values as unsalted secrets.
# It's easier to salt here and be consistent later, rather than add
# different code paths in the checks, although that might be a tad more
# efficient.
return _salt_cipher_secret(token)
return _get_new_csrf_token()