Python3 对象关系映射(ORM)

ORM

  • 对象关系映射 Object Relational Mapping
    • 表 ---> 类
    • 字段 ---> 属性
    • 记录 ---> 对象
# mysql_client.py

import pymysql


class MySQLClient:

    def __init__(self):
        # 建立连接
        self.client = pymysql.connect(
            host='localhost',
            port=3306,
            user='root',
            password='123',
            database='orm_demo',
            charset='utf8',
            autocommit=True
        )

        # 获取游标
        self.cursor = self.client.cursor(
            pymysql.cursors.DictCursor
        )

    # 提交查询sql语句并返回结果
    def my_select(self, sql, value=None):
        print('sql:', sql, '\nvalue:', value)
        self.cursor.execute(sql, value)
        res = self.cursor.fetchall()
        return res

    # 提交增加, 修改的sql语句
    def my_execute(self, sql, values):
        try:
            print('<sql>:', sql, '\n<values>:', values)
            self.cursor.execute(sql, values)
        except Exception as e:
            print(e)

    # 关闭连接
    def close(self):
        self.cursor.close()
        self.client.close()

# ORM.py
from mysql_client import MySQLClient

# 定义字段类
class Field:
    def __init__(self, name, column_type, primary_key, default):
        self.name = name
        self.column_type = column_type
        self.primary_key = primary_key
        self.default = default


# 定义整数字段类
class IntegerField(Field):
    def __init__(self, name, column_type='int', primary_key=False, default=0):
        super().__init__(name, column_type, primary_key, default)


# 定义字符字段类
class StringField(Field):
    def __init__(self, name, column_type='varchar(64)', primary_key=False, default=''):
        super().__init__(name, column_type, primary_key, default)


# 定义元类
class OrmMetaClass(type):
    def __new__(cls, class_name, class_bases, class_dict):

        if class_name == 'Models':
            return super().__new__(cls, class_name, class_bases, class_dict)

        # 没有表名则默认等于类名
        table_name = class_dict.get('table_name', class_name)

        primary_key = None

        # 定义一个空字典, 用来存放字段对象
        mappings = {}

        for key, value in class_dict.items():
            # 筛选字段对象
            if isinstance(value, Field):
                mappings[key] = value

                if value.primary_key:
                    if primary_key:
                        raise TypeError('只能有一个主键!')

                    primary_key = value.name

        if not primary_key:
            raise TypeError('必须有一个主键!')

        # 删除class_dict中和mappings重复的字段属性
        for key in mappings.keys():
            class_dict.pop(key)

        # 将表名添加到class_dict中
        class_dict['table_name'] = table_name
        # 将主键添加到class_dict中
        class_dict['primary_key'] = primary_key
        # 将mappings添加到class_dict中
        class_dict['mappings'] = mappings

        return super().__new__(cls, class_name, class_bases, class_dict)


# 定义Models类
class Models(dict, metaclass=OrmMetaClass):
    def __getattr__(self, item):
        return self.get(item)

    def __setattr__(self, key, value):
        self[key] = value

    # 查询
    @classmethod
    def orm_select(cls, **kwargs):  # kwargs --> {'id': 1}

        mysql = MySQLClient()

        # select * from User
        if not kwargs:
            sql = 'select * from %s' % cls.table_name

            res = mysql.my_select(sql)

        else:
            key = list(kwargs.keys())[0]
            value = kwargs.get(key)

            sql = 'select * from %s where %s=?' % (cls.table_name, key)

            sql = sql.replace('?', '%s')

            res = mysql.my_select(sql, value)
            
            mysql.close()

            # 返回的是[{}]一个普通的字典(在列表内), 我们把这个字典传入Models类, 让其可以 "对象.属性"
        return [cls(**d) for d in res]  # d是一个字典

    # 增加
    def orm_insert(self):
        mysql = MySQLClient()

        # 存字段名
        keys = []
        # 存字段值
        values = []
        # 存 "?"
        args = []

        for k, v in self.mappings.items():
            # 过滤主键, 因为主键是自增的
            if not v.primary_key:
                # 字段名
                keys.append(v.name)

                # 字段值, 没有值则使用默认值
                values.append(
                    getattr(self, v.name, v.default)
                )

                # 有几个字段就存放几个"?"
                args.append('?')

        sql = 'insert into %s (%s) values (%s)' % (
            self.table_name,
            ','.join(keys),
            ','.join(args)
        )

        sql = sql.replace('?', '%s')
        print(sql)
        print(values)

        mysql.my_execute(sql, values)
        mysql.close()

    # 修改
    def orm_update(self):
        mysql = MySQLClient()

        # 存字段名
        keys = []
        # 存字段值
        values = []
        # 主键值
        primary_key = None

        for k, v in self.mappings.items():
            if v.primary_key:
                # 获取主键值
                primary_key = v.name + '= %s' % getattr(self, v.name)

            else:
                keys.append(v.name + '=?')
                values.append(
                    getattr(self, v.name)
                )

        sql = 'update %s set %s where %s' % (
            self.table_name,
            ','.join(keys),
            primary_key
        )

        # sql: update table set k1=%s, k2=%s where id=pk
        sql = sql.replace('?', '%s')

        mysql.my_execute(sql, values)
        mysql.close()


# 定义用户类
class User(Models):
    user_id = IntegerField('user_id', primary_key=True)
    user_name = StringField('user_name')
    password = StringField('password')


user_obj = User.orm_select(user_id=1)[0]  # 查询
print(user_obj)  # {'user_id': 1, 'user_name': 'blake', 'password': '123'}
user_obj.user_name = 'bigb'  # 修改user_name
user_obj.orm_update()  # 将修改提交到数据库
posted @ 2019-11-05 20:12  MrBigB  阅读(851)  评论(0编辑  收藏  举报