仿写简易ORM

在最近学习Mysql,利用之前的学习的基础,综合起来,挑战自我,写了一个简易版的ORM;

只是可以满足最基础的增删查改操作。以下是源码:

数据库连接池

from conf import setting
import pymysql
from DBUtils.PooledDB import PooledDB
Pool = PooledDB(
    creator=pymysql,  # 连接使用的数据库模块
    maxconnections=6,  # 连接池允许的最大连接数
    mincached=5,  # 初始化时,连接池中至少创建的空闲连接,0表示不创建
    maxcached=3,  # 连接池中最多的闲置连接0表示不限制
    maxusage=3,  # 一个连接最多被重用的次数,None表示不限制
    maxshared=3,  # 连接池中做多共享的连接数量
    setsession=[],  # 开始会话前执行的命令列表
    ping=0,  # ping mysql服务端,检查连接是否可用
    blocking=True,  # 如果没有连接可用后,是否阻塞
    host=setting.host,
    port=setting.port,
    user=setting.user,
    password=setting.password,
    database=setting.database,
    charset=setting.charset,
    autocommit=setting.autocommit,)

配置文件(setting.py)

host = '127.0.0.1'
port = 3306
user = 'root'
password = 'root'
database = 'youku2'
charset = 'utf8'
autocommit = True

实例化数据库连接,获取游标(mysql_pool.py)

import db_pool
import pymysql
class Mysql:
    def __init__(self):
        self.conn = db_pool.Pool.connection()
        self.cursor = self.conn.cursor(cursor=pymysql.cursors.DictCursor)
    def close(self):
        self.cursor.close()
        self.conn.close()
    def select(self, sql, args=None):
        self.cursor.execute(sql, args)
        res = self.cursor.fetchall()
        return res
    def execute(self,sql,args):
        affected = None
        try:
            self.cursor.execute(sql, args)
            affected = self.cursor.rowcount
        except Exception as e:
            print(e)
        finally:
            self.close()
        return affected

表的实例化,通过自定义元类实现对表生成的控制(orm.py)

import mysql_singleton


# 封装sql语句的复杂性,提供应用程序和数据库之间的封装接口,减少程序员的工作量
# 一:表的字段的创建
class Field:
    """
    字段的定义
    create table t1(id int primary key,
    name varchar(200) primary key not null default);
    """

    def __init__(self, column_name, column_type, primary_key, default):
        self.column_name = column_name  # 字段的名称
        self.column_type = column_type  # 字段的类型
        self.primary_key = primary_key  # 是否主键(约束条件)
        self.default = default  # 默认值


class StringField(Field):
    """
    数据字段类型为字符型的应用程序接口封装
    """

    def __init__(self, column_name=None, column_type="varchar(200)", primary_key=False, default=None):
        super().__init__(column_name, column_type, primary_key, default)


class IntegerField(Field):
    """
    数据字段类型为整数型的应用程序接口封装
    """

    def __init__(self, column_name=None, column_type="int", primary_key=False, default=0):
        super().__init__(column_name, column_type, primary_key, default)


# 把表看做一个类的对象,这个对象有着表名,字段名,主键所在列等等属性
class ModelMetaclass(type):
    """
    把表看做一个类,通过自定义元类(类的类)的方式,从而可以控制表的创建
    实例化出来一个类(也就是创建一张表)需要一个类的名称,类的基类(object),
    类的名称空间(类体),类在定义阶段就会执行代码,获取类体的名字存入名称空间
    通过这个类可以实例化出来一个表
    表会有表名(元类实例化出来的类的类名),会有主键所在字段,会有字段名
    """

    def __new__(cls, class_name, bases, namespace):
        if class_name == 'Model':
            return type.__new__(cls, class_name, bases, namespace)
        t_name = namespace.get('t_name', None)  # 从类体代码中找到是否有定义表名(table)
        if not t_name:
            t_name = class_name  # 如果没有找到自定义的表名,那么以类名为表名(默认设置)
        primary_key = None
        mappings = dict()
        for k, v in namespace.items():
            if isinstance(v, Field):  # v 是不是Field的对象
                mappings[k] = v
                if v.primary_key:
                    # 找到主键
                    if primary_key:
                        raise TypeError('主键重复:%s' % k)
                    primary_key = k
        for k in mappings.keys():
            namespace.pop(k)
        if not primary_key:
            raise TypeError('没有主键')
        namespace['t_name'] = t_name
        namespace['primary_key'] = primary_key
        namespace['mappings'] = mappings
        return type.__new__(cls, class_name, bases, namespace)


class Model(dict, metaclass=ModelMetaclass):  # 元类是ModelMetaclass,父类是dict
    """
    一:
        在类model的定义阶段(代码从上至下运行到这一段)元类的__new__会在
        __init__执行之前截获执行步骤,转而执行__new__内的代码
        运行了model的类体代码,获取了一系列的名字放入了类model的名称空间
    二:
        把表看做是一个类(Model)的实例化对象,这个类的元类是(ModelMetaclass),
        类在之前的学习就已经知道,是数据和数据的处理方法的高度封装,
        那么,实例化出表的类,应该有表和对表的一系列增删查改的操作方法的高度封装
    """

    def __init__(self, **kw):
        """
        cls[key] = value
        :param kw:
        """
        super(Model, self).__init__(**kw)  # 继承父类的__init__方法,把值传入

    def __getattr__(self, key):
        """
        __getattr__ 拦截点号运算。当对未定义的属性名称和实例进行点号运算时,就会用
        属性名作为字符串调用这个方法。如果继承树可以找到该属性,则不调用此方法
        """
        try:
            return self[key]
        except KeyError:
            raise AttributeError('没有属性:%s' % key)

    def __setattr__(self, key, value):
        """
        __setattr__会拦截所有属性的的赋值语句。如果定义了这个方法,
        self.name = value 就会变成self,__setattr__("name", value).
        这个需要注意。当在__setattr__方法内对属性进行赋值是,不可使
        用self.name = value,因为他会再次调用self,__setattr__("name", value),则
        会形成无穷递归循环,最后导致堆栈溢出异常。应该通过对属性字典做索引运算来赋值任
        何实例属性,也就是使用self.__dict__['name'] = value
        """
        self[key] = value

    @classmethod
    def select_all(cls, **kwargs):
        """
        一:支持单条查询或查询所有
        sql语句:select * from 表名 where id(字段) = 1(查询id)
                select * from 表名(查询所有)
        二:
            当类调用这个方法时,会把调用的类名自动传入
            cls = 调用的类的名
            **接受关键字参数(例如id= 1),赋值给kwargs
        """
        ms = mysql_singleton.Mysql().singleton()
        if kwargs:
            """
            如果有传入的参数,取出参数,拼接sql语句
            """
            key = list(kwargs.keys())[0]
            value = kwargs[key]
            sql = "select * from %s where %s=?" % (cls.t_name, key)
            sql = sql.replace('?', '%s')
            re = ms.select(sql, value)
        else:
            sql = "select * from %s" % cls.t_name
            re = ms.select(sql)
            # 用fetchall方法取得SQL语句执行结果(字典形式)
        return [cls(**r) for r in re]
        # for循环取出各个元素(字典形式),打散成如id = 1 ,name = "ls"

    @classmethod
    def select_one(cls, **kwargs):
        key = list(kwargs.keys())[0]
        value = kwargs[key]
        ms = mysql_singleton.Mysql().singleton()
        sql = "select * from %s where %s=?" % (cls.t_name, key)
        sql = sql.replace('?', '%s')
        re = ms.select(sql, value)
        if re:
            return cls(**re[0])
        else:
            return None

    def save(self):
        """
        给表中插入数据
        sql语句:insert into t1(id,name) values(1,"egon")
        :return:
        """
        ms = mysql_singleton.Mysql().singleton()
        fields = []
        params = []
        args = []
        for k, v in self.mapping.items():
            """
            self指的是表,k为字段,v为字段的约束条件(Field的实例化)
            """
            fields.append(v.name)  # ['id','name']
            params.append('?')  # ['?','?']
            args.append(getattr(self, k, v.default))  # ['1','egon']
            # 返回对象(表)的属性(字段)k的值,没有则返回默认值v.default
            # getattr(object, name[, default])函数用于返回一个对象属性值。
        sql = "insert into %s (%s) values (%s)" % (self.table_name, ','.join(fields), ','.join(params))
        # join()方法用于将序列中的元素以指定的字符连接生成一个新的字符串。
        sql = sql.replace('?', '%s')
        ms.execute(sql, args)

    def update(self):
        """
        sql语句:update t1 set name = "ls",id = 2 where name = "egon"
        :return:
        """
        ms = mysql_singleton.Mysql().singleton()
        fields = []
        args = []
        pr = None
        for k, v in self.mapping.items():
            if v.primary_key:
                pr = getattr(self, k, v.default)
                # 返回对象(表)的属性(字段)k的值,没有则返回默认值
                # getattr(object, name[, default])函数用于返回一个对象属性值。
            else:
                fields.append(v.name + '=?')
                args.append(getattr(self, k, v.default))
        sql = "update %s set %s where %s = %s" % (
            self.table_name, ', '.join(fields), self.primary_key, pr)
        sql = sql.replace('?', '%s')
        print(sql)
        ms.execute(sql, args)


class User(Model):
    # User类对应着一个表,User类的实例化对象对应表中的一条记录
    t_name = "user"
    id = IntegerField('id', primary_key=True)
    name = StringField("name")


if __name__ == '__main__':
    # 数据查询
    user = User.select_one(id=1)
    print(user)
    print(user.name)
    # 数据插入
    user = User(name='ls')
    user.save()
    # 数据修改
    user = User.select_one(id=1)
    user.name = 'mysql'
    user.update()

附:sql文件

SET foreign_key_checks = 0;
DROP TABLE IF EXISTS userinfo,movie,notice,download_record;
CREATE TABLE userinfo (
    id INT PRIMARY KEY NOT NULL auto_increment,
    `name` VARCHAR (32),
    `password` VARCHAR (64),
    is_vip INT,
    locked INT,
    user_type VARCHAR (32)
) ENGINE = INNODB,charset = 'utf8';
CREATE TABLE movie (
    id INT PRIMARY KEY NOT NULL auto_increment,
    `name` VARCHAR (32),
    `path` VARCHAR (255),
    is_free INT DEFAULT 0,
    is_delete INT DEFAULT 0,
    create_time timestamp default current_timestamp,
    user_id int,
    file_md5 VARCHAR (64)
) charset = 'utf8';
create table notice(
    id int not null primary key auto_increment,
    `name` varchar(64),
    content varchar(255),
    create_time timestamp  default current_timestamp,
    user_id int
)charset = 'utf8';
create table download_record(
    id int not null PRIMARY key auto_increment,
    user_id int,
    movie_id int
)charset = 'utf8';
posted @ 2018-06-07 21:32  Leslie-x  阅读(153)  评论(0编辑  收藏  举报