[Django] 22 - DRF: Add Tag API
为下一步给 Ingredient 添加 Label API 做准备。
Model
- test_models.py
def test_create_tag(self): """Test creating a tag is successful."""
user = create_user() tag = models.Tag.objects.create(user=user, name='Tag1') self.assertEqual(str(tag), tag.name)
- models.py
# 90-1 define db. class Tag(models.Model): """Tag for filtering recipes."""
name = models.CharField(max_length=255) user = models.ForeignKey( settings.AUTH_USER_MODEL, on_delete=models.CASCADE, ) def __str__(self): return self.name
Update Recipe.
class Recipe(models.Model): """Recipe object.""" user = models.ForeignKey( settings.AUTH_USER_MODEL, on_delete=models.CASCADE, ) title = models.CharField(max_length=255) description = models.TextField(blank=True) time_minutes = models.IntegerField() price = models.DecimalField(max_digits=5, decimal_places=2) link = models.CharField(max_length=255, blank=True)
# 90-2 add, then register in admin. tags = models.ManyToManyField('Tag') def __str__(self): return self.title
View
- test_tags_api.py
TAGS_URL = reverse('recipe:tag-list') def detail_url(tag_id): """Create and return a tag detail url.""" return reverse('recipe:tag-detail', args=[tag_id]) def create_user(email='user@example.com', password='testpass123'): """Create and return a user.""" return get_user_model().objects.create_user(email=email, password=password)
class PublicTagsApiTests(TestCase): """Test unauthenticated API requests.""" def setUp(self): # This is a test client. self.client = APIClient() def test_auth_required(self): """Test auth is required for retrieving tags.""" res = self.client.get(TAGS_URL) self.assertEqual(res.status_code, status.HTTP_401_UNAUTHORIZED)
本教程还是以注册登录为前提再测试。
class PrivateTagsApiTests(TestCase): """Test authenticated API requests.""" def setUp(self): self.user = create_user() self.client = APIClient() self.client.force_authenticate(self.user) def test_retrieve_tags(self): """Test retrieving a list of tags.""" Tag.objects.create(user=self.user, name='Vegan') Tag.objects.create(user=self.user, name='Dessert') res = self.client.get(TAGS_URL) tags = Tag.objects.all().order_by('-name') serializer = TagSerializer(tags, many=True) self.assertEqual(res.status_code, status.HTTP_200_OK) self.assertEqual(res.data, serializer.data) def test_tags_limited_to_user(self): """Test list of tags is limited to authenticated user.""" user2 = create_user(email='user2@example.com') Tag.objects.create(user=user2, name='Fruity') tag = Tag.objects.create(user=self.user, name='Comfort Food') res = self.client.get(TAGS_URL) self.assertEqual(res.status_code, status.HTTP_200_OK) self.assertEqual(len(res.data), 1) self.assertEqual(res.data[0]['name'], tag.name) self.assertEqual(res.data[0]['id'], tag.id)
# ---------------------------------------------------------------
def test_update_tag(self): """Test updating a tag.""" tag = Tag.objects.create(user=self.user, name='After Dinner') payload = {'name': 'Dessert'} url = detail_url(tag.id) res = self.client.patch(url, payload) self.assertEqual(res.status_code, status.HTTP_200_OK) tag.refresh_from_db() # <---- self.assertEqual(tag.name, payload['name']) def test_delete_tag(self): """Test deleting a tag.""" tag = Tag.objects.create(user=self.user, name='Breakfast') url = detail_url(tag.id) res = self.client.delete(url) self.assertEqual(res.status_code, status.HTTP_204_NO_CONTENT) tags = Tag.objects.filter(user=self.user) self.assertFalse(tags.exists())
- serializers.py
Recipe部分也要相应的添加Tag。
from rest_framework import serializers from core.models import ( Recipe, Tag, ) class TagSerializer(serializers.ModelSerializer): """Serializer for tags.""" class Meta: model = Tag fields = ['id', 'name'] read_only_fields = ['id']
# ---------------------------------------------------------------------- class RecipeSerializer(serializers.ModelSerializer): """Serializer for recipes.""" tags = TagSerializer(many=True, required=False) class Meta: model = Recipe fields = ['id', 'title', 'time_minutes', 'price', 'link', 'tags'] read_only_fields = ['id']
class RecipeDetailSerializer(RecipeSerializer): """Serializer for recipe detail view.""" class Meta(RecipeSerializer.Meta): fields = RecipeSerializer.Meta.fields + ['description'] def _get_or_create_tags(self, tags, recipe): """Handle getting or creating tags as needed.""" auth_user = self.context['request'].user for tag in tags: tag_obj, created = Tag.objects.get_or_create( user=auth_user, **tag, ) recipe.tags.add(tag_obj) def create(self, validated_data): """Create a recipe.""" tags = validated_data.pop('tags', []) recipe = Recipe.objects.create(**validated_data) self._get_or_create_tags(tags, recipe) return recipe def update(self, instance, validated_data): """Update recipe.""" tags = validated_data.pop('tags', None) if tags is not None: instance.tags.clear() self._get_or_create_tags(tags, instance) for attr, value in validated_data.items(): setattr(instance, attr, value) instance.save() return instance
- views.py
这里考虑了 minxins。
class TagViewSet(mixins.DestroyModelMixin, mixins.UpdateModelMixin, # ----> 让test中的patch成为available mixins.ListModelMixin, viewsets.GenericViewSet): """Manage tags in the database."""
serializer_class = serializers.TagSerializer queryset = Tag.objects.all() authentication_classes = [TokenAuthentication] permission_classes = [IsAuthenticated] def get_queryset(self): """Filter queryset to authenticated user.""" return self.queryset.filter(user=self.request.user).order_by('-name')
最后注册router即可。
之后注意Recipe中因为tag而导致的一些变化。
Nested Serializers 的局限性
- Read only by default
- Custom logic to make writable
View
- test_recipe_api.py
def test_create_recipe_with_new_tags(self): """Test creating a recipe with new tags."""
payload = { 'title': 'Thai Prawn Curry', 'time_minutes': 30, 'price': Decimal('2.50'), 'tags': [{'name': 'Thai'}, {'name': 'Dinner'}], } res = self.client.post(RECIPES_URL, payload, format='json')
# post后再取出来,看看tag name有没有 self.assertEqual(res.status_code, status.HTTP_201_CREATED)
recipes = Recipe.objects.filter(user=self.user) self.assertEqual(recipes.count(), 1) recipe = recipes[0] self.assertEqual(recipe.tags.count(), 2) for tag in payload['tags']: exists = recipe.tags.filter( name=tag['name'], # 'Thai', 'Dinner' user=self.user, ).exists() self.assertTrue(exists)
def test_create_recipe_with_existing_tags(self): """Test creating a recipe with existing tag."""
tag_indian= Tag.objects.create(user=self.user, name='Indian') # 要事先创建好了一个 ----> payload = { 'title': 'Pongal', 'time_minutes': 60, 'price': Decimal('4.50'), 'tags': [{'name': 'Indian'}, {'name': 'Breakfast'}], } res = self.client.post(RECIPES_URL, payload, format='json')
self.assertEqual(res.status_code, status.HTTP_201_CREATED)
recipes = Recipe.objects.filter(user=self.user) self.assertEqual(recipes.count(), 1) recipe = recipes[0] self.assertEqual(recipe.tags.count(), 2) self.assertIn(tag_indian, recipe.tags.all()) for tag in payload['tags']: exists = recipe.tags.filter( name=tag['name'], user=self.user, ).exists() self.assertTrue(exists)
-
- serializers.py
针对 Tag.object.create, 序列化也要相应的添加 create。(见代码中的上图)
def test_create_tag_on_update(self): """Test create tag when updating a recipe."""
recipe = create_recipe(user=self.user) payload = {'tags': [{'name': 'Lunch'}]} url = detail_url(recipe.id) res = self.client.patch(url, payload, format='json') self.assertEqual(res.status_code, status.HTTP_200_OK) new_tag = Tag.objects.get(user=self.user, name='Lunch') self.assertIn(new_tag, recipe.tags.all())
def test_update_recipe_assign_tag(self): """Test assigning an existing tag when updating a recipe."""
tag_breakfast = Tag.objects.create(user=self.user, name='Breakfast') recipe = create_recipe(user=self.user) recipe.tags.add(tag_breakfast) tag_lunch= Tag.objects.create(user=self.user, name='Lunch')
payload = {'tags': [{'name': 'Lunch'}]} url = detail_url(recipe.id) res = self.client.patch(url, payload, format='json') # <---- 改变了url所指的recipe的tags属性 self.assertEqual(res.status_code, status.HTTP_200_OK) self.assertIn(tag_lunch, recipe.tags.all()) self.assertNotIn(tag_breakfast, recipe.tags.all())
def test_clear_recipe_tags(self): """Test clearing a recipes tags."""
tag = Tag.objects.create(user=self.user, name='Dessert') recipe = create_recipe(user=self.user) recipe.tags.add(tag) payload = {'tags': []} url = detail_url(recipe.id) res = self.client.patch(url, payload, format='json') self.assertEqual(res.status_code, status.HTTP_200_OK) self.assertEqual(recipe.tags.count(), 0)
AssertionError: The '.update()' method does not support writable nested fields by default.
So, we need to "overwrite".
这里重新定义了update的方法,采用的是"overwrite' 的模式。
def _get_or_create_tags(self, tags, recipe): """Handle getting or creating tags as needed."""
auth_user = self.context['request'].user for tag in tags: tag_obj, created = Tag.objects.get_or_create( user=auth_user, **tag, ) recipe.tags.add(tag_obj) def create(self, validated_data): """Create a recipe."""
tags = validated_data.pop('tags', []) recipe = Recipe.objects.create(**validated_data) self._get_or_create_tags(tags, recipe) return recipe def update(self, instance, validated_data): """Update recipe."""
tags = validated_data.pop('tags', None) if tags is not None: instance.tags.clear() # overwrite的预兆 self._get_or_create_tags(tags, instance) for attr, value in validated_data.items(): setattr(instance, attr, value) instance.save() return instance
End.