Python 元类编程实现一个简单的 ORM
概述
什么是ORM?
ORM全称“Object Relational Mapping”,即对象-关系映射,就是把关系数据库的一行映射为一个对象,也就是一个类对应一个表,这样,写代码更简单,不用直接操作SQL语句。
现在我们就要实现简易版ORM。
效果
class Person(Model): """ 定义类的属性到列的映射 """ pid = IntegerField('id') names = StringField('username') email = StringField('email') password = StringField('password') p = Person(pid=10086, names='晓明', email='10086@163.com', password='123456') p.save()
通过执行save()方法 动态生成sql插入语句, 是不是很神奇, 那我们现在开始解析原理吧
步骤
首先我们要定义一个 Field 类 它负责保存数据库表的字段名和字段类型:
class Field(object): def __init__(self, name, column_type): self.name = name self.column_type = column_type def __str__(self): return '<%s:%s>' % (self.__class__.__name__, self.name)
在 Field
的基础上,进一步定义各种类型的 Field
,比如 StringField
,IntegerField
等等:
class StringField(Field): def __init__(self, name): super(StringField, self).__init__(name, 'varchar(100)') class IntegerField(Field): def __init__(self, name): super(IntegerField, self).__init__(name, 'bigint')
下一步,就是编写最复杂的 ModelMetaclass
:
class ModelMetaclass(type): def __new__(cls, name, bases, attrs): if name == "Model": return type.__new__(cls, name, bases, attrs) mappings = dict() print("Found class: %s" % name) for k, v in attrs.items(): if isinstance(v, Field): print("Found mapping: %s ==> %s" % (k, v)) mappings[k] = v for k in mappings.keys(): attrs.pop(k) attrs["__table__"] = name # 表名和类名一致 attrs["__mappings__"] = mappings # 保存属性和列的映射关系 return type.__new__(cls, name, bases, attrs)
最后就是基类 Model:
class Model(metaclass=ModelMetaclass): def __init__(self, **kwargs): _setattr = setattr if kwargs: for k, v in kwargs.items(): _setattr(self, k, v) super(Model, self).__init__() def save(self): fields = [] params = [] args = [] for k, v in self.__mappings__.items(): fields.append(k) params.append("?") args.append(getattr(self, k, None)) sql = "insert into %s (%s) values (%s)" % (self.__table__, ','.join(fields), ",".join(params)) print('插入语句: %s' % sql) print('参数: %s' % str(args)) def update(self): fields = [] args = [] for k, v in self.__mappings__.items(): if getattr(self, k, None): fields.append(k+"=?") args.append(getattr(self, k, None)) sql = "update %s set %s" % (self.__table__, ','.join(fields)) print("更新语句: %s " % sql) print("参数: %s" % args) def filter(self, *args): pass def delete(self): pass
当用户定义一个 class Person(Model) 继承父类时,Python解释器会在当前类 Person 的定义中找 __metaclass__,如果没有找到,就继续到父类中找 __metaclass__,实在找不到就用默认 type 类。
我们在父类 Model 中定义了 __metaclass__ 的 ModelMetaclass 来创建 Person 类,所以 metaclass 隐式地继承到子类。
在 ModelMetaclass
中,一共做了几件事情:
-
排除掉对
Model
类的修改; -
在当前类(比如
Person
)中查找定义的类的所有属性,如果找到一个 Field 属性,就把它保存到一个__mappings__
的dict中,同时从类属性中删除该Field属性,否则,容易造成运行时错误; -
把表名保存到
__table__
中,这里简化为表名默认为类名。
在Model
类中,就可以定义各种操作数据库的方法,比如save()
,delete()
,find()
,update()
等等。
我们实现了save(), update()
方法,把一个实例保存到数据库中。因为有表名,属性到字段的映射和属性值的集合,就可以构造出INSERT语句和UPDATE
语句。
编写代码试试:
class UserInfo(Model): """ 定义类的属性到列的映射 """ uid = IntegerField('uid') name = StringField('username') email = StringField('email') password = StringField('password') class Person(Model): """ 定义类的属性到列的映射 """ pid = IntegerField('id') names = StringField('username') email = StringField('email') password = StringField('password') p = Person(pid=10086, names='晓明', email='10086@163.com', password='123456') p.save() u2 = UserInfo(password='123456') u2.update()
输出
Found class: UserInfo Found mapping: uid ==> <IntegerField:uid> Found mapping: name ==> <StringField:username> Found mapping: email ==> <StringField:email> Found mapping: password ==> <StringField:password> Found class: Person Found mapping: pid ==> <IntegerField:id> Found mapping: names ==> <StringField:username> Found mapping: email ==> <StringField:email> Found mapping: password ==> <StringField:password> 插入语句: insert into Person (pid,names,email,password) values (?,?,?,?) 参数: [10086, '晓明', '10086@163.com', '123456'] 更新语句: update UserInfo set password=? 参数: ['123456']
结束语
就这样一个小巧的ORM就这么完成了。是不是学到了很多呢 ?这里利用的是元编程,很多Python框架都运用了元编程达到动态操作类。
注:上述代码列子 结合了廖雪峰的列子和少量的django ORM源码。
完整代码
class Field(object): def __init__(self, name, column_type): self.name = name self.column_type = column_type def __str__(self): return '<%s:%s>' % (self.__class__.__name__, self.name) class StringField(Field): def __init__(self, name): super(StringField, self).__init__(name, 'varchar(100)') class IntegerField(Field): def __init__(self, name): super(IntegerField, self).__init__(name, 'bigint') class ModelMetaclass(type): def __new__(cls, name, bases, attrs): if name == "Model": return type.__new__(cls, name, bases, attrs) mappings = dict() print("Found class: %s" % name) for k, v in attrs.items(): if isinstance(v, Field): print("Found mapping: %s ==> %s" % (k, v)) mappings[k] = v for k in mappings.keys(): attrs.pop(k) attrs["__table__"] = name # 表名和类名一致 attrs["__mappings__"] = mappings # 保存属性和列的映射关系 return type.__new__(cls, name, bases, attrs) class Model(metaclass=ModelMetaclass): def __init__(self, **kwargs): _setattr = setattr if kwargs: for k, v in kwargs.items(): _setattr(self, k, v) super(Model, self).__init__() def save(self): fields = [] params = [] args = [] for k, v in self.__mappings__.items(): fields.append(k) params.append("?") args.append(getattr(self, k, None)) sql = "insert into %s (%s) values (%s)" % (self.__table__, ','.join(fields), ",".join(params)) print('插入语句: %s' % sql) print('参数: %s' % str(args)) def update(self): fields = [] args = [] for k, v in self.__mappings__.items(): if getattr(self, k, None): fields.append(k+"=?") args.append(getattr(self, k, None)) sql = "update %s set %s" % (self.__table__, ','.join(fields)) print("更新语句: %s " % sql) print("参数: %s" % args) def filter(self, *args): pass def delete(self): pass class UserInfo(Model): """ 定义类的属性到列的映射 """ uid = IntegerField('uid') name = StringField('username') email = StringField('email') password = StringField('password') class Person(Model): """ 定义类的属性到列的映射 """ pid = IntegerField('id') names = StringField('username') email = StringField('email') password = StringField('password') p = Person(pid=10086, names='晓明', email='10086@163.com', password='123456') p.save() u2 = UserInfo(password='123456') u2.update()