手撸orm
orm的作用就是类和数据库的表的映射关系。
一个类代表的就是一张表,一个类实例化出来的对象就是一条记录。
from orm_demo import mysql_control
class Field:
def __init__(self,name , column_type, primary_key , default):
self.name=name
self.column_type=column_type
self.primary_key= primary_key
self.default = default
class IntegerField(Field):
def __init__(self, name , column_type="int" , primary_key=False , default = 0):
super().__init__( name , column_type , primary_key, default)
class StringField(Field):
def __init__(self, name, column_type='varchar(64)', primary_key=False, default=None):
super().__init__(name, column_type, primary_key, default)
class OrmMetaClass(type):
def __new__(cls, class_name, class_base, class_attr):
if class_name == 'Models':
return type.__new__(cls, class_name, class_base , class_attr)
table_name = class_attr.get('table_name', class_name)
mappings = {}
primary_key = None
for k,v in class_attr.items():
if isinstance(v,Field):
mappings[k] = v
if v.primary_key:
if primary_key:
raise TypeError('只能有一个主键')
primary_key = v.name
for k in mappings.keys():
class_attr.pop(k)
if not primary_key:
raise TypeError('必须要有一个主键')
class_attr['table_name'] = table_name
class_attr['primary_key'] = primary_key
class_attr['mappings'] = mappings
return type.__new__(cls, class_name, class_base , class_attr)
class Models(dict,metaclass= OrmMetaClass):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# __getattr__: 在对象.属性时,属性没有时触发。
def __getattr__(self, item):
# 字典本来的取值方式 字典[key] ---> 字典.key
return self.get(item)
# __setattr__: 在对象.属性赋值的时候触发。
def __setattr__(self, key, value):
# 字典本身的赋值方式
self[key] = value
@classmethod
def select(cls , **kwargs):
mysql_obj = mysql_control.Mysql()
if not kwargs:
sql= 'select * from %s' % cls.table_name
res = mysql_obj.select(sql)
else:
key = list(kwargs.keys())[0]
value = kwargs.get(key)
sql = 'select * from %s where %s=?' % (cls.table_name, key)
sql = sql.replace('?' , '%s')
res = mysql_obj.select(sql, value)
return [cls(**r) for r in res]
def save(self):
mysql = mysql_control.Mysql()
fields = []
values = []
replace = []
for k,v in self.mappings.items():
fields.append(k)
values.append(
getattr(self, v.name, v.default)
)
replace.append('?')
sql = 'insert into %s(%s) values(%s)' % (self.table_name,','.join(fields),','.join(replace))
sql = sql.replace('?', '%s')
mysql.execute(sql, values)
def sql_update(self):
mysql = mysql_control.Mysql()
fields = []
values = []
primary_key = None
for k, v in self.mappings.items():
if v.primary_key:
primary_key = getattr(self, v.name)
else :
fields.append(v.name + '=?')
values.append(
getattr(self, v.name)
)
sql = 'update %s set %s where %s=%s' % (self.table_name, ','.join(fields),self.primary_key,primary_key)
sql = sql.replace('?', '%s')
mysql.execute(sql, values)