仿写简易ORM
在最近学习Mysql,利用之前的学习的基础,综合起来,挑战自我,写了一个简易版的ORM;
只是可以满足最基础的增删查改操作。以下是源码:
数据库连接池
from conf import setting import pymysql from DBUtils.PooledDB import PooledDB Pool = PooledDB( creator=pymysql, # 连接使用的数据库模块 maxconnections=6, # 连接池允许的最大连接数 mincached=5, # 初始化时,连接池中至少创建的空闲连接,0表示不创建 maxcached=3, # 连接池中最多的闲置连接0表示不限制 maxusage=3, # 一个连接最多被重用的次数,None表示不限制 maxshared=3, # 连接池中做多共享的连接数量 setsession=[], # 开始会话前执行的命令列表 ping=0, # ping mysql服务端,检查连接是否可用 blocking=True, # 如果没有连接可用后,是否阻塞 host=setting.host, port=setting.port, user=setting.user, password=setting.password, database=setting.database, charset=setting.charset, autocommit=setting.autocommit,)
配置文件(setting.py)
host = '127.0.0.1' port = 3306 user = 'root' password = 'root' database = 'youku2' charset = 'utf8' autocommit = True
实例化数据库连接,获取游标(mysql_pool.py)
import db_pool import pymysql class Mysql: def __init__(self): self.conn = db_pool.Pool.connection() self.cursor = self.conn.cursor(cursor=pymysql.cursors.DictCursor) def close(self): self.cursor.close() self.conn.close() def select(self, sql, args=None): self.cursor.execute(sql, args) res = self.cursor.fetchall() return res def execute(self,sql,args): affected = None try: self.cursor.execute(sql, args) affected = self.cursor.rowcount except Exception as e: print(e) finally: self.close() return affected
表的实例化,通过自定义元类实现对表生成的控制(orm.py)
import mysql_singleton # 封装sql语句的复杂性,提供应用程序和数据库之间的封装接口,减少程序员的工作量 # 一:表的字段的创建 class Field: """ 字段的定义 create table t1(id int primary key, name varchar(200) primary key not null default); """ def __init__(self, column_name, column_type, primary_key, default): self.column_name = column_name # 字段的名称 self.column_type = column_type # 字段的类型 self.primary_key = primary_key # 是否主键(约束条件) self.default = default # 默认值 class StringField(Field): """ 数据字段类型为字符型的应用程序接口封装 """ def __init__(self, column_name=None, column_type="varchar(200)", primary_key=False, default=None): super().__init__(column_name, column_type, primary_key, default) class IntegerField(Field): """ 数据字段类型为整数型的应用程序接口封装 """ def __init__(self, column_name=None, column_type="int", primary_key=False, default=0): super().__init__(column_name, column_type, primary_key, default) # 把表看做一个类的对象,这个对象有着表名,字段名,主键所在列等等属性 class ModelMetaclass(type): """ 把表看做一个类,通过自定义元类(类的类)的方式,从而可以控制表的创建 实例化出来一个类(也就是创建一张表)需要一个类的名称,类的基类(object), 类的名称空间(类体),类在定义阶段就会执行代码,获取类体的名字存入名称空间 通过这个类可以实例化出来一个表 表会有表名(元类实例化出来的类的类名),会有主键所在字段,会有字段名 """ def __new__(cls, class_name, bases, namespace): if class_name == 'Model': return type.__new__(cls, class_name, bases, namespace) t_name = namespace.get('t_name', None) # 从类体代码中找到是否有定义表名(table) if not t_name: t_name = class_name # 如果没有找到自定义的表名,那么以类名为表名(默认设置) primary_key = None mappings = dict() for k, v in namespace.items(): if isinstance(v, Field): # v 是不是Field的对象 mappings[k] = v if v.primary_key: # 找到主键 if primary_key: raise TypeError('主键重复:%s' % k) primary_key = k for k in mappings.keys(): namespace.pop(k) if not primary_key: raise TypeError('没有主键') namespace['t_name'] = t_name namespace['primary_key'] = primary_key namespace['mappings'] = mappings return type.__new__(cls, class_name, bases, namespace) class Model(dict, metaclass=ModelMetaclass): # 元类是ModelMetaclass,父类是dict """ 一: 在类model的定义阶段(代码从上至下运行到这一段)元类的__new__会在 __init__执行之前截获执行步骤,转而执行__new__内的代码 运行了model的类体代码,获取了一系列的名字放入了类model的名称空间 二: 把表看做是一个类(Model)的实例化对象,这个类的元类是(ModelMetaclass), 类在之前的学习就已经知道,是数据和数据的处理方法的高度封装, 那么,实例化出表的类,应该有表和对表的一系列增删查改的操作方法的高度封装 """ def __init__(self, **kw): """ cls[key] = value :param kw: """ super(Model, self).__init__(**kw) # 继承父类的__init__方法,把值传入 def __getattr__(self, key): """ __getattr__ 拦截点号运算。当对未定义的属性名称和实例进行点号运算时,就会用 属性名作为字符串调用这个方法。如果继承树可以找到该属性,则不调用此方法 """ try: return self[key] except KeyError: raise AttributeError('没有属性:%s' % key) def __setattr__(self, key, value): """ __setattr__会拦截所有属性的的赋值语句。如果定义了这个方法, self.name = value 就会变成self,__setattr__("name", value). 这个需要注意。当在__setattr__方法内对属性进行赋值是,不可使 用self.name = value,因为他会再次调用self,__setattr__("name", value),则 会形成无穷递归循环,最后导致堆栈溢出异常。应该通过对属性字典做索引运算来赋值任 何实例属性,也就是使用self.__dict__['name'] = value """ self[key] = value @classmethod def select_all(cls, **kwargs): """ 一:支持单条查询或查询所有 sql语句:select * from 表名 where id(字段) = 1(查询id) select * from 表名(查询所有) 二: 当类调用这个方法时,会把调用的类名自动传入 cls = 调用的类的名 **接受关键字参数(例如id= 1),赋值给kwargs """ ms = mysql_singleton.Mysql().singleton() if kwargs: """ 如果有传入的参数,取出参数,拼接sql语句 """ key = list(kwargs.keys())[0] value = kwargs[key] sql = "select * from %s where %s=?" % (cls.t_name, key) sql = sql.replace('?', '%s') re = ms.select(sql, value) else: sql = "select * from %s" % cls.t_name re = ms.select(sql) # 用fetchall方法取得SQL语句执行结果(字典形式) return [cls(**r) for r in re] # for循环取出各个元素(字典形式),打散成如id = 1 ,name = "ls" @classmethod def select_one(cls, **kwargs): key = list(kwargs.keys())[0] value = kwargs[key] ms = mysql_singleton.Mysql().singleton() sql = "select * from %s where %s=?" % (cls.t_name, key) sql = sql.replace('?', '%s') re = ms.select(sql, value) if re: return cls(**re[0]) else: return None def save(self): """ 给表中插入数据 sql语句:insert into t1(id,name) values(1,"egon") :return: """ ms = mysql_singleton.Mysql().singleton() fields = [] params = [] args = [] for k, v in self.mapping.items(): """ self指的是表,k为字段,v为字段的约束条件(Field的实例化) """ fields.append(v.name) # ['id','name'] params.append('?') # ['?','?'] args.append(getattr(self, k, v.default)) # ['1','egon'] # 返回对象(表)的属性(字段)k的值,没有则返回默认值v.default # getattr(object, name[, default])函数用于返回一个对象属性值。 sql = "insert into %s (%s) values (%s)" % (self.table_name, ','.join(fields), ','.join(params)) # join()方法用于将序列中的元素以指定的字符连接生成一个新的字符串。 sql = sql.replace('?', '%s') ms.execute(sql, args) def update(self): """ sql语句:update t1 set name = "ls",id = 2 where name = "egon" :return: """ ms = mysql_singleton.Mysql().singleton() fields = [] args = [] pr = None for k, v in self.mapping.items(): if v.primary_key: pr = getattr(self, k, v.default) # 返回对象(表)的属性(字段)k的值,没有则返回默认值 # getattr(object, name[, default])函数用于返回一个对象属性值。 else: fields.append(v.name + '=?') args.append(getattr(self, k, v.default)) sql = "update %s set %s where %s = %s" % ( self.table_name, ', '.join(fields), self.primary_key, pr) sql = sql.replace('?', '%s') print(sql) ms.execute(sql, args) class User(Model): # User类对应着一个表,User类的实例化对象对应表中的一条记录 t_name = "user" id = IntegerField('id', primary_key=True) name = StringField("name") if __name__ == '__main__': # 数据查询 user = User.select_one(id=1) print(user) print(user.name) # 数据插入 user = User(name='ls') user.save() # 数据修改 user = User.select_one(id=1) user.name = 'mysql' user.update()
附:sql文件
SET foreign_key_checks = 0; DROP TABLE IF EXISTS userinfo,movie,notice,download_record; CREATE TABLE userinfo ( id INT PRIMARY KEY NOT NULL auto_increment, `name` VARCHAR (32), `password` VARCHAR (64), is_vip INT, locked INT, user_type VARCHAR (32) ) ENGINE = INNODB,charset = 'utf8'; CREATE TABLE movie ( id INT PRIMARY KEY NOT NULL auto_increment, `name` VARCHAR (32), `path` VARCHAR (255), is_free INT DEFAULT 0, is_delete INT DEFAULT 0, create_time timestamp default current_timestamp, user_id int, file_md5 VARCHAR (64) ) charset = 'utf8'; create table notice( id int not null primary key auto_increment, `name` varchar(64), content varchar(255), create_time timestamp default current_timestamp, user_id int )charset = 'utf8'; create table download_record( id int not null PRIMARY key auto_increment, user_id int, movie_id int )charset = 'utf8';
Ideal are like the stars --- we never reach them ,but like mariners , we chart our course by them