DRF之排序类源码分析
【一】排序类介绍
- 在Django REST framework (DRF)中,排序类用于处理API端点的排序操作,允许客户端请求按特定字段对数据进行升序或降序排序。
- 排序类是一种特殊的过滤类
- DRF提供了内置的排序类,并且你也可以自定义排序类以满足特定的需求。
【二】内置排序类OrderingFilter
rest_framework.filters.OrderingFilter
:这是DRF默认的排序类。
- 它允许客户端在API请求中使用
?ordering=
参数来指定要排序的字段。
- 例如,
?ordering=-created_at
将按 created_at
字段降序排序。
【1】使用
| from rest_framework.filters import OrderingFilter |
| |
| class MyModelListView(ListAPIView): |
| queryset = MyModel.objects.all() |
| serializer_class = MyModelSerializer |
| filter_backends = [OrderingFilter] |
| ordering_fields = ['field1', 'field2'] |
- 执行流程
- 当一个API请求到达时,Django REST framework将会执行视图的
get_queryset
方法来获取查询集。
- 如果使用了
OrderingFilter
排序类,它会检查请求中是否包含 ?ordering=
参数。
- 如果请求中包含
?ordering=
参数,OrderingFilter
会根据参数的值对查询集进行排序。
- 排序后的查询集将传递给视图进行进一步处理和返回。
【2】源码分析
| class OrderingFilter(BaseFilterBackend): |
| |
| ordering_param = api_settings.ORDERING_PARAM |
| ordering_fields = None |
| ordering_title = _('Ordering') |
| ordering_description = _('Which field to use when ordering the results.') |
| template = 'rest_framework/filters/ordering.html' |
| |
| |
| def get_ordering(self, request, queryset, view): |
| """ |
| Ordering is set by a comma delimited ?ordering=... query parameter. |
| |
| The `ordering` query parameter can be overridden by setting |
| the `ordering_param` value on the OrderingFilter or by |
| specifying an `ORDERING_PARAM` value in the API settings. |
| """ |
| |
| params = request.query_params.get(self.ordering_param) |
| if params: |
| |
| fields = [param.strip() for param in params.split(',')] |
| |
| ordering = self.remove_invalid_fields(queryset, fields, view, request) |
| |
| if ordering: |
| |
| return ordering |
| |
| |
| |
| return self.get_default_ordering(view) |
| |
| |
| def get_default_ordering(self, view): |
| |
| ordering = getattr(view, 'ordering', None) |
| if isinstance(ordering, str): |
| |
| return (ordering,) |
| |
| return ordering |
| |
| |
| def get_default_valid_fields(self, queryset, view, context={}): |
| |
| |
| |
| |
| if hasattr(view, 'get_serializer_class'): |
| try: |
| |
| serializer_class = view.get_serializer_class() |
| except AssertionError: |
| |
| |
| serializer_class = None |
| else: |
| serializer_class = getattr(view, 'serializer_class', None) |
| |
| |
| if serializer_class is None: |
| msg = ( |
| "Cannot use %s on a view which does not have either a " |
| "'serializer_class', an overriding 'get_serializer_class' " |
| "or 'ordering_fields' attribute." |
| ) |
| raise ImproperlyConfigured(msg % self.__class__.__name__) |
| |
| |
| model_class = queryset.model |
| |
| model_property_names = [ |
| |
| attr for attr in dir(model_class) if isinstance(getattr(model_class, attr), property) and attr != 'pk' |
| ] |
| |
| |
| return [ |
| (field.source.replace('.', '__') or field_name, field.label) |
| for field_name, field in serializer_class(context=context).fields.items() |
| if ( |
| not getattr(field, 'write_only', False) and |
| not field.source == '*' and |
| field.source not in model_property_names |
| ) |
| ] |
| |
| |
| def get_valid_fields(self, queryset, view, context={}): |
| |
| valid_fields = getattr(view, 'ordering_fields', self.ordering_fields) |
| |
| |
| if valid_fields is None: |
| |
| |
| return self.get_default_valid_fields(queryset, view, context) |
| |
| |
| elif valid_fields == '__all__': |
| |
| |
| |
| valid_fields = [ |
| (field.name, field.verbose_name) for field in queryset.model._meta.fields |
| ] |
| valid_fields += [ |
| (key, key.title().split('__')) |
| for key in queryset.query.annotations |
| ] |
| else: |
| |
| valid_fields = [ |
| (item, item) if isinstance(item, str) else item |
| for item in valid_fields |
| ] |
| |
| |
| return valid_fields |
| |
| |
| def remove_invalid_fields(self, queryset, fields, view, request): |
| |
| |
| valid_fields = [item[0] for item in self.get_valid_fields(queryset, view, {'request': request})] |
| |
| |
| def term_valid(term): |
| if term.startswith("-"): |
| term = term[1:] |
| return term in valid_fields |
| |
| |
| return [term for term in fields if term_valid(term)] |
| |
| |
| def filter_queryset(self, request, queryset, view): |
| |
| ordering = self.get_ordering(request, queryset, view) |
| |
| if ordering: |
| |
| return queryset.order_by(*ordering) |
| |
| |
| return queryset |
| |
| |
| def get_template_context(self, request, queryset, view): |
| |
| current = self.get_ordering(request, queryset, view) |
| |
| current = None if not current else current[0] |
| |
| |
| options = [] |
| |
| |
| context = { |
| |
| 'request': request, |
| |
| 'current': current, |
| |
| 'param': self.ordering_param, |
| } |
| |
| |
| |
| for key, label in self.get_valid_fields(queryset, view, context): |
| |
| |
| options.append((key, '%s - %s' % (label, _('ascending')))) |
| |
| |
| options.append(('-' + key, '%s - %s' % (label, _('descending')))) |
| |
| |
| context['options'] = options |
| |
| |
| return context |
| |
| |
| def to_html(self, request, queryset, view): |
| |
| |
| template = loader.get_template(self.template) |
| |
| |
| context = self.get_template_context(request, queryset, view) |
| |
| |
| return template.render(context) |
| |
| def get_schema_fields(self, view): |
| |
| |
| assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' |
| |
| |
| assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' |
| |
| |
| return [ |
| |
| coreapi.Field( |
| |
| name=self.ordering_param, |
| |
| required=False, |
| |
| location='query', |
| |
| schema=coreschema.String( |
| title=force_str(self.ordering_title), |
| description=force_str(self.ordering_description) |
| ) |
| ) |
| ] |
| |
| |
| def get_schema_operation_parameters(self, view): |
| return [ |
| { |
| |
| 'name': self.ordering_param, |
| |
| 'required': False, |
| |
| 'in': 'query', |
| |
| 'description': force_str(self.ordering_description), |
| |
| 'schema': { |
| 'type': 'string', |
| }, |
| }, |
| ] |
【三】自定义排序类
【1】使用
| from rest_framework.filters import OrderingFilter |
| |
| class CustomOrderingFilter(OrderingFilter): |
| def get_ordering(self, request, queryset, view): |
| |
| ordering = request.query_params.get('ordering') |
| if ordering: |
| |
| return [ordering] |
| return super().get_ordering(request, queryset, view) |
| class MyModelListView(ListAPIView): |
| queryset = MyModel.objects.all() |
| serializer_class = MyModelSerializer |
| filter_backends = [CustomOrderingFilter] |
| ordering_fields = ['field1', 'field2'] |
【2】分析
- 继承 OrderingFilter
- 重写 get_ordering 方法
- 自定义 过滤条件
- 将过滤后的视图视图集返回给视图函数进一步调用
本文作者:ssrheart
本文链接:https://www.cnblogs.com/ssrheart/p/18153599
版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 2.5 中国大陆许可协议进行许可。
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步