简单封装DBUtils 和 pymysql 并实现简单的逆向工程生成class 类的py文件

  这里使用的 Python 版本是:Python 3.6.0b2。

  涉及的三方库:DBUtils、pymysql

1.ConfigurationParser

  通过调用Python内置的 xml.dom.minidom 对 xml 文件进行解析,获取xml内容。此类用于下面 BaseDao 获取数据库连接信息。

 1 import sys
 2 import re
 3 import pymysql
 4 import xml.dom.minidom
 5 
 6 from xml.dom.minidom import parse
 7 
 8 class ConfigurationParser(object):
 9     """
10     解析xml
11     - @return configDict = {"jdbcConnectionDictList":jdbcConnectionDictList,"tableList":tableList}
12     """
13     def __init__(self, configFilePath=None):
14         if configFilePath:
15             self.__configFilePath = configFilePath
16         else:
17             self.__configFilePath = sys.path[0] + "/config/config.xml"
18         pass
19 
20     def parseConfiguration(self):
21         """
22         解析xml,返回jdbc配置信息以及需要生成python对象的表集合
23         """
24         # 解析xml文件,获取Document对象
25         DOMTree = xml.dom.minidom.parse(self.__configFilePath)    # <class 'xml.dom.minidom.Document'>
26         # 获取 generatorConfiguration 节点的NodeList对象
27         configDOM = DOMTree.getElementsByTagName("generatorConfiguration")[0]  #<class 'xml.dom.minicompat.NodeList'>
28 
29         # 获取 jdbcConnection 节点的 property 节点集合
30         jdbcConnectionPropertyList = configDOM.getElementsByTagName("jdbcConnection")[0].getElementsByTagName("property")
31         # 循环 jdbcConnection 节点的 property 节点集合,获取属性名称和属性值
32         jdbcConnectionDict = {}
33         for property in jdbcConnectionPropertyList:
34             name = property.getAttributeNode("name").nodeValue.strip().lower()
35             if property.hasAttribute("value"):
36                 value = property.getAttributeNode("value").nodeValue
37                 if re.match("[0-9]",value) and name != "password" and name != "host":
38                     value = int(value)
39             else:
40                 value = property.childNodes[0].data
41                 if re.match("[0-9]",value) and name != "password" and name != "host":
42                     value = int(value)
43             if name == "charset":
44                 if re.match("utf-8|UTF8", value, re.I):
45                     value = "utf8"
46             elif name == "port":
47                 value = int(value)
48             elif name == "creator":
49                 if value == "pymysql":
50                     value = pymysql
51             jdbcConnectionDict[name] = value
52         # print(jdbcConnectionDict)
53         return jdbcConnectionDict
54 
55 if __name__ == "__main__":
56     print(ConfigurationParser().parseConfiguration())

  config.xml

 1 <?xml version="1.0" encoding="utf-8"?>
 2 <generatorConfiguration>
 3     <jdbcConnection>
 4         <property name="creator">pymysql</property>
 5         <property name="host">127.0.0.1</property>
 6         <property name="database">rcddup</property>
 7         <property name="port">3306</property>
 8         <property name="user">root</property>
 9         <property name="password">root</property>
10         <property name="charset">Utf-8</property>
11         <property name="mincached">0</property>
12         <property name="maxcached">10</property>
13         <property name="maxshared">0</property>
14         <property name="maxconnections">20</property>
15     </jdbcConnection>
16 </generatorConfiguration>

2.BaseDao

  BaseDao是在 DBUtils 的基础上对 pymysql 操作数据库进行了一些简单的封装。

  其中 queryUtil 用于拼接SQL语句,log4py用于控制台输出信息,page 分页对象。

  由于DBUtils基础上执行的 SQL 查询结果是一个元组类型结果,在 SQL 查询结果返回之后,利用 setattr()方法实现将 SQL 查询结果转换成想要的类对象。为了得到想要的结果,因此类对象(User)的 __str__()需要按照特定的格式重写(下文会给出User类的代码示例)。

  1 import pymysql
  2 import time
  3 import json
  4 
  5 from DBUtils.PooledDB import PooledDB
  6 from configParser import ConfigurationParser
  7 from queryUtil import QueryUtil
  8 from log4py import Logger
  9 from page import Page
 10 
 11 
 12 global PRIMARY_KEY_DICT_LIST
 13 PRIMARY_KEY_DICT_LIST = []
 14 
 15 class BaseDao(object):
 16     """
 17     Python 操作数据库基类方法
 18     - @Author RuanCheng
 19     - @UpdateDate 2017/5/17
 20     """
 21     __logger = None
 22     __parser = None                 # 获取 xml 文件信息对象
 23     __poolConfigDict = None         # 从 xml 中获取的数据库连接信息的字典对象
 24     __pool = None                   # 数据库连接池
 25     __obj = None                    # 实体类
 26     __className = None              # 实体类类名
 27     __tableName = None              # 实体类对应的数据库名
 28     __primaryKeyDict = {}           # 数据库表的主键字典对象
 29     __columnList = []
 30 
 31     def __init__(self, obj=None):
 32         """
 33         初始化方法:
 34         - 1.初始化配置信息
 35         - 2.初始化 className
 36         - 3.初始化数据库表的主键
 37         """
 38         if not obj:
 39             raise Exception("BaseDao is missing a required parameter --> obj(class object).\nFor example [super().__init__(User)].")
 40         else:
 41             self.__logger = Logger(self.__class__)                                      # 初始化日志对象
 42             self.__logger.start()                                                       # 开启日志
 43             if not self.__parser:                                                       # 解析 xml
 44                 self.__parser = ConfigurationParser()
 45                 self.__poolConfigDict = self.__parser.parseConfiguration()
 46                 print(self.__poolConfigDict)
 47                 self.__pool = PooledDB(**self.__poolConfigDict)
 48             # 初始化参数
 49             if (self.__obj == None) or ( self.__obj != obj):
 50                 global PRIMARY_KEY_DICT_LIST
 51                 if (not PRIMARY_KEY_DICT_LIST) or (PRIMARY_KEY_DICT_LIST.count == 0):
 52                     self.__init_primary_key_dict_list()                                 # 初始化主键字典列表
 53                 self.__init_params(obj)                                                 # 初始化参数
 54                 self.__init_columns()                                                   # 初始化字段列表
 55                 self.__logger.end()                                                     # 结束日志
 56         pass
 57     ################################################# 外部调用方法 #################################################
 58     def selectAll(self):
 59         """
 60         查询所有
 61         """
 62         sql = QueryUtil.queryAll(self.__tableName, self.__columnList)
 63         return self.__executeQuery(sql)
 64 
 65     def selectByPrimaryKey(self, value):
 66         """
 67         按主键查询
 68         - @Param: value 主键
 69         """
 70         if (not value) or (value == ""):
 71             raise Exception("selectByPrimaryKey() is missing a required paramter 'value'.")
 72         sql = QueryUtil.queryByPrimaryKey(self.__primaryKeyDict, value, self.__columnList)
 73         return self.__executeQuery(sql)
 74 
 75     def selectCount(self):
 76         """
 77         查询总记录数
 78         """
 79         sql = QueryUtil.queryCount(self.__tableName);
 80         return self.__execute(sql)[0][0]
 81 
 82     def selectAllByPage(self, page=None):
 83         """
 84         分页查询
 85         """
 86         if (not page) or (not isinstance(page,Page)):
 87             raise Exception("Paramter [page] is not correct. Parameter [page] must a Page object instance. ")
 88         sql = QueryUtil.queryAllByPage(self.__tableName, self.__columnList, page)
 89         return self.__executeQuery(sql, logEnable=True)
 90 
 91     def insert(self, obj):
 92         """
 93         新增
 94         - @Param: obj 实体对象
 95         """
 96         if (not obj) or (obj == ""):
 97             raise Exception("insert() is missing a required paramter 'obj'.")
 98         sql = QueryUtil.queryInsert(self.__primaryKeyDict, json.loads(str(obj)))
 99         return self.__executeUpdate(sql)
100     
101     def delete(self, obj=None):
102         """
103         根据实体删除
104         - @Param: obj 实体对象
105         """
106         if (not obj) or (obj == ""):
107             raise Exception("delete() is missing a required paramter 'obj'.")
108         sql = QueryUtil.queryDelete(self.__primaryKeyDict, json.loads(str(obj)))
109         return self.__executeUpdate(sql)
110 
111     def deleteByPrimaryKey(self, value=None):
112         """
113         根据主键删除
114         - @Param: value 主键
115         """
116         if (not value) or (value == ""):
117             raise Exception("deleteByPrimaryKey() is missing a required paramter 'value'.")
118         sql = QueryUtil.queryDeleteByPrimaryKey(self.__primaryKeyDict, value)
119         return self.__executeUpdate(sql)
120     
121     def updateByPrimaryKey(self, obj=None):
122         """
123         根据主键更新
124         - @Param: obj 实体对象
125         """
126         if (not obj) or (obj == ""):
127             raise Exception("updateByPrimaryKey() is missing a required paramter 'obj'.")
128         sql = QueryUtil.queryUpdateByPrimaryKey(self.__primaryKeyDict, json.loads(str(obj)))
129         return self.__executeUpdate(sql)
130 
131     ################################################# 内部调用方法 #################################################
132     def __execute(self, sql="", logEnable=True):
133         """
134         执行 SQL 语句(用于内部初始化参数使用):
135         - @Param: sql 执行sql
136         - @Param: logEnable 是否开启输出日志
137         - @return 查询结果
138         """
139         if not sql:
140             raise Exception("Execute method is missing a required parameter --> sql.")
141         try:
142             self.__logger.outSQL(sql, enable=logEnable)
143             conn = self.__pool.connection()
144             cur = conn.cursor()
145             cur.execute(sql)
146             result = cur.fetchall()
147             resultList = []
148             for r in result:
149                 resultList.append(r)
150             return resultList
151         except Exception as e:
152             conn.rollback()
153             raise Exception(e)
154         finally:
155             cur.close()
156             conn.close()
157             pass
158 
159     def __executeQuery(self, sql="", logEnable=True):
160         """
161         执行查询 SQL 语句:
162         - @Param: sql 执行sql
163         - @Param: logEnable 是否开启输出日志
164         - @return 查询结果
165         """
166         if not sql:
167             raise Exception("Execute method is missing a required parameter --> sql.")
168         try:
169             self.__logger.outSQL(sql, enable=logEnable)
170             conn = self.__pool.connection()
171             cur = conn.cursor()
172             cur.execute(sql)
173             resultTuple = cur.fetchall()
174             resultList = list(resultTuple)
175             objList = []
176             
177             for result in resultList:
178                 i = 0
179                 obj = self.__obj()
180                 for col in self.__columnList:
181                     prop = '_%s__%s'%(self.__className, col)
182                     setattr(obj, prop, result[i])
183                     i += 1
184                 objList.append(obj)
185             if not objList:
186                 return None
187             elif len(objList) == 1:
188                 return objList[0]
189             else:
190                 return objList
191         except Exception as e:
192             conn.rollback()
193             raise Exception(e)
194         finally:
195             cur.close()
196             conn.close()
197             pass
198     
199     def __executeUpdate(self, sql=None, logEnable=True):
200         """
201         执行修改 SQL 语句:
202         - @Param: sql 执行sql
203         - @Param: logEnable 是否开启输出日志
204         - @return 影响行数
205         """
206         try:
207             self.__logger.outSQL(sql, enable=logEnable)
208             conn = self.__pool.connection()
209             cur = conn.cursor()
210             return cur.execute(sql)
211             pass
212         except Exception as e:
213             conn.rollback()
214             raise Exception(e)
215             pass
216         finally:
217             conn.commit()
218             cur.close()
219             conn.close()
220             pass
221 
222     def __init_params(self, obj):
223         """
224         初始化参数
225         - @Param:obj class 对象
226         """
227         self.__obj = obj
228         self.__className = obj.__name__
229         for i in PRIMARY_KEY_DICT_LIST:
230             if i.get("className") == self.__className:
231                 self.__primaryKeyDict = i
232                 self.__className = i["className"]
233                 self.__tableName = i["tableName"]
234                 break
235 
236     def __init_primary_key_dict_list(self):
237         """
238         初始化数据库主键集合:
239         - pk_dict = {"className": {"tableName":tableName,"primaryKey":primaryKey,"auto_increment":auto_increment}}
240         """
241         global PRIMARY_KEY_DICT_LIST
242         sql = """
243             SELECT
244                 t.TABLE_NAME,
245                 c.COLUMN_NAME,
246                 c.ORDINAL_POSITION
247             FROM
248                 INFORMATION_SCHEMA.TABLE_CONSTRAINTS as t,
249                 INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS c
250             WHERE t.TABLE_NAME = c.TABLE_NAME
251                 AND t.TABLE_SCHEMA = "%s"
252                 AND c.CONSTRAINT_SCHEMA = "%s"
253         """%(self.__poolConfigDict.get("database"),self.__poolConfigDict.get("database"))
254         resultList = self.__execute(sql, logEnable=False)
255         for result in resultList:
256             pk_dict = dict()
257             pk_dict["tableName"] = result[0]
258             pk_dict["primaryKey"] = result[1]
259             pk_dict["ordinalPosition"] = result[2]
260             pk_dict["className"] = self.__convertToClassName(result[0])
261             PRIMARY_KEY_DICT_LIST.append(pk_dict)
262         self.__logger.outMsg("initPrimaryKey is done.")
263 
264     def __init_columns(self):
265         """
266         初始化表字段
267         """
268         sql = "SELECT column_name FROM  Information_schema.columns WHERE table_Name = '%s' AND TABLE_SCHEMA='%s'"
            
%(self.__tableName, self.__poolConfigDict["database"])
269 resultList = self.__execute(sql, logEnable=False) 270 for result in resultList: 271 self.__columnList.append(result[0]) 272 self.__logger.outMsg("init_columns is done.") 273 # print(self.__columnList) 274 pass 275 276 def __convertToClassName(self, tableName): 277 """ 278 表名转换方法: 279 - @Param: tableName 表名 280 - @return 转换后的类名 281 """ 282 result = None 283 if tableName.startswith("t_md_"): 284 result = tableName.replace("t_md_", "").replace("_","").lower() 285 elif tableName.startswith("t_ac_"): 286 result = tableName.replace("t_ac_","").replace("_","").lower() 287 elif tableName.startswith("t_"): 288 result = tableName.replace("t_","").replace("_","").lower() 289 else: 290 result = tableName 291 return result.capitalize()

 3.简单应用 UserDao

  创建以个 UserDao,继承BaseDao之后调用父类初始化方法,传递一个 User 对象给父类,我们就可以很方便的对 User 进行CRUD了。

 1 import random
 2 import math
 3 
 4 from baseDao import BaseDao
 5 from user import User
 6 from page import Page
 7 
 8 
 9 class UserDao(BaseDao):
10 
11     def __init__(self):
12         super().__init__(User)
13         pass
14 
15 userDao = UserDao()
16 
17 ######################################## CRUD
18 
19 # print(userDao.selectAll())
20 # user = userDao.selectByPrimaryKey(1)
21 # print(user)
22 
23 # print(userDao.insert(user))
24 
25 # print(userDao.delete(user))
26 # print(userDao.deleteByPrimaryKey(4))
27 
28 # user = userDao.selectByPrimaryKey(1)
29 # print(userDao.updateByPrimaryKey())
30 # print(userDao.update())
31 
32 ######################################## 根据主键更新
33 
34 # strList = list("QWERTYUI欧帕斯电饭锅和进口量自行车VB你们送人头刚回家个省份和健康的根本就可获得草泥马VB你们从v莫妮卡了VB了")
35 # for index in range(1000):
36 #     user = User()
37 #     user.set_id(index+1)
38 #     name = ""
39 #     for i in range(random.randint(3,8)):
41 #         name += random.chioce(strList)
42 #     user.set_name(name)
43 #     user.set_status(1)
44 #     i += 1
45 #     userDao.updateByPrimaryKey(user)
46 
47 ######################################## 更新
48 
49 # user = User()
50 # user.set_id(2)
51 # user.set_name("测试更新")
52 # userDao.updateByPrimaryKey(user)
53 
54 ######################################## 分页查询
55 
56 # page = Page()
57 # pageNum = 1
58 # limit = 10
59 # page.set_page(pageNum)
60 # page.set_limit(limit)
61 # total_count = userDao.selectCount()
62 # page.set_total_count(total_count)
63 # if total_count % limit == 0:
64 #     total_page = total_count / limit
65 # else:
66 #     total_page = math.ceil(total_count / limit)
67 # page.set_total_page(total_page)
68 # begin = (pageNum - 1) * limit
69 
70 # for user in userDao.selectAllByPage(page):
71 #     print(user)

4. User

  User 对象属性设置为私有,通过 get/set 方法访问,最后重写 __str__() 方法,用于 BaseDao 返回 User 对象,而不是一个字典对象或者字符串什么的。

 1 import json
 2 
 3 class User(object):
 4 
 5     def __init__(self):
 6         self.__id = None
 7         self.__name = None
 8         self.__status = None
 9         pass
10 
11     def get_id(self):
12         return self.__id
13 
14     def set_id(self, id):
15         self.__id = id
16 
17     def get_name(self):
18         return self.__name
19 
20     def set_name(self, name):
21         self.__name = name
22 
23     def get_status(self):
24         return self.__status
25 
26     def set_status(self, status):
27         self.__status = status
28 
29 
30     def __str__(self):
31         userDict = {'id':self.__id,'name':self.__name,'status':self.__status}
32         return json.dumps(userDict)

5.QueryUtil

  拼接 SQL 语句的工具类。

  1 from page import Page
  2 
  3 class QueryUtil(object):
  4 
  5     def __init__(self):
  6         pass
  7     
  8     @staticmethod
  9     def queryColumns(columnList):
 10         i = 1
 11         s = ""
 12         for col in columnList:
 13             if i != 1:
 14                 s += ", `%s`"%(col)
 15             else:
 16                 s += "`%s`"%(col)
 17             i += 1
 18         return s
 19     @staticmethod    
 20     def queryByPrimaryKey(primaryKeyDict, value, columnList):
 21         """
 22         拼接主键查询
 23         """
 24         sql = 'SELECT %s FROM `%s` WHERE `%s`="%s"'%(QueryUtil.queryColumns(columnList), primaryKeyDict["tableName"], primaryKeyDict["primaryKey"], str(value))
 25         return sql
 26 
 27     @staticmethod
 28     def queryAll(tableName, columnList):
 29         """
 30         拼接查询所有
 31         """
 32         return 'SELECT %s FROM %s'%(QueryUtil.queryColumns(columnList), tableName)
 33 
 34     @staticmethod
 35     def queryCount(tableName):
 36         """
 37         拼接查询记录数
 38         """
 39         return 'SELECT COUNT(*) FROM %s'%(tableName)
 40 
 41     @staticmethod
 42     def queryAllByPage(tableName, columnList, page=None):
 43         """
 44         拼接分页查询
 45         """
 46         if not page:
 47             page = Page()
 48         return 'SELECT %s FROM %s LIMIT %d,%d'%(QueryUtil.queryColumns(columnList), tableName, page.get_begin(), page.get_limit())
 49 
 50 
 51     @staticmethod
 52     def queryInsert(primaryKeyDict, objDict):
 53         """
 54         拼接新增
 55         """
 56         tableName = primaryKeyDict["tableName"]
 57         key = primaryKeyDict["primaryKey"]
 58         columns = list(objDict.keys())
 59         values = list(objDict.values())
 60 
 61         sql = "INSERT INTO `%s`("%(tableName)
 62         for i in range(0, columns.__len__()):
 63             if i == 0:
 64                 sql += '`%s`'%(columns[i])
 65             else:
 66                 sql += ',`%s`'%(columns[i])
 67         sql += ') VALUES('
 68         for i in range(0, values.__len__()):
 69             if values[i] == None or values[i] == "None":
 70                 value = "null"
 71             else:
 72                 value = '"%s"'%(values[i])
 73             if i == 0:
 74                 sql += value
 75             else:
 76                 sql += ',%s'%(value);
 77         sql += ')'
 78         return sql
 79     
 80     @staticmethod
 81     def queryDelete(primaryKeyDict, objDict):
 82         """
 83         拼接删除
 84         """
 86         tableName = primaryKeyDict["tableName"]
 87         key = primaryKeyDict["primaryKey"]
 88         columns = list(objDict.keys())
 89         values = list(objDict.values())
 90 
 91         sql = "DELETE FROM `%s` WHERE 1=1 "%(tableName)
 92         for i in range(0, values.__len__()):
 93             if values[i] != None and values[i] != "None":
 94                 sql += 'and `%s`="%s"'%(columns[i], values[i])
 95         return sql
 96 
 97     @staticmethod
 98     def queryDeleteByPrimaryKey(primaryKeyDict, value=None):
 99         """
100         拼接根据主键删除
101         """
103         sql = 'DELETE FROM `%s` WHERE `%s`="%s"'%(primaryKeyDict["tableName"], primaryKeyDict["primaryKey"], value)
104         return sql
105     
106     @staticmethod
107     def queryUpdateByPrimaryKey(primaryKeyDict, objDict):
108         """
109         拼接根据主键更新
110         UPDATE t_user SET name='test' WHERE id = 1007
111         """
112         tableName = primaryKeyDict["tableName"]
113         key = primaryKeyDict["primaryKey"]
114         columns = list(objDict.keys())
115         values = list(objDict.values())
116         keyValue = None
117         sql = "UPDATE `%s` SET"%(tableName)
118         for i in range(0, columns.__len__()):
119             if (values[i] != None) and (values[i] != "None"):
120                 if columns[i] != key:
121                     sql += ' `%s`="%s", '%(columns[i], values[i])
122                 else:
123                     keyValue = values[i]
124         sql = sql[0:len(sql)-2] + ' WHERE `%s`="%s"'%(key, keyValue)
125         return sql

6. Page

  分页对象

import json
import math

class Page(object):

    def __init__(self):
        self.__page = 1
        self.__total_page = 1
        self.__total_count = 0
        self.__begin = 0
        self.__limit = 10
        self.__result = []
        pass

    def get_page(self):
        return self.__page

    def set_page(self, page):
        if page > 1:
            self.__page = page

    def get_total_page(self):
        return self.__total_page

    def set_total_page(self, total_page):
        if total_page > 1:
            self.__total_page = total_page

    def get_total_count(self):
        return self.__total_count

    def set_total_count(self, total_count):
        if total_count > 0:
            self.__total_count = total_count

    def get_begin(self):
        return self.__begin

    def set_begin(self, begin):
        if begin > 0:
            self.__begin = begin

    def get_limit(self):
        return self.__limit

    def set_limit(self, limit):
        if limit > 0:
            self.__limit = limit

    def get_result(self):
        return self.__result

    def set_result(self, result):
        self.__result = result

    def __str__(self):
        pageDict = {'page':self.__page,'total_page':self.__total_page,'total_count':self.__total_count,'begin':self.__begin,'limit':self.__limit,'result':self.__result}
        return json.dumps(pageDict)

 

7.Logger

  简单的用于输出信息。

 1 import time
 2 
 3 class Logger(object):
 4 
 5     def __init__(self, obj):
 6         self.__obj = obj
 7         self.__start = None
 8         pass
 9     
10     def start(self):
11         self.__start = time.time()
12         pass
13 
14     def end(self):
15         print("%s >>> [%s] Finished [Time consuming %dms]"%(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), self.__obj.__name__, time.time()-self.__start))
16         pass
17 
18     def outSQL(self, msg, enable=True):
19         """
20         输出 SQL 日志:
21         - @Param: msg SQL语句
22         - @Param: enable 日志开关
23         """
24         if enable:
25             print("%s >>> [%s] [SQL] %s"%(str(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())), self.__obj.__name__, msg))
26         pass
27     
28     def outMsg(self, msg, enable=True):
29         """
30         输出消息日志:
31         - @Param: msg 日志信息
32         - @Param: enable 日志开关
33         """
34         if enable:
35             print("%s >>> [%s] [Msg] %s"%(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), self.__obj.__name__, msg))
36         pass
37     
38         

8.Generator

  为了便于创建 user.py文件,此出提供了自动生成方法,只需要在配置文件中简单的配置数据库连接信息以及要生成的表即可生成对象的py类文件。

  目前只实现了类对象文件的创建。

  1 import sys
  2 import re
  3 import pymysql
  4 import time
  5 import os
  6 import xml.dom.minidom
  7 
  8 from xml.dom.minidom import parse
  9 
 10 global _pythonPath
 11 global _daoPath
 12 global _servicePath
 13 global _controllerPath
 14 
 15 class Generator(object):
 16     """
 17     # python类生成器
 18     @param configDict 配置文件信息的字典对象
 19     """
 20     def __init__(self, configFilePath=None):
 21         if not configFilePath:
 22             self.__configDict = ConfigurationParser().parseConfiguration()
 23         else:
 24             if os.path.isabs(configFilePath):
 25                 self.__configDict = ConfigurationParser(configFilePath).parseConfiguration()
 26             else:
 27                 configFilePath = configFilePath.replace(".", sys.path[0])
 28             pass
 29     
 30     def run(self):
 31         """
 32         # 生成器执行方法
 33         """
 34         fieldDict = DBOperator(self.__configDict).queryFieldDict()
 35         PythonGenarator(self.__configDict, fieldDict).run()
 36         # DaoGenarator(self.__configDict).run()
 37         # ServiceGenarator(self.__configDict).run()
 38         # ControllerGenarator(self.__configDict).run()
 39         
 40 
 41 class PythonGenarator(object):
 42     """
 43     # pyEntity文件生成类
 44     @param configDict 配置文件信息的字典对象
 45     """
 46     def __init__(self, configDict, fieldDict):
 47         self.__configDict = configDict
 48         self.__fieldDict = fieldDict
 49         self.__content = ""
 50         pass
 51     
 52     def run(self):
 53         """
 54         执行 py 类生成方法
 55         """
 56         for filePath in self.__configDict["pythonPathList"]:
 57             if not os.path.exists(filePath):
 58                 os.makedirs(os.path.dirname(filePath), exist_ok=True)
 59             # 获取表名
 60             fileName = os.path.basename(filePath).split(".py")[0]
 61             # 表名(首字母大写)
 62             ClassName = fileName.capitalize()
 63             # 打开新建文件
 64             file = open(filePath, "w", encoding="utf-8")
 65             self.writeImport(file)                                  # 生成 import 内容
 66             self.writeHeader(file, ClassName)                       # 生成 class 头部内容
 67             self.writeInit(file, fileName, ClassName)               # 生成 class 的 init 方法
 68             tableDictString = self.writeGetSet(file, fileName)      # 生成 get/set 方法,并返回一个类属性的字典对象
 69             self.writeStr(file, fileName, tableDictString)          # 重写 class 的 str 方法
 70             file.write(self.__content)
 71             file.close()
 72             print("Generator --> %s"%(filePath))
 73         pass
 74 
 75     def writeImport(self,file ,importList = None):
 76         """
 77         # 写import部分
 78         """
 79         self.__content += "import json\r\n"
 80         pass
 81     
 82     def writeHeader(self, file, className, superClass = None):
 83         """
 84         # 写类头部(class ClassName(object):)
 85         """
 86         if not superClass:
 87             self.__content += "class %s(object):\r\n"%(className)
 88         else:
 89             self.__content += "class %s(%s):\r\n"%(className, superClass)
 90         pass
 91         
 92     def writeInit(self, file, fileName, className):
 93         """
 94         # 写类初始化方法
 95         """
 96         self.__content += "\tdef __init__(self):\n\t\t"
 97         for field in self.__fieldDict[fileName]:
 98             self.__content += "self.__%s = None\n\t\t"%(field)
 99         self.__content += "pass\r\n"
100         pass
101     
102     def writeGetSet(self, file, fileName):
103         """
104         # 写类getXXX(),setXXX()方法
105         @return tableDictString 表属性字典的字符串对象,用于写__str__()方法
106         """
107         tableDictString = ""
108         i = 1
109         for field in self.__fieldDict[fileName]:
110             if i != len(self.__fieldDict[fileName]):
111                 tableDictString += "'%s':self.__%s,"%(field,field)
112             else:
113                 tableDictString += "'%s':self.__%s"%(field,field)
114             Field = field.capitalize()
115             self.__content += "\tdef get_%(field)s(self):\n\t\treturn self.__%(field)s\n\n\tdef set_%(field)s(self, %(field)s):\n\t\tself.__%(field)s = %(field)s\n\n"%({"field":field})
116             i += 1
117         return tableDictString
118     
119     def writeStr(self, file, fileName, tableDictString):
120         """
121         # 重写__str__()方法
122         """
123         tableDictString = "{" + tableDictString + "}"
124         self.__content += "\n\tdef __str__(self):\n\t\t%sDict = %s\r\t\treturn json.dumps(%sDict)\n"%(fileName, tableDictString, fileName)
125         pass
126 
127 class DaoGenarator(object):
128     """
129     # pyDao文件生成类
130     @param configDict 配置文件信息的字典对象
131     """
132     def __init__(self, configDict):
133         self.__configDict = configDict
134         pass
135     
136     def run(self):
137         pass
138 
139 class ServiceGenarator(object):
140     """
141     # pyService文件生成类
142     @param configDict 配置文件信息的字典对象
143     """
144     def __init__(self, configDict):
145         self.__configDict = configDict
146         pass
147     
148     def run(self):
149         pass
150 
151 class ControllerGenarator(object):
152     """
153     # pyControlelr生成类
154     @param configDict 配置文件信息的字典对象
155     """
156     def __init__(self, configDict):
157         self.__configDict = configDict
158         pass
159     
160     def run(self):
161         pass
162 
163 class ConfigurationParser(object):
164     """
165     解析xml\n
166     @return configDict = {"jdbcConnectionDictList":jdbcConnectionDictList,"tableList":tableList}
167     """
168     def __init__(self, configFilePath=None):
169         if configFilePath:
170             self.__configFilePath = configFilePath
171         else:
172             self.__configFilePath = sys.path[0] + "/config/generatorConfig.xml"
173         self.__generatorBasePath = sys.path[0] + "/src/"
174         pass
175 
176     def parseConfiguration(self):
177         """
178         解析xml,返回jdbc配置信息以及需要生成python对象的表集合
179         """
180         # 解析xml文件,获取Document对象
181         DOMTree = xml.dom.minidom.parse(self.__configFilePath)    # <class 'xml.dom.minidom.Document'>
182         # 获取 generatorConfiguration 节点的NodeList对象
183         configDOM = DOMTree.getElementsByTagName("generatorConfiguration")[0]  #<class 'xml.dom.minicompat.NodeList'>
184 
185         # jdbcConnection 节点的 property 节点集合
186         jdbcConnectionPropertyList = configDOM.getElementsByTagName("jdbcConnection")[0].getElementsByTagName("property")
187 
188         # pythonGenerator节点对象
189         pythonDOM = configDOM.getElementsByTagName("pythonGenerator")[0]
190         _pythonPath = self.__getGeneratorPath(pythonDOM.getAttributeNode("targetPath").nodeValue)
191 
192         # serviceGenerator 节点对象
193         serviceDOM = configDOM.getElementsByTagName("serviceGenerator")[0]
194         _servicePath = self.__getGeneratorPath(serviceDOM.getAttributeNode("targetPath").nodeValue)
195         
196 
197         # pythonGenerator节点对象
198         daoDOM = configDOM.getElementsByTagName("daoGenerator")[0]
199         _daoPath = self.__getGeneratorPath(daoDOM.getAttributeNode("targetPath").nodeValue)
200 
201         # controllerGenerator 节点对象
202         controllerDOM = configDOM.getElementsByTagName("controllerGenerator")[0]
203         _controllerPath = self.__getGeneratorPath(controllerDOM.getAttributeNode("targetPath").nodeValue)
204         
205         # 循环 jdbcConnection 节点的 property 节点集合,获取属性名称和属性值
206         jdbcConnectionDict = {"host":None,"user":None,"password":None,"port":3306,"database":None,"charset":"utf8"}
207         for property in jdbcConnectionPropertyList:
208             name = property.getAttributeNode("name").nodeValue.strip().lower()
209             if property.hasAttribute("value"):
210                 value = property.getAttributeNode("value").nodeValue
211             else:
212                 value = property.childNodes[0].data
213             if name == "charset":
214                 if re.match("utf-8|UTF8", value, re.I):
215                     continue
216             elif name == "port":
217                 value = int(value)
218             jdbcConnectionDict[name] = value
219         # print(jdbcConnectionDict)
220 
221         
222         pythonPathList = []
223         daoPathList = []
224         servicePathList = []
225         controllerPathList = []
226 
227         # 获取 table 节点的集合
228         tableList = []
229         tableDOMList = configDOM.getElementsByTagName("table")
230         for tableDOM in tableDOMList:
231             table = {}
232             name = tableDOM.getAttributeNode("name").nodeValue.strip().lower()
233             alias = tableDOM.getAttributeNode("alias").nodeValue.strip().lower()
234             if (not alias) or alias == '' :
235                 prefix = name
236             else:
237                 prefix = alias
238             table["tableName"] = name
239             table["alias"] = alias
240             tableList.append(table)
241 
242 
243             pythonPath = "%s/%s.py" %(_pythonPath, prefix)
244             pythonPathList.append(pythonPath)
245             daoPath = "%s/%sDao.py" %(_daoPath, prefix)
246             daoPathList.append(daoPath)
247             servicePath = "%s/%sService.py" %(_servicePath, prefix)
248             servicePathList.append(servicePath)
249             controllerPath = "%s/%sController.py" %(_controllerPath, prefix)
250             controllerPathList.append(controllerPath)
251 
252         configDict = {
253                         "jdbcConnectionDict":jdbcConnectionDict,
254                         "tableList":tableList,
255                         "pythonPathList":pythonPathList,
256                         "daoPathList":daoPathList,
257                         "servicePathList":servicePathList,
258                         "controllerPathList":controllerPathList
259                     }
260         # print(configDict)
261         return configDict
262     
263     def __getGeneratorPath(self, targetPath):
264         return self.__generatorBasePath + targetPath.replace(".","/")
265 
266 class DBOperator(object):
267 
268     def __init__(self, configDict=None):
269         if configDict == None:
270             raise Exception("Error in DBOperator >>> jdbcConnectionDict is None")
271         self.__configDict = configDict
272         pass
273     
274     def queryFieldDict(self):
275         """
276         * 获取数据库表中的所有字段名
277         * @ return tableDict
278         """
279         fieldDict = {}
280         jdbcConnectionDict = self.__configDict["jdbcConnectionDict"]
281         conn = pymysql.Connect(**jdbcConnectionDict)
282         # 循环数据表
283         for table in self.__configDict["tableList"]:
284             tableName = table["tableName"]
285             alias = table["alias"]
286             fieldList = []
287             # 获取游标
288             cursor = conn.cursor()
289             # 查询表的字段名称和类型
290             sql = """SELECT COLUMN_NAME as name, DATA_TYPE as type
291                      FROM information_schema.columns
292                      WHERE table_schema = '%s' AND table_name = '%s'
293                   """%(self.__configDict["jdbcConnectionDict"]["database"], tableName)
295             # 执行sql
296             cursor.execute(sql)
297             # 返回所有查询结果
298             results = cursor.fetchall()
299             # 关闭游标
300             cursor.close()
301             # 将表所有字段添加到 fieldList 中
302             for result in results:
303                 field = result[0].lower()
306                 fieldList.append(field)
308             fieldDict[alias] = fieldList
309         # 关闭数据库连接
310         conn.close()
312         return fieldDict
313 
314 if __name__ == "__main__":
315     Generator().run()

  generatorConfig.xml

 1 <?xml version="1.0" encoding="utf-8"?>
 2 <generatorConfiguration>
 3     <jdbcConnection>
 4         <property name="host">127.0.0.1</property>
 5         <property name="database">rcddup</property>
 6         <property name="port">3306</property>
 7         <property name="user">root</property>
 8         <property name="password">root</property>
 9         <property name="charset">UTF-8</property>
10     </jdbcConnection>
11     <!-- targetPath 文件生成路径 -->
12     <pythonGenerator targetPath="cn.rcddup.entity"></pythonGenerator>
13     <daoGenerator targetPath="cn.rcddup.dao"></daoGenerator>
14     <serviceGenerator targetPath="cn.rcddup.service"></serviceGenerator>
15     <controllerGenerator targetPath="cn.rcddup.controller"> </controllerGenerator>
16 
17     <!-- name:数据库表明,alias:生成的 class 类名 -->
18     <table name="t_user" alias="User" ></table>
19 </generatorConfiguration>

  到这最近一段时间的 python 学习成果就完了,用兴趣的可以加群:626787819。如果你是小白你可以来这询问,如果你是大牛希望不要嫌弃我们小白,一起交流学习。

  本程序代码在 github 上可以下载,下载地址:https://github.com/ruancheng77/baseDao
  

  创建于:2017-05-20

 

TABLE_SCHEMA='%s'
posted @ 2017-05-20 13:02  rcddup  阅读(2112)  评论(0编辑  收藏  举报