cookiecutter-flask生成的框架里边自带了一个CRUDMixin类

单元测试的必要性

之前曾经写过一篇讲单元测试的,正好最近也在实践和摸索。我似乎有种洁癖,就是我会严格遵守流程性的东西,比如测试,注释和文档等。目前就职的公司在我接手项目的时候是没有一行单元测试的,我挺诧异的。我大概也能估计到目前国内的python项目团队很多是不太规范的。当然流程不够规范可能不会有什么大问题,但是绝对会给代码维护造成困难,我是踩了坑的,所以要保持谨慎。虽然这次工期比较紧,半个月内搞出来一个CRM系统,但是目前这一周多的进度还是严格遵守了规范并完善了测试,并且进展还是挺顺利的,感觉单元测试确实能减少bug出现率。至于会不会浪费开发和维护时间,还需要自己权衡。

flask单元测试 由于我直接偷懒使用了[cookiecutter-flask]生成框架,自带了一个tests文件夹,我就直接照葫芦画瓢就好。首先在tests文件夹下有一个py.test使用的conftest.py文件(推荐你使用pytest做测试,相当便捷) # -*- coding: utf-8 -*- """Defines fixtures available to all tests. http://doc.pytest.org/en/latest/fixture.html?highlight=fixture """ import pytest from webtest import TestApp from crm_backend.app import create_app from crm_backend.database import db as _db from crm_backend.settings import TestConfig from .factories import UserFactory @pytest.yield_fixture(scope='function') def app(): """An application for the tests.""" _app = create_app(TestConfig) ctx = _app.test_request_context() ctx.push() yield _app ctx.pop() @pytest.yield_fixture def client(app): """A Flask test client. An instance of :class:`flask.testing.TestClient` by default. """ with app.test_client() as client: yield client @pytest.fixture(scope='function') def testapp(app): """A Webtest app.""" return TestApp(app) @pytest.yield_fixture(scope='function') def db(app): """A database for the tests.""" _db.app = app with app.app_context(): _db.create_all() yield _db # Explicitly close DB connection _db.session.close() _db.drop_all() @pytest.fixture def user(db): """A user for the tests.""" user = UserFactory(password='myprecious') db.session.commit() return user

对于普通的python函数或者类,可以直接使用简单的test函数,由于编写的是web项目,麻烦的地方就在于和数据库以及后端请求的交互。在cookiecutter中使用了pytest的fixture特性来处理和数据库的交互问题,使用了webtest库来处理请求问题。分别来看看如何测试Model和View层,我这里使用了flask restful,所以改成了api层。

使用py.test测试Model层

下边是cookiecutter-flask自动生成的关于user的Model单元测试。这里有一点需要注意,测试类TestUser使用了fixture db,这个fixture在conftest.py中定义的,使用的测试配置 SQLALCHEMY_DATABASE_URI = ‘sqlite:///:memory:’,所有操作都是在内存中进行,db使用这个模拟的sqllite内存数据库。其他貌似也没啥好说的了,都是基本的crud操作,照着写测试就行,没啥好说的:

# -*- coding: utf-8 -*- """Model unit tests.""" import datetime as dt import pytest from crm_backend.user.models import Role, User from .factories import UserFactory @pytest.mark.usefixtures('db') class TestUser: """User tests.""" def test_get_by_id(self): """Get user by ID.""" user = User('foo', 'foo@bar.com') user.save() retrieved = User.get_by_id(user.id) assert retrieved == user def test_created_at_defaults_to_datetime(self): """Test creation date.""" user = User(username='foo', email='foo@bar.com') user.save() assert bool(user.created_at) assert isinstance(user.created_at, dt.datetime) 测试flask接口

这里使用的是 WebTest 这个库进行测试,而没有使用flask自带的test_client,WebTest还是比较方便的,常见的也就是get、post、put方法和请求数据的提交,也比较简单,代码见示例:

# -*- coding: utf-8 -*- """ flask flask_restful api的单元测试 """ import pytest from crm_backend.extensions import api from crm_backend.advertiser.api import ( AdvertiserListApi, AdvertiserApi, BusinessLeadListApi, ) from crm_backend.advertiser.models import ( Advertiser, ) from crm_backend.employee.models import ( Employee, ) @pytest.mark.usefixtures('db') class TestAdvertiserListApi: def test_get(self, testapp): self.test_post(testapp) # 先创建一个advertiser url = api.url_for(AdvertiserListApi) res = testapp.get( url, { # 'fields': "id,name", # 'filter': """[{"field":"status","op":"eq","q":1}]""", 'limit': 1000, # 'order': "", 'page': 1 }, expect_errors=True ) assert len(res.json['data']['items']) == 1 def test_post(self, testapp): url = api.url_for(AdvertiserListApi) bd = Employee.create( name='e1', email='e1@bar.com', password='foobarbaz123', team=Employee.TeamEnum.__dict__['CN-Beijing1'], is_leader=True, # set leader role=Employee.RoleEnum.BD ) res = testapp.post_json( url, { 'name': 'advertiser_wang', 'contact_name': 'xiaoliu', 'phone': '18818881888', 'email': 'tes@qq.com', }, expect_errors=True ) a = Advertiser.get_by_id(1) assert res.json['data']['id'] == a.id assert a.name == 'advertiser_wang' assert a.bd == bd assert a.is_client return a @pytest.mark.usefixtures('db') class TestAdvertiserApi: def test_get(self, testapp): a = Advertiser.create(name='adervertiser_wang') url = api.url_for(AdvertiserApi, advertiser_id=a.id) res = testapp.get(url) assert res.json['id'] == str(a.id) def test_put(self, testapp): a = TestAdvertiserListApi().test_post(testapp) # 先创建个用户再更新 url = api.url_for(AdvertiserApi, advertiser_id=a.id) res = testapp.put_json( url, { 'name': 'new_advertiser_wang', 'contact_name': 'xiaoliu', 'phone': '18818881888', 'email': 'tes@qq.com', }, expect_errors=True ) # 测试名称已经更新 assert Advertiser.get_by_id(a.id).name == 'new_advertiser_wang' 使用marshmallow.Schema dumps返回数据

marshmallow is an ORM/ODM/framework-agnostic library for converting complex datatypes, such as objects, to and from native Python datatypes.

阅读flask restful文档的时候发现提到了这么个 marshmallow 东西我就直接在项目中使用了。

在做后台接口时,一般会碰到两个问题,一个就是参数(表单)验证,还有一个就是数据返回。参数或者表单验证都可以用wtforms完成,或者可以尝试flask eve作者写的看门狗 Cerberus ,这个Cerberus是专门用来搞字段校验的,不涉及表单。数据返回可能不同项目有不同的做法。

marshmallow的作用就是用来序列化自定义的一些Python类实例。比如我们从数据库用sqlalchemy查到一个对象列表以后,需要按照指定格式返回前端需要的数据和类型,之前的做法都是自己用函数转成个dict,现在这种模式化的东西可以直接使用marshmalow里的Schema来做,而且非常灵活,需要返回不同格式或者类型的数据直接可以自定义一个schema解决。给个官方文档的例子:

from datetime import date from marshmallow import Schema, fields, pprint class ArtistSchema(Schema): name = fields.Str() class AlbumSchema(Schema): title = fields.Str() release_date = fields.Date() artist = fields.Nested(ArtistSchema()) bowie = dict(name='David Bowie') album = dict(artist=bowie, title='Hunky Dory', release_date=date(1971, 12, 17)) schema = AlbumSchema() result = schema.dump(album) pprint(result.data, indent=2) # { 'artist': {'name': 'David Bowie'}, # 'release_date': '1971-12-17', # 'title': 'Hunky Dory'}

实际上我感觉和最近比较火的 graphql 有点像,通过定义一系列查询模式直接返回数据。这样我们就不用自己转成dict了,不直观也不够通用。使用这种Schema以后你就可以写个统一的查询函数了,需要不同的数据格式只要把Schema类作为参数传给函数就好,我甚至尝试用一个统一的分页查询函数解决了所有Model的分页查询和过滤问题。

增强flask_sqlalchemy自带的Modle类

cookiecutter-flask生成的框架里边自带了一个CRUDMixin类,用来给Model增加常用的增删改查,我稍微加了几个函数用来解决一些通用的查询。比如我的query_paginate_and_dump一个函数解决了几乎大部分的查询问题。(借鉴了他人的一些代码)

# -*- coding: utf-8 -*- """Database module, including the SQLAlchemy database object and DB-related utilities.""" import datetime as dt from marshmallow import Schema from sqlalchemy import desc, or_ from sqlalchemy.sql.sqltypes import Date, DateTime from sqlalchemy.orm import relationship from werkzeug import cached_property from .compat import basestring from .extensions import db from .utils import date_str_to_obj, datetime_str_to_obj # Alias common SQLAlchemy names Column = db.Column relationship = relationship OPERATOR_FUNC_DICT = { '=': (lambda cls, k, v: getattr(cls, k) == v), '==': (lambda cls, k, v: getattr(cls, k) == v), 'eq': (lambda cls, k, v: getattr(cls, k) == v), '!=': (lambda cls, k, v: getattr(cls, k) != v), 'ne': (lambda cls, k, v: getattr(cls, k) != v), 'neq': (lambda cls, k, v: getattr(cls, k) != v), '>': (lambda cls, k, v: getattr(cls, k) > v), 'gt': (lambda cls, k, v: getattr(cls, k) > v), '>=': (lambda cls, k, v: getattr(cls, k) >= v), 'gte': (lambda cls, k, v: getattr(cls, k) >= v), '<': (lambda cls, k, v: getattr(cls, k) < v), 'lt': (lambda cls, k, v: getattr(cls, k) < v), '<=': (lambda cls, k, v: getattr(cls, k) <= v), 'lte': (lambda cls, k, v: getattr(cls, k) <= v), 'or': (lambda cls, k, v: or_(getattr(cls, k) == value for value in v)), 'in': (lambda cls, k, v: getattr(cls, k).in_(v)), 'nin': (lambda cls, k, v: ~getattr(cls, k).in_(v)), 'like': (lambda cls, k, v: getattr(cls, k).like('%{}%'.format(v))), 'nlike': (lambda cls, k, v: ~getattr(cls, k).like(v)), '+': (lambda cls, k, v: getattr(cls, k) + v), 'incr': (lambda cls, k, v: getattr(cls, k) + v), '-': (lambda cls, k, v: getattr(cls, k) - v), 'decr': (lambda cls, k, v: getattr(cls, k) - v), } def parse_operator(cls, filter_name_dict): """ 用来返回sqlalchemy query对象filter使用的表达式 Args: filter_name_dict (dict): 过滤条件dict { 'last_name': {'eq': 'wang'}, # 如果是dic使用key作为操作符 'age': {'>': 12} } Returns: binary_expression_list (lambda list) """ def _change_type(cls, field, value): """ 有些表字段比如DateTime类型比较的时候需要转换类型, 前端传过来的都是字符串,Date等类型没法直接相比较,需要转成Date类型 Args: cls (class): Model class field (str): Model class field value (str): value need to compare """ field_type = getattr(cls, field).type if isinstance(field_type, Date): return date_str_to_obj(value) elif isinstance(field_type, DateTime): return datetime_str_to_obj(value) else: return value binary_expression_list = [] for field, op_dict in filter_name_dict.items(): for op, op_val in op_dict.items(): op_val = _change_type(cls, field, op_val) if op in OPERATOR_FUNC_DICT: binary_expression_list.append( OPERATOR_FUNC_DICT[op](cls, field, op_val) ) return binary_expression_list class CRUDMixin(object): """Mixin that adds convenience methods for CRUD (create, read, update, delete) operations.""" @classmethod def create(cls, **kwargs): """Create a new record and save it the database.""" instance = cls(**kwargs) return instance.save() @classmethod def create_from_dict(cls, d): """Create a new record and save it the database.""" assert isinstance(d, dict) instance = cls(**d) return instance.save() def update(self, commit=True, **kwargs): """Update specific fields of a record.""" for attr, value in kwargs.items(): setattr(self, attr, value) return commit and self.save() or self def save(self, commit=True): """Save the record.""" db.session.add(self) if commit: db.session.commit() return self def delete(self, commit=True): """Remove the record from the database.""" db.session.delete(self) return commit and db.session.commit() def to_dict(self, fields_list=None): """ Args: fields (str list): 指定返回的字段 """ column_list = fields_list or [ column.name for column in self.__table__.columns ] return { column_name: getattr(self, column_name) for column_name in column_list } @classmethod def create_or_update(cls, query_dict, update_dict=None): instance =2881064151 db.session.query(cls).filter_by(**query_dict).first() if instance: # update if update_dict is not None: return instance.update(**update_dict) else: return instance else: # create new instance query_dict.update(update_dict or {}) return cls.create(**query_dict) @classmethod def query_paginate(cls, page=1, limit=20, fields=None, order_by_list=None, filter_name_dict=None): """ 通用的分页查询函数 Args: page (int): 页数 limit (int): 每页个数 order_by_list (tuple list): 用来指定排序的字段,可以传多个 [ ('id', 1), ('name', -1) ],1表示正序,-1表示逆序 or [ ('id', 'asc'), ('name', 'desc') ],1表示正序,-1表示逆序 filter_name_dict (dict): 过滤条件

posted @ 2016-11-24 18:02  韩国服务器-Time  阅读(1281)  评论(0编辑  收藏  举报