我的第一个python web开发框架(32)——定制ORM(八)

  写到这里,基本的ORM功能就完成了,不知大家有没有发现,这个ORM每个方法都是在with中执行的,也就是说每个方法都是一个完整的事务,当它执行完成以后也会将事务提交,那么如果我们想要进行一个复杂的事务时,它并不能做到,所以我们还需要对它进行改造,让它支持sql事务。

  那么应该怎么实现呢?我们都知道要支持事务,就必须让不同的sql语句在同一个事务中执行,也就是说,我们需要在一个with中执行所有的sql语句,失败则回滚,成功再提交事务。

  由于我们的逻辑层各个类都是继承ORM基类来实现的,而事务的开关放在各个类中就不合适,可能会存在问题,所以在执行事务时,直接调用db_helper模块,使用with初始化好数据库链接,然后在方法里编写并执行各个sql语句。

  当前逻辑层基类(ORM模块)的sql语句都是在方法中生成(拼接)的,然后在方法的with模块中执行,所以我们需要再次对整个类进行改造,将所有的sql生成方法提炼出来,成为单独的方法,然后在事务中,我们不直接执行获取结果,而是通过ORM生成对应的sql语句,在with中执行这样语句。(当然还有其他方法也能实现事务,不过在这里不做进一步的探讨,因为当前这种是最简单实现事务的方式之一,多层封装处理,有可能会导致系统变的更加复杂,代码更加难懂)

  代码改造起来很简单,比如说获取记录方法

 1     def get_model(self, wheres):
 2         """通过条件获取一条记录"""
 3         # 如果有条件,则自动添加where
 4         if wheres:
 5             wheres = ' where ' + wheres
 6 
 7         # 合成sql语句
 8         sql = "select %(column_name_list)s from %(table_name)s %(wheres)s" % \
 9               {'column_name_list': self.__column_name_list, 'table_name': self.__table_name, 'wheres': wheres}
10         # 初化化数据库链接
11         result = self.select(sql)
12         if result:
13             return result[0]
14         return {}

  我们可以将它拆分成get_model_sql()和get_model()两个方法,一个处理sql组合,一个执行获取结果,前者可以给事务调用,后者直接给对应的程序调用

 1     def get_model_sql(self, wheres):
 2         """通过条件获取一条记录"""
 3         # 如果有条件,则自动添加where
 4         if wheres:
 5             wheres = ' where ' + wheres
 6 
 7         # 合成sql语句
 8         sql = "select %(column_name_list)s from %(table_name)s %(wheres)s" % \
 9               {'column_name_list': self.__column_name_list, 'table_name': self.__table_name, 'wheres': wheres}
10         return sql
11 
12     def get_model(self, wheres):
13         """通过条件获取一条记录"""
14         # 生成sql
15         sql = self.get_model_sql(wheres)
16         # 初化化数据库链接
17         result = self.select(sql)
18         if result:
19             return result[0]
20         return {}

  其他代码不一一细述,大家自己看看重构后的结果

  1 #!/usr/bin/env python
  2 # coding=utf-8
  3 
  4 from common import db_helper, cache_helper, encrypt_helper
  5 
  6 
  7 class LogicBase():
  8     """逻辑层基础类"""
  9 
 10     def __init__(self, db, is_output_sql, table_name, column_name_list='*', pk_name='id'):
 11         """类初始化"""
 12         # 数据库参数
 13         self.__db = db
 14         # 是否输出执行的Sql语句到日志中
 15         self.__is_output_sql = is_output_sql
 16         # 表名称
 17         self.__table_name = str(table_name).lower()
 18         # 查询的列字段名称,*表示查询全部字段,多于1个字段时用逗号进行分隔,除了字段名外,也可以是表达式
 19         self.__column_name_list = str(column_name_list).lower()
 20         # 主健名称
 21         self.__pk_name = str(pk_name).lower()
 22         # 缓存列表
 23         self.__cache_list = self.__table_name + '_cache_list'
 24 
 25     #####################################################################
 26     ### 生成Sql ###
 27     def get_model_sql(self, wheres):
 28         """通过条件获取一条记录"""
 29         # 如果有条件,则自动添加where
 30         if wheres:
 31             wheres = ' where ' + wheres
 32 
 33         # 合成sql语句
 34         sql = "select %(column_name_list)s from %(table_name)s %(wheres)s" % \
 35               {'column_name_list': self.__column_name_list, 'table_name': self.__table_name, 'wheres': wheres}
 36         return sql
 37 
 38     def get_model_for_pk_sql(self, pk, wheres=''):
 39         """通过主键值获取数据库记录实体"""
 40         # 组装查询条件
 41         wheres = '%s = %s' % (self.__pk_name, str(pk))
 42         return self.get_model_sql(wheres)
 43 
 44     def get_value_sql(self, column_name, wheres=''):
 45         """
 46         获取指定条件的字段值————多于条记录时,只取第一条记录
 47         :param column_name: 单个字段名,如:id
 48         :param wheres: 查询条件
 49         :return: 7 (指定的字段值)
 50         """
 51         if wheres:
 52             wheres = ' where ' + wheres
 53 
 54         sql = 'select %(column_name)s from %(table_name)s %(wheres)s limit 1' % \
 55               {'column_name': column_name, 'table_name': self.__table_name, 'wheres': wheres}
 56         return sql
 57 
 58     def get_value_list_sql(self, column_name, wheres=''):
 59         """
 60         获取指定条件记录的字段值列表
 61         :param column_name: 单个字段名,如:id
 62         :param wheres: 查询条件
 63         :return: [1,3,4,6,7]
 64         """
 65         if not column_name:
 66             column_name = self.__pk_name
 67         elif wheres:
 68             wheres = ' where ' + wheres
 69 
 70         sql = 'select array_agg(%(column_name)s) as list from %(table_name)s %(wheres)s' % \
 71               {'column_name': column_name, 'table_name': self.__table_name, 'wheres': wheres}
 72         return sql
 73 
 74     def add_model_sql(self, fields, returning=''):
 75         """新增数据库记录"""
 76         ### 拼接sql语句 ###
 77         # 初始化变量
 78         key_list = []
 79         value_list = []
 80         # 将传入的字典参数进行处理,把字段名生成sql插入字段名数组和字典替换数组
 81         # PS:字符串使用字典替换参数时,格式是%(name)s,这里会生成对应的字串
 82         # 比如:
 83         #   传入的字典为: {'id': 1, 'name': '名称'}
 84         #   那么生成的key_list为:'id','name'
 85         #   而value_list为:'%(id)s,%(name)s'
 86         #   最终而value_list为字符串对应名称位置会被替换成相应的值
 87         for key in fields.keys():
 88             key_list.append(key)
 89             value_list.append('%(' + key + ')s')
 90         # 设置sql拼接字典,并将数组(lit)使用join方式进行拼接,生成用逗号分隔的字符串
 91         parameter = {
 92             'table_name': self.__table_name,
 93             'pk_name': self.__pk_name,
 94             'key_list': ','.join(key_list),
 95             'value_list': ','.join(value_list)
 96         }
 97         # 如果有指定返回参数,则添加
 98         if returning:
 99             parameter['returning'] = ', ' + returning
100         else:
101             parameter['returning'] = ''
102 
103         # 生成可以使用字典替换的字符串
104         sql = "insert into %(table_name)s (%(key_list)s) values (%(value_list)s) returning %(pk_name)s %(returning)s" % parameter
105         # 将生成好的字符串替字典参数值,生成最终可执行的sql语句
106         return sql % fields
107 
108     def edit_sql(self, fields, wheres='', returning=''):
109         """
110         批量编辑数据库记录
111         :param fields: 要更新的字段(字段名与值存储在字典中)
112         :param wheres: 更新条件
113         :param returning: 更新成功后,返回的字段名
114         :param is_update_cache: 是否同步更新缓存
115         :return:
116         """
117         ### 拼接sql语句 ###
118         # 拼接字段与值
119         field_list = [key + ' = %(' + key + ')s' for key in fields.keys()]
120         # 设置sql拼接字典
121         parameter = {
122             'table_name': self.__table_name,
123             'pk_name': self.__pk_name,
124             'field_list': ','.join(field_list)
125         }
126         # 如果存在更新条件,则将条件添加到sql拼接更换字典中
127         if wheres:
128             parameter['wheres'] = ' where ' + wheres
129         else:
130             parameter['wheres'] = ''
131 
132         # 如果有指定返回参数,则添加
133         if returning:
134             parameter['returning'] = ', ' + returning
135         else:
136             parameter['returning'] = ''
137 
138         # 生成sql语句
139         sql = "update %(table_name)s set %(field_list)s %(wheres)s returning %(pk_name)s %(returning)s" % parameter
140         return sql % fields
141 
142     def edit_model_sql(self, pk, fields, wheres='', returning=''):
143         """编辑单条数据库记录"""
144         if wheres:
145             wheres = self.__pk_name + ' = ' + str(pk) + ' and ' + wheres
146         else:
147             wheres = self.__pk_name + ' = ' + str(pk)
148 
149         return self.edit_sql(fields, wheres, returning)
150 
151     def delete_sql(self, wheres='', returning=''):
152         """
153         批量删除数据库记录
154         :param wheres: 删除条件
155         :param returning: 删除成功后,返回的字段名
156         :param is_update_cache: 是否同步更新缓存
157         :return:
158         """
159         # 如果存在条件
160         if wheres:
161             wheres = ' where ' + wheres
162 
163         # 如果有指定返回参数,则添加
164         if returning:
165             returning = ', ' + returning
166 
167         # 生成sql语句
168         sql = "delete from %(table_name)s %(wheres)s returning %(pk_name)s %(returning)s" % \
169               {'table_name': self.__table_name, 'wheres': wheres, 'pk_name': self.__pk_name, 'returning': returning}
170         return sql
171 
172     def delete_model_sql(self, pk, wheres='', returning=''):
173         """删除单条数据库记录"""
174         if wheres:
175             wheres = self.__pk_name + ' = ' + str(pk) + ' and ' + wheres
176         else:
177             wheres = self.__pk_name + ' = ' + str(pk)
178 
179         return self.delete_sql(wheres, returning)
180 
181     def get_list_sql(self, column_name_list='', wheres='', orderby=None, table_name=None):
182         """
183         获取指定条件的数据库记录集
184         :param column_name_list:      查询字段
185         :param wheres:      查询条件
186         :param orderby:     排序规则
187         :param table_name:     查询数据表,多表查询时需要设置
188         :return:
189         """
190         # 初始化查询数据表名称
191         if not table_name:
192             table_name = self.__table_name
193         # 初始化查询字段名
194         if not column_name_list:
195             column_name_list = self.__column_name_list
196         # 初始化查询条件
197         if wheres:
198             # 如果是字符串,表示该查询条件已组装好了,直接可以使用
199             if isinstance(wheres, str):
200                 wheres = 'where ' + wheres
201             # 如果是list,则表示查询条件有多个,可以使用join将它们用and方式组合起来使用
202             elif isinstance(wheres, list):
203                 wheres = 'where ' + ' and '.join(wheres)
204         # 初始化排序
205         if not orderby:
206             orderby = self.__pk_name + ' desc'
207         #############################################################
208 
209         ### 按条件查询数据库记录
210         sql = "select %(column_name_list)s from %(table_name)s %(wheres)s order by %(orderby)s " % \
211               {'column_name_list': column_name_list,
212                'table_name': table_name,
213                'wheres': wheres,
214                'orderby': orderby}
215         return sql
216 
217     def get_count_sql(self, wheres=''):
218         """获取指定条件记录数量"""
219         if wheres:
220             wheres = ' where ' + wheres
221         sql = 'select count(1) as total from %(table_name)s %(wheres)s ' % \
222               {'table_name': self.__table_name, 'wheres': wheres}
223         return sql
224 
225     def get_sum_sql(self, fields, wheres):
226         """获取指定条件记录数量"""
227         sql = 'select sum(%(fields)s) as total from %(table_name)s where %(wheres)s ' % \
228               {'table_name': self.__table_name, 'wheres': wheres, 'fields': fields}
229         return sql
230 
231     def get_min_sql(self, fields, wheres):
232         """获取该列记录最小值"""
233         sql = 'select min(%(fields)s) as min from %(table_name)s where %(wheres)s ' % \
234               {'table_name': self.__table_name, 'wheres': wheres, 'fields': fields}
235         return sql
236 
237     def get_max_sql(self, fields, wheres):
238         """获取该列记录最大值"""
239         sql = 'select max(%(fields)s) as max from %(table_name)s where %(wheres)s ' % \
240               {'table_name': self.__table_name, 'wheres': wheres, 'fields': fields}
241         return sql
242 
243     #####################################################################
244 
245 
246     #####################################################################
247     ### 执行Sql ###
248 
249     def select(self, sql):
250         """执行sql查询语句(select)"""
251         with db_helper.PgHelper(self.__db, self.__is_output_sql) as db:
252             # 执行sql语句
253             result = db.execute(sql)
254             if not result:
255                 result = []
256         return result
257 
258     def execute(self, sql):
259         """执行sql语句,并提交事务"""
260         with db_helper.PgHelper(self.__db, self.__is_output_sql) as db:
261             # 执行sql语句
262             result = db.execute(sql)
263             if result:
264                 db.commit()
265             else:
266                 result = []
267         return result
268 
269     def copy(self, values, columns):
270         """批量更新数据"""
271         with db_helper.PgHelper(self.__db, self.__is_output_sql) as db:
272             # 执行sql语句
273             result = db.copy(values, self.__table_name, columns)
274         return result
275 
276     def get_model(self, wheres):
277         """通过条件获取一条记录"""
278         # 生成sql
279         sql = self.get_model_sql(wheres)
280         # 执行查询操作
281         result = self.select(sql)
282         if result:
283             return result[0]
284         return {}
285 
286     def get_model_for_pk(self, pk, wheres=''):
287         """通过主键值获取数据库记录实体"""
288         if not pk:
289             return {}
290         # 生成sql
291         sql = self.get_model_for_pk_sql(pk, wheres)
292         # 执行查询操作
293         result = self.select(sql)
294         if result:
295             return result[0]
296         return {}
297 
298     def get_value(self, column_name, wheres=''):
299         """
300         获取指定条件的字段值————多于条记录时,只取第一条记录
301         :param column_name: 单个字段名,如:id
302         :param wheres: 查询条件
303         :return: 7 (指定的字段值)
304         """
305         if not column_name:
306             return None
307 
308         # 生成sql
309         sql = self.get_value_sql(column_name, wheres)
310         result = self.select(sql)
311         # 如果查询成功,则直接返回记录字典
312         if result:
313             return result[0].get(column_name)
314 
315     def get_value_list(self, column_name, wheres=''):
316         """
317         获取指定条件记录的字段值列表
318         :param column_name: 单个字段名,如:id
319         :param wheres: 查询条件
320         :return: [1,3,4,6,7]
321         """
322         # 生成sql
323         sql = self.get_value_list_sql(column_name, wheres)
324         result = self.select(sql)
325         # 如果查询失败或不存在指定条件记录,则直接返回初始值
326         if result and isinstance(result, list):
327             return result[0].get('list')
328         else:
329             return []
330 
331     def add_model(self, fields, returning=''):
332         """新增数据库记录"""
333         # 生成sql
334         sql = self.add_model_sql(fields, returning)
335         result = self.execute(sql)
336         if result:
337             return result[0]
338         return {}
339 
340     def edit(self, fields, wheres='', returning='', is_update_cache=True):
341         """
342         批量编辑数据库记录
343         :param fields: 要更新的字段(字段名与值存储在字典中)
344         :param wheres: 更新条件
345         :param returning: 更新成功后,返回的字段名
346         :param is_update_cache: 是否同步更新缓存
347         :return:
348         """
349         # 生成sql
350         sql = self.edit_sql(fields, wheres, returning)
351         result = self.execute(sql)
352         if result:
353             # 判断是否删除对应的缓存
354             if is_update_cache:
355                 # 循环删除更新成功的所有记录对应的缓存
356                 for model in result:
357                     self.del_model_for_cache(model.get(self.__pk_name, 0))
358                 # 同步删除与本表关联的缓存
359                 self.del_relevance_cache()
360         return result
361 
362     def edit_model(self, pk, fields, wheres='', returning='', is_update_cache=True):
363         """编辑单条数据库记录"""
364         if not pk:
365             return {}
366         # 生成sql
367         sql = self.edit_model_sql(pk, fields, wheres, returning)
368         result = self.execute(sql)
369         if result:
370             # 判断是否删除对应的缓存
371             if is_update_cache:
372                 # 删除更新成功的所有记录对应的缓存
373                 self.del_model_for_cache(result[0].get(self.__pk_name, 0))
374                 # 同步删除与本表关联的缓存
375                 self.del_relevance_cache()
376         return result
377 
378     def delete(self, wheres='', returning='', is_update_cache=True):
379         """
380         批量删除数据库记录
381         :param wheres: 删除条件
382         :param returning: 删除成功后,返回的字段名
383         :param is_update_cache: 是否同步更新缓存
384         :return:
385         """
386         # 生成sql
387         sql = self.delete_sql(wheres, returning)
388         result = self.execute(sql)
389         if result:
390             # 同步删除对应的缓存
391             if is_update_cache:
392                 for model in result:
393                     self.del_model_for_cache(model.get(self.__pk_name, 0))
394                 # 同步删除与本表关联的缓存
395                 self.del_relevance_cache()
396         return result
397 
398     def delete_model(self, pk, wheres='', returning='', is_update_cache=True):
399         """删除单条数据库记录"""
400         if not pk:
401             return {}
402         # 生成sql
403         sql = self.delete_model_sql(pk, wheres, returning)
404         result = self.execute(sql)
405         if result:
406             # 同步删除对应的缓存
407             if is_update_cache:
408                 self.del_model_for_cache(result[0].get(self.__pk_name, 0))
409                 # 同步删除与本表关联的缓存
410                 self.del_relevance_cache()
411         return result
412 
413     def get_list(self, column_name_list='', wheres='', page_number=None, page_size=None, orderby=None, table_name=None):
414         """
415         获取指定条件的数据库记录集
416         :param column_name_list:      查询字段
417         :param wheres:      查询条件
418         :param page_number:   分页索引值
419         :param page_size:    分页大小, 存在值时才会执行分页
420         :param orderby:     排序规则
421         :param table_name:     查询数据表,多表查询时需要设置
422         :return: 返回记录集总数量与分页记录集
423             {'records': 0, 'total': 0, 'page': 0, 'rows': []}
424         """
425         # 初始化输出参数:总记录数量与列表集
426         data = {
427             'records': 0,  # 总记录数
428             'total': 0,  # 总页数
429             'page': 1,  # 当前页面索引
430             'rows': [],  # 查询结果(记录列表)
431         }
432         # 初始化查询数据表名称
433         if not table_name:
434             table_name = self.__table_name
435         # 初始化查询字段名
436         if not column_name_list:
437             column_name_list = self.__column_name_list
438         # 初始化查询条件
439         if wheres:
440             # 如果是字符串,表示该查询条件已组装好了,直接可以使用
441             if isinstance(wheres, str):
442                 wheres = 'where ' + wheres
443             # 如果是list,则表示查询条件有多个,可以使用join将它们用and方式组合起来使用
444             elif isinstance(wheres, list):
445                 wheres = 'where ' + ' and '.join(wheres)
446         # 初始化排序
447         if not orderby:
448             orderby = self.__pk_name + ' desc'
449         # 初始化分页查询的记录区间
450         paging = ''
451 
452         with db_helper.PgHelper(self.__db, self.__is_output_sql) as db:
453             #############################################################
454             # 判断是否需要进行分页
455             if not page_size is None:
456                 ### 执行sql,获取指定条件的记录总数量
457                 sql = 'select count(1) as records from %(table_name)s %(wheres)s ' % \
458                       {'table_name': table_name, 'wheres': wheres}
459                 result = db.execute(sql)
460                 # 如果查询失败或不存在指定条件记录,则直接返回初始值
461                 if not result or result[0]['records'] == 0:
462                     return data
463 
464                 # 设置记录总数量
465                 data['records'] = result[0].get('records')
466 
467                 #########################################################
468                 ### 设置分页索引与页面大小 ###
469                 if page_size <= 0:
470                     page_size = 10
471                 # 计算总分页数量:通过总记录数除于每页显示数量来计算总分页数量
472                 if data['records'] % page_size == 0:
473                     page_total = data['records'] // page_size
474                 else:
475                     page_total = data['records'] // page_size + 1
476                 # 判断页码是否超出限制,超出限制查询时会出现异常,所以将页面索引设置为最后一页
477                 if page_number < 1 or page_number > page_total:
478                     page_number = page_total
479                 # 记录总页面数量
480                 data['total'] = page_total
481                 # 记录当前页面值
482                 data['page'] = page_number
483                 # 计算当前页面要显示的记录起始位置(limit指定的位置)
484                 record_number = (page_number - 1) * page_size
485                 # 设置查询分页条件
486                 paging = ' limit ' + str(page_size) + ' offset ' + str(record_number)
487             #############################################################
488 
489             ### 按条件查询数据库记录
490             sql = "select %(column_name_list)s from %(table_name)s %(wheres)s order by %(orderby)s %(paging)s" % \
491                   {'column_name_list': column_name_list,
492                    'table_name': table_name,
493                    'wheres': wheres,
494                    'orderby': orderby,
495                    'paging': paging}
496             result = db.execute(sql)
497             if result:
498                 data['rows'] = result
499                 # 不需要分页查询时,直接在这里设置总记录数
500                 if page_size is None:
501                     data['records'] = len(result)
502 
503         return data
504 
505     def get_count(self, wheres=''):
506         """获取指定条件记录数量"""
507         # 生成sql
508         sql = self.get_count_sql(wheres)
509         result = self.select(sql)
510         # 如果查询存在记录,则返回true
511         if result:
512             return result[0].get('total')
513         return 0
514 
515     def get_sum(self, fields, wheres):
516         """获取指定条件记录数量"""
517         # 生成sql
518         sql = self.get_sum_sql(fields, wheres)
519         result = self.select(sql)
520         # 如果查询存在记录,则返回true
521         if result and result[0].get('total'):
522             return result[0].get('total')
523         return 0
524 
525     def get_min(self, fields, wheres):
526         """获取该列记录最小值"""
527         # 生成sql
528         sql = self.get_min_sql(fields, wheres)
529         result = self.select(sql)
530         # 如果查询存在记录,则返回true
531         if result and result[0].get('min'):
532             return result[0].get('min')
533 
534     def get_max(self, fields, wheres):
535         """获取该列记录最大值"""
536         # 生成sql
537         sql = self.get_max_sql(fields, wheres)
538         result = self.select(sql)
539         # 如果查询存在记录,则返回true
540         if result and result[0].get('max'):
541             return result[0].get('max')
542 
543     #####################################################################
544 
545 
546     #####################################################################
547     ### 缓存操作方法 ###
548 
549     def get_cache_key(self, pk):
550         """获取缓存key值"""
551         return ''.join((self.__table_name, '_', str(pk)))
552 
553     def set_model_for_cache(self, pk, value, time=43200):
554         """更新存储在缓存中的数据库记录,缓存过期时间为12小时"""
555         # 生成缓存key
556         key = self.get_cache_key(pk)
557         # 存储到nosql缓存中
558         cache_helper.set(key, value, time)
559 
560     def get_model_for_cache(self, pk):
561         """从缓存中读取数据库记录"""
562         # 生成缓存key
563         key = self.get_cache_key(pk)
564         # 从缓存中读取数据库记录
565         result = cache_helper.get(key)
566         # 缓存中不存在记录,则从数据库获取
567         if not result:
568             result = self.get_model_for_pk(pk)
569             self.set_model_for_cache(pk, result)
570         if result:
571             return result
572         else:
573             return {}
574 
575     def get_model_for_cache_of_where(self, where):
576         """
577         通过条件获取记录实体(我们经常需要使用key、编码或指定条件来获取记录,这时可以通过当前方法来获取)
578         :param where: 查询条件
579         :return: 记录实体
580         """
581         # 生成实体缓存key
582         model_cache_key = self.__table_name + encrypt_helper.md5(where)
583         # 通过条件从缓存中获取记录id
584         pk = cache_helper.get(model_cache_key)
585         # 如果主键id存在,则直接从缓存中读取记录
586         if pk:
587             return self.get_model_for_cache(pk)
588 
589         # 否则从数据库中获取
590         result = self.get_model(where)
591         if result:
592             # 存储条件对应的主键id值到缓存中
593             cache_helper.set(model_cache_key, result.get(self.__pk_name))
594             # 存储记录实体到缓存中
595             self.set_model_for_cache(result.get(self.__pk_name), result)
596             return result
597 
598     def get_value_for_cache(self, pk, column_name):
599         """获取指定记录的字段值"""
600         return self.get_model_for_cache(pk).get(column_name)
601 
602     def del_model_for_cache(self, pk):
603         """删除缓存中指定数据"""
604         # 生成缓存key
605         key = self.get_cache_key(pk)
606         # log_helper.info(key)
607         # 存储到nosql缓存中
608         cache_helper.delete(key)
609 
610     def add_relevance_cache_in_list(self, key):
611         """将缓存名称存储到列表里————主要存储与记录变更关联的"""
612         # 从nosql中读取全局缓存列表
613         cache_list = cache_helper.get(self.__cache_list)
614         # 判断缓存列表是否有值,有则进行添加操作
615         if cache_list:
616             # 判断是否已存储列表中,不存在则执行添加操作
617             if not key in cache_list:
618                 cache_list.append(key)
619                 cache_helper.set(self.__cache_list, cache_list)
620         # 无则直接创建全局缓存列表,并存储到nosql中
621         else:
622             cache_list = [key]
623             cache_helper.set(self.__cache_list, cache_list)
624 
625     def del_relevance_cache(self):
626         """删除关联缓存————将和数据表记录关联的,个性化缓存全部删除"""
627         # 从nosql中读取全局缓存列表
628         cache_list = cache_helper.get(self.__cache_list)
629         # 清除已删除缓存列表
630         cache_helper.delete(self.__cache_list)
631         if cache_list:
632             # 执行删除操作
633             for cache in cache_list:
634                 cache_helper.delete(cache)
635 
636     #####################################################################
View Code

  从完整代码可以看到,重构后的类多了很多sql生成方法,它们其实是从原方法中分享出sql合成代码,将它们独立出来而已。

 

  接下来我们编写单元测试代码,执行一下事务看看效果

 1 #!/usr/bin/evn python
 2 # coding=utf-8
 3 
 4 import unittest
 5 from common import db_helper
 6 from common.string_helper import string
 7 from config import db_config
 8 from logic import product_logic, product_class_logic
 9 
10 
11 class DbHelperTest(unittest.TestCase):
12     """数据库操作包测试类"""
13 
14     def setUp(self):
15         """初始化测试环境"""
16         print('------ini------')
17 
18     def tearDown(self):
19         """清理测试环境"""
20         print('------clear------')
21 
22     def test(self):
23         ##############################################
24         # 只需要看这里,其他代码是测试用例的模板代码 #
25         ##############################################
26         # 测试事务
27         # 使用with方法,初始化数据库链接
28         with db_helper.PgHelper(db_config.DB, db_config.IS_OUTPUT_SQL) as db:
29             # 实例化product表操作类ProductLogic
30             _product_logic = product_logic.ProductLogic()
31             # 实例化product_class表操作类product_class_logic
32             _product_class_logic = product_class_logic.ProductClassLogic()
33             # 初始化产品分类主键id
34             id = 1
35 
36             # 获取产品分类信息(为了查看效果,所以加了这段获取分类信息)
37             sql = _product_class_logic.get_model_for_pk_sql(id)
38             print(sql)
39             # 执行sql语句
40             result = db.execute(sql)
41             if not result:
42                 print('不存在指定的产品分类')
43                 return
44             print('----产品分类实体----')
45             print(result)
46             print('-------------------')
47 
48             # 禁用产品分类
49             fields = {
50                 'is_enable': 0
51             }
52             sql = _product_class_logic.edit_model_sql(id, fields, returning='is_enable')
53             print(sql)
54             # 执行sql语句
55             result = db.execute(sql)
56             if not result:
57                 # 执行失败,执行回滚操作
58                 db.rollback()
59                 print('禁用产品分类失败')
60                 return
61             # 执行缓存清除操作
62             _product_class_logic.del_model_for_cache(id)
63             _product_class_logic.del_relevance_cache()
64             print('----执行成功后的产品分类实体----')
65             print(result)
66             print('-------------------------------')
67 
68             # 同步禁用产品分类对应的所有产品
69             sql = _product_logic.edit_sql(fields, 'product_class_id=' + str(id), returning='is_enable')
70             print(sql)
71             # 执行sql语句
72             result = db.execute(sql)
73             if not result:
74                 # 执行失败,执行回滚操作
75                 db.rollback()
76                 print('同步禁用产品分类对应的所有产品失败')
77                 return
78             # 执行缓存清除操作
79             for model in result:
80                 _product_class_logic.del_model_for_cache(model.get('id'))
81             _product_class_logic.del_relevance_cache()
82             print('----执行成功后的产品实体----')
83             print(result)
84             print('---------------------------')
85 
86             db.commit()
87             print('执行成功')
88         ##############################################
89 
90 if __name__ == '__main__':
91     unittest.main()

  细心的朋友可能会发现,在事务处理中,进行编辑操作以后,会执行缓存的清除操作,这是因为我们在ORM中所绑定的缓存自动清除操作,是在对应的执行方法中,而不是sql生成方法里,所以在进行事务时,如果你使用了缓存的方法,在这里就需要手动添加清除缓存操作,不然就会产生脏数据。

 

  执行结果:

 1 ------ini------
 2 select * from product_class  where id = 1
 3 ----产品分类实体----
 4 [{'add_time': datetime.datetime(2018, 8, 17, 16, 14, 54), 'id': 1, 'is_enable': 1, 'name': '饼干'}]
 5 -------------------
 6 update product_class set is_enable = 0  where id = 1 returning id , is_enable
 7 ----执行成功后的产品分类实体----
 8 [{'id': 1, 'is_enable': 0}]
 9 -------------------------------
10 update product set is_enable = 0  where product_class_id=1 returning id , is_enable
11 ----执行成功后的产品实体----
12 [{'id': 2, 'is_enable': 0}, {'id': 7, 'is_enable': 0}, {'id': 14, 'is_enable': 0}, {'id': 15, 'is_enable': 0}]
13 ---------------------------
14 执行成功
15 ------clear------

 

 

  本文对应的源码下载(一些接口进行了重构,有些还没有处理,所以源码可能直接运行不了,下一章节会讲到所有代码使用ORM模块重构内容)

 

版权声明:本文原创发表于 博客园,作者为 AllEmpty 本文欢迎转载,但未经作者同意必须保留此段声明,且在文章页面明显位置给出原文连接,否则视为侵权。

python开发QQ群:669058475(本群已满)、733466321(可以加2群)    作者博客:http://www.cnblogs.com/EmptyFS/

 

posted @ 2018-08-17 16:43  AllEmpty  阅读(1747)  评论(0编辑  收藏  举报