python轻量级orm

python下的orm使用SQLAlchemy比较多,用了一段时间感觉不顺手,主要问题是SQLAlchemy太重,所以自己写了一个orm,实现方式和netsharp类似,oql部分因为代码比较多,没有完全实现

下面是源代码

一,系统配置

configuration.py

# !/usr/bin/python
# -*- coding: UTF-8 -*-

host="192.168.4.1"
port=3306
db="wolf"
user="root"
pwd="xxx"
user_name="xxx"

二,orm.py

   1 # !/usr/bin/python
   2 # -*- coding: UTF-8 -*-
   3 
   4 import sys
   5 from datetime import *
   6 from enum import Enum
   7 import logging
   8 
   9 import MySQLdb
  10 import MySQLdb.cursors
  11 
  12 import configuration
  13 
  14 #########################################################################
  15         
  16 class Field(object):
  17     
  18     property_name = None
  19     column_name = None
  20     group_name = None
  21     column_type_name = None
  22     header = None
  23     memoto = None
  24 
  25     is_primary_key = None
  26     is_auto = False
  27     is_name_equals = None
  28     is_required = None
  29     is_unique = None
  30 
  31     size = None
  32     precision = None
  33 
  34     def __init__(self,**kw):        
  35         for k,v in kw.iteritems():
  36             setattr(self,k,v) 
  37 
  38 class ShortField(Field) :
  39     
  40     def __init__(self,**kw):
  41         self.column_type_name="smallint"
  42         super(ShortField,self).__init__(**kw)
  43        
  44 
  45 class IntField(Field) :
  46     
  47     def __init__(self,**kw):
  48         self.column_type_name="int"
  49         super(IntField,self).__init__(**kw)
  50        
  51 
  52 class LongField(Field) :
  53     
  54     def __init__(self,**kw):
  55         self.column_type_name="bigint"
  56         super(LongField,self).__init__(**kw)
  57               
  58 
  59 class StringFiled(Field) :
  60     
  61     def __init__(self,**kw):
  62         self.column_type_name="nvarchar"
  63         self.size=50         
  64         super(StringFiled,self).__init__(**kw)
  65 
  66 
  67 class BoolFiled(Field) :
  68     def __init__(self,**kw):
  69         self.column_type_name="bool"         
  70         super(BoolFiled,self).__init__(**kw)
  71        
  72 
  73 class FloatFiled(Field) :
  74     def __init__(self,**kw):
  75         self.column_type_name="float"
  76         self.size=8
  77         self.precision=4         
  78         super(FloatFiled,self).__init__(**kw)
  79 
  80 
  81 class DoubleFiled(Field) :
  82     def __init__(self,**kw):
  83         self.column_type_name="double"
  84         self.size=8
  85         self.precision=4         
  86         super(DoubleFiled,self).__init__(**kw)
  87 
  88 
  89 class DecimalFiled(Field) :
  90     def __init__(self,**kw):
  91         self.column_type_name="decimal"
  92         self.size=14
  93         self.precision=8         
  94         super(DecimalFiled,self).__init__(**kw)
  95 
  96 class DateTimeFiled(Field) :
  97     def __init__(self,**kw):
  98        self.column_type_name="datetime"         
  99        super(DateTimeFiled,self).__init__(**kw)
 100 
 101 class BinaryFiled(Field) :
 102     def __init__(self,**kw):
 103        self.column_type_name="longblob"         
 104        super(BinaryFiled,self).__init__(**kw)
 105        
 106 class Reference(object) :
 107     
 108     property_name = None
 109     header = None
 110     foreign_key = None
 111     primary_key = None
 112 
 113     reference_type = None
 114 
 115     def __init__(self,**kw):        
 116         for k,v in kw.iteritems():
 117             setattr(self,k,v)
 118 
 119 class Subs(object) :
 120     
 121     property_name = None
 122     header = None
 123     foreign_key = None
 124     primary_key = None
 125 
 126     sub_type = None
 127 
 128     def __init__(self,**kw):        
 129         for k,v in kw.iteritems():
 130             setattr(self,k,v)   
 131 
 132 #########################################################################
 133 
 134 class EntityState(Enum):
 135     # 瞬时状态,不受数据库管理的状态
 136     # 或者是事务提交后不需要更新的实体
 137     Transient = 0
 138     # 事务提交后将被新增
 139     New = 1
 140 
 141     # 事务提交后将被修更新
 142     Persist = 2
 143 
 144     # 事务提交后将被删除
 145     Deleted = 3
 146 
 147 class Persistable(object) :
 148     
 149     __entity_status__ = EntityState.New
 150     
 151     def __init__(self,**kw):        
 152         for k,v in kw.iteritems():
 153             setattr(self,k,v)
 154 
 155     def to_new(self) :
 156         self.__entity_status__ = EntityState.New
 157 
 158     def to_persist(self) :
 159         self.__entity_status__ = EntityState.Persist
 160 
 161     def to_delete(self) :
 162         self.__entity_status__ = EntityState.Deleted
 163 
 164     def to_transient(self) :
 165         self.__entity_status__ = EntityState.Transient            
 166 
 167 
 168 class Entity(Persistable) :
 169     
 170     id = IntField(is_primary_key=True,is_auto=True)
 171     creator = StringFiled(size=100)
 172     create_time = DateTimeFiled()
 173     update_time = DateTimeFiled()
 174     updator = StringFiled(size=100)
 175 
 176 class BizEntity(Entity) :
 177     
 178     code = StringFiled()
 179     name = StringFiled(size=100) 
 180     memoto = StringFiled(size=500)  
 181 
 182 #########################################################################
 183 
 184 class NColumn :
 185     
 186     property_name = None
 187     column_name = None
 188     group_name = None
 189     column_type_name = None
 190     header = None
 191     memoto = None
 192     column_type_name = None
 193 
 194     is_primary_key = None
 195     is_auto = False
 196     is_name_equals = None
 197     is_required = None
 198     is_unique = None
 199 
 200     size = None
 201     precision = None
 202 
 203     def __repr__(self):
 204         return self.property_name + "["+self.column_name+"]"
 205 
 206 
 207 class NEntity :
 208     
 209     name = None
 210     table_name = None
 211     entity_id = None
 212     header= None
 213     is_view= None
 214     is_refcheck= None
 215     key_column = None
 216     auto_column = None
 217     order_by= None
 218     type = None
 219 
 220     columns = {}
 221     fields = {}
 222     subs = {}
 223     references = {}
 224 
 225     # def __repr__(self):
 226     #     return self.__name__ + "["+self.table_name+"]"
 227 
 228 class NReference :
 229     header = None
 230     foreign_key = None
 231     primary_key = None
 232 
 233     foreign_key_column = None
 234     primary_key_column = None
 235 
 236     reference_type = None
 237 
 238 class NSubs :
 239     header = None
 240     foreign_key = None
 241     primary_key = None
 242 
 243     foreign_key_column = None
 244     primary_key_column = None
 245 
 246     sub_type = None
 247 
 248 class EntityManager :
 249     
 250     entityMap = {}
 251 
 252     @classmethod
 253     def get_meta(cls,type) :
 254         
 255         
 256         ne = cls.entityMap.get(type.__name__)
 257         if ne == None :
 258             ne = cls.parse_entity(type)
 259             cls.entityMap[type.__name__]= ne
 260 
 261         return ne
 262 
 263     @classmethod
 264     def parse_entity(cls,type) :
 265         
 266         ne = NEntity()
 267         ne.type = type
 268 
 269         ne.table_name= type.__table_name__
 270         ne.name = type.__name__
 271         ne.entity_id=type.__name__
 272         ne.header= type.__doc__
 273         ne.is_view= False
 274         ne.is_refcheck= False
 275 
 276         ne.key_column = None
 277         ne.auto_column = None
 278         ne.order_by= None
 279 
 280         ne.columns = {}
 281         ne.fields = {}
 282         ne.full_columns = {}
 283         ne.full_fields = {}
 284         ne.subs = {}
 285         ne.references = {}
 286 
 287         for k in dir(type) :
 288             
 289             v = getattr(type,k)
 290 
 291             if isinstance(v, Field) :
 292 
 293                 c = cls.parse_field(k,v)
 294 
 295                 ne.full_fields[k]=c
 296                 ne.full_columns[c.column_name] = c
 297 
 298                 if c.is_primary_key :
 299                     ne.key_column = c
 300                 else :
 301                     ne.fields[k]=c
 302                     ne.columns[c.column_name] = c
 303 
 304                 if c.is_auto :
 305                     ne.auto_column = c
 306 
 307             if isinstance(v, Reference) :
 308                 r = cls.parse_reference(k,v)
 309                 ne.references[k]=r
 310 
 311             if isinstance(v, Subs) :
 312                 s = cls.parse_subs(k,v)
 313                 ne.subs[k]=s  
 314         
 315         for r in ne.references.values() :
 316             r.foreign_key_column = ne.full_fields[r.foreign_key]
 317 
 318         for r in ne.subs.values() :
 319             r.primary_key_column = ne.key_column
 320 
 321         return ne
 322 
 323     @classmethod
 324     def parse_field(cls,name,field) :
 325         
 326         c = NColumn()
 327         
 328         c.property_name = name
 329         c.column_name = name
 330 
 331         if field.column_name != None :
 332             c.column_name = name
 333 
 334         c.is_primary_key = False
 335         c.is_auto= False
 336         c.is_name_equals = True
 337         c.is_required = False
 338         c.group_name = ""
 339         c.column_type_name = field.column_type_name
 340         c.size= field.size
 341         c.precision = field.precision
 342         c.group_name = field.group_name
 343         c.is_primary_key = field.is_primary_key
 344         c.is_auto = field.is_auto
 345 
 346         return c
 347 
 348     @classmethod
 349     def parse_reference(cls,name,field) :
 350         r = NReference()
 351 
 352         r.property_name = name
 353         r.foreign_key = field.foreign_key
 354         r.header = field.header
 355         r.primary_key = field.primary_key
 356 
 357         if r.primary_key == None :
 358             r.primary_key = "id"
 359 
 360         r.reference_type = field.reference_type
 361 
 362         r.foreign_key_column = None
 363 
 364         rne = EntityManager.get_meta( field.reference_type )
 365         r.primary_key_column = rne.full_fields[r.primary_key]
 366 
 367         return r
 368 
 369     @classmethod
 370     def parse_subs(cls,name,field) :
 371         
 372         s = NSubs()
 373 
 374         s.property_name = name
 375         s.foreign_key = field.foreign_key
 376         s.header = field.header
 377         s.primary_key = field.primary_key
 378 
 379         if s.primary_key == None :
 380             s.primary_key = "id"
 381 
 382         s.primary_key_column = None
 383 
 384         s.sub_type = field.sub_type  
 385 
 386         rne = EntityManager.get_meta( field.sub_type )
 387         s.foreign_key_column = rne.full_fields[field.foreign_key]          
 388 
 389         return s
 390 
 391 class ORMException(Exception):
 392     def __init__(self,value):
 393             self.value=value
 394 
 395 #########################################################################
 396 
 397 class DataAccess :
 398     
 399     conn = None
 400     cursor = None
 401     isClosed = True
 402 
 403     def open(self,host=configuration.host,port=configuration.port,db=configuration.db,user=configuration.user,pwd=configuration.pwd) :
 404         
 405         self.conn = MySQLdb.connect(host=host,port=port,db=db,user=user,passwd=pwd,charset="utf8")
 406         self.cursor = self.conn.cursor()
 407         self.isClosed = False
 408 
 409     def execute(self,cmd,pars = None) :
 410         if self.isClosed :
 411             raise Exception("db is not opened!")
 412 
 413         logging.info(cmd)
 414 
 415         ret = self.cursor.execute(cmd,pars)
 416         return ret
 417 
 418     def get_last_row_id(self) :
 419         return int(self.cursor.lastrowid)
 420 
 421     def fetchone(self,cmd,pars = None) :
 422         if self.isClosed :
 423             raise Exception("db is not opened!")
 424 
 425         logging.info(cmd)
 426         self.cursor.execute(cmd,pars)
 427         row = self.cursor.fetchone()
 428 
 429         return row
 430 
 431     def fetchall(self,cmd,pars = None) :
 432         if self.isClosed :
 433             raise Exception("db is not opened!")
 434             
 435         logging.info(cmd)
 436         self.cursor.execute(cmd,pars)
 437         rows = self.cursor.fetchall()
 438 
 439         return rows
 440 
 441     def executeScalar(self,cmd,pars = None) :
 442         if self.isClosed :
 443             raise Exception("db is not opened!")
 444 
 445         self.cursor.execute(cmd,pars)
 446         row = self.cursor.fetchone()
 447 
 448         if row == None :
 449             return None
 450         if len(row)==0 :
 451             return None
 452         return row[0]
 453 
 454 
 455     def commit(self):
 456         self.conn.commit()
 457     
 458     def rolback(self):
 459         self.conn.rolback()
 460 
 461     def close(self) :
 462         if self.isClosed :
 463             pass;
 464         
 465         self.conn.close()
 466         self.isClosed=True
 467         self.conn = None
 468         self.cursor = None
 469 
 470 #########################################################################
 471 
 472 class SqlGenerator(object):
 473     
 474     @classmethod
 475     def generate_insert(cls,ne) :
 476         
 477         columns = ne.full_columns
 478         if ne.key_column.is_auto :
 479             columns = ne.columns
 480         
 481         sql = 'insert into %s (%s) values (%s);' % ( ne.table_name, ', '.join(columns.keys()), ', '.join(['%s'] * len(columns)))
 482 
 483         return sql
 484 
 485     @classmethod
 486     def generate_update(cls,ne) :
 487 
 488         sets = []
 489 
 490         for c in ne.columns.keys() :
 491             sets.append("%s = %s" % (c,'%s'))
 492 
 493         sql = 'update %s set %s where %s = %s;' % ( ne.table_name, ",".join(sets),ne.key_column.column_name,"%s")
 494         return sql
 495 
 496     @classmethod
 497     def generate_delete(cls,ne) :
 498         
 499         sql = "delete from %s where %s = %s" % (ne.table_name,ne.key_column.column_name,"%s")
 500 
 501         return sql
 502 
 503     @classmethod
 504     def generate_byid(cls,ne) :
 505         
 506         sql = "select %s from %s where %s = %s" % (",".join(ne.full_columns.keys()),ne.table_name,ne.key_column.column_name,'%s')
 507 
 508         return sql
 509 
 510 class Db(object) :
 511     
 512     dao = DataAccess()
 513 
 514     def open(self) :
 515         self.dao.open()
 516 
 517     def create_db(self,db_name) :
 518         
 519         sql = "CREATE SCHEMA %s DEFAULT CHARACTER SET ucs2 ;" % db_name
 520         self.dao.execute(sql)
 521 
 522     def drop_db(self,db_name) :
 523         sql ="DROP DATABASE %s;" % db_name
 524         self.dao.execute(sql)
 525     
 526     def create_table(self,cls) :
 527         
 528         ne= EntityManager.get_meta(cls)
 529         
 530         columns = []
 531         for column in ne.full_columns.values() :
 532             c = self.generate_column(column)
 533             columns.append(c)
 534             
 535         sql ="create table %s( %s )" % (ne.table_name,",".join(columns))
 536 
 537         self.dao.execute(sql,None)
 538 
 539     def generate_column(self,column) :
 540         sql = "%s %s" % (column.column_name , column.column_type_name)
 541 
 542         if column.size != None :
 543             if column.precision != None :
 544                 sql += "(%d,%d)" % (column.size ,column.precision)
 545             else:
 546                 sql += "(%d)" % column.size 
 547 
 548         if column.is_auto :
 549             sql += " AUTO_INCREMENT"
 550 
 551         if column.is_primary_key :
 552             sql +=" PRIMARY KEY"
 553 
 554         return sql
 555 
 556     def drop_table(self,cls) :
 557         ne= EntityManager.get_meta(cls)
 558         sql ="drop table if exists %s" % ne.table_name
 559 
 560         self.dao.execute(sql,None)
 561 
 562     def commit(self) :
 563         self.dao.commit()
 564 
 565     def close(self) :
 566         self.dao.close()  
 567 
 568 #########################################################################
 569 
 570 class SetQuery(object) :
 571 
 572     dao = None
 573 
 574     def __init__(self,dao) :
 575         self.dao = dao
 576 
 577     def query(self,ne,entity) :
 578         
 579         id = getattr(entity,ne.key_column.property_name)
 580         if id == None :
 581             raise ORMException("%s.id不能为空" % ne.table_name)
 582         
 583         #当前实体
 584         pars = [ id ] 
 585         sql = SqlGenerator.generate_byid(ne)
 586         row = self.dao.fetchone(sql,pars)
 587         if row == None :
 588             entity = None
 589             return None
 590         entity = self.read_row(ne,row,entity)
 591 
 592         self.query_iter(entity,ne)
 593 
 594         return entity
 595     
 596     def query_iter(self,entity,ne) :
 597         
 598         #查询引用实体
 599         for rm in ne.references.values() :
 600             filter = "%s = %s" % (rm.primary_key_column.column_name,"%s")
 601             pars = [getattr(entity,rm.foreign_key)]
 602             re = EntityManager.get_meta(rm.reference_type)
 603 
 604             sql = SqlGenerator.generate_byid(re)
 605             row = self.dao.fetchone(sql,pars) 
 606             ref = self.read_row(re,row)
 607             setattr(entity,rm.property_name,ref)        
 608         
 609         #查询子实体
 610         for sm in ne.subs.values() :
 611             filter = "%s = %s" % (sm.foreign_key_column.column_name,"%s")
 612             pars = [ getattr(entity,ne.key_column.property_name) ] 
 613             se = EntityManager.get_meta(sm.sub_type)
 614             subs = self.do_query(se,filter,pars)
 615 
 616             setattr(entity,sm.property_name,subs)
 617 
 618             for sub in subs :
 619                 self.query_iter(sub,se)
 620 
 621     def do_query(self,ne,filter,pars) :
 622         
 623         sql = "select %s from %s where %s" % (",".join(ne.full_columns.keys()),ne.table_name,filter)
 624         rows = self.dao.fetchall(sql,pars)
 625 
 626         entities = []
 627         for row in rows :
 628             entity = self.read_row(ne,row)
 629             entities.append(entity)
 630 
 631         return entities
 632 
 633     def read_row(self,ne,row,entity = None) :
 634         
 635         if entity == None :
 636             entity = ne.type()
 637 
 638         index = 0;
 639         for c in ne.full_columns.values() :
 640             cell_value = row[index]
 641             setattr(entity,c.property_name,cell_value)
 642             index +=1
 643 
 644         return entity  
 645 #########################################################################
 646           
 647 class Persister(object) :
 648     
 649     isClosed = True
 650     dao = DataAccess()
 651 
 652     def query_first(self,type,filters=None,pars=None):
 653         
 654         lis = self.query_list(type,filters,pars)
 655         if len(lis) == 0 :
 656             return None
 657         
 658         entity = lis[0]
 659         entity = self.byid(entity)
 660 
 661         return entity
 662 
 663     #实体列表查询
 664     #只支持基于主实体的一级查询
 665     def query_list(self,type,filters=None,pars=None) :
 666         
 667         ne = EntityManager.get_meta( type )
 668 
 669         columns = []
 670         for key in ne.full_columns.keys() :
 671             columns.append(ne.name+"."+key)
 672 
 673         sqls =["select %s from %s as %s" % (",".join(columns),ne.table_name,ne.name)]
 674         for r in ne.references.values() :
 675             re = EntityManager.get_meta(r.reference_type)
 676             sqls.append( "left join %s as %s on %s.%s = %s.%s" % (re.table_name, r.property_name,ne.name,r.foreign_key_column.column_name,r.property_name,re.key_column.column_name))
 677 
 678         if filters != None and filters.strip() != "" :
 679             sqls.append("where " + filters)
 680 
 681         sql = "\n".join(sqls)
 682 
 683         setter = SetQuery(self.dao)
 684 
 685         rows = self.dao.fetchall(sql,pars)
 686         entities = []
 687 
 688         for row in rows :
 689             entity = setter.read_row(ne,row)
 690             entities.append(entity)
 691 
 692         return entities
 693 
 694     def byid(self,entity) :
 695         
 696         ne = EntityManager.get_meta(entity.__class__)
 697         query = SetQuery(self.dao)
 698         entity = query.query(ne,entity)
 699         
 700         return entity        
 701 
 702     def save(self,entity) :
 703         
 704         if entity.__entity_status__ == None :
 705             raise ORMException("调用save方法必须设置实体__entity_state__")
 706         
 707         ne = EntityManager.get_meta(entity.__class__)
 708 
 709         for r in ne.references.values() :
 710             ref = getattr(entity,r.property_name)
 711             fk = getattr(entity,r.foreign_key)
 712             if (fk == None or isinstance(fk,Field)) and ref != None :
 713                 pk = getattr(ref,r.primary_key)
 714                 setattr(entity,r.foreign_key,pk)
 715 
 716         if entity.__entity_status__ == EntityState.New :
 717             self.add(entity)
 718             id = getattr(entity,ne.key_column.property_name)
 719 
 720             for r in ne.subs.values() :
 721                 subs = getattr(entity,r.property_name)
 722                 if isinstance(subs,Subs) :
 723                     break
 724                 if subs is None :
 725                     continue
 726                 for sub in subs :
 727                     sub.to_new()
 728                     setattr(sub,r.foreign_key_column.property_name,id)
 729                     self.save(sub)
 730             entity.to_transient()     
 731 
 732         if entity.__entity_status__ == EntityState.Persist :
 733             self.update(entity)
 734             id = getattr(entity,ne.key_column.property_name)
 735 
 736             for r in ne.subs.values() :
 737                 subs = getattr(entity,r.property_name)
 738                 if isinstance(subs,Subs) :
 739                     break
 740                 if subs is None :
 741                     continue                
 742                 for sub in subs :
 743                     setattr(sub,r.foreign_key_column.property_name,id)
 744                     self.save(sub) 
 745             entity.to_transient()
 746 
 747         if entity.__entity_status__ == EntityState.Deleted :
 748             for r in ne.subs.values() :
 749                 subs = getattr(entity,r.property_name)
 750                 for sub in subs :
 751                     sub.to_delete()
 752                     self.save(sub)
 753             self.delete(entity)                    
 754 
 755         if entity.__entity_status__ == EntityState.Transient :
 756             id = getattr(entity,ne.key_column.property_name)
 757             for r in ne.subs.values() :
 758                 subs = getattr(entity,r.property_name)
 759                 if isinstance(subs,Subs) :
 760                     break
 761                 if subs is None :
 762                     continue                
 763                 for sub in subs :
 764                     setattr(sub,r.foreign_key_column.property_name,id)
 765                     self.save(sub)
 766 
 767     def add(self,entity) :
 768         ne = EntityManager.get_meta(entity.__class__)
 769         sql = SqlGenerator.generate_insert(ne)
 770 
 771         if isinstance(entity,Entity) :
 772             entity.create_time = datetime.now()
 773             entity.creator = configuration.user_name
 774 
 775         pars = []
 776 
 777         fields = ne.full_fields
 778         if ne.key_column.is_auto :
 779             fields = ne.fields    
 780 
 781         for c in fields.keys() :
 782             field_value = getattr(entity,c)
 783             
 784             if isinstance(field_value,Field) :
 785                 field_value = None
 786 
 787             pars.append(field_value)
 788 
 789         self.dao.execute(sql,pars)
 790 
 791         if ne.auto_column != None :
 792             auto_id = self.dao.get_last_row_id()
 793             setattr(entity, ne.auto_column.column_name,auto_id)
 794 
 795     def update(self,entity) :
 796         
 797         ne = EntityManager.get_meta(entity.__class__)
 798         sql = SqlGenerator.generate_update(ne)
 799 
 800         if isinstance(entity,Entity) :
 801             entity.update_time = datetime.now()
 802             entity.updator = configuration.user_name
 803 
 804         pars = []
 805         for c in ne.fields.keys() :
 806             field_value = getattr(entity,c)
 807             # if field_value == None :
 808             #     field_value = "null"
 809             if isinstance(field_value,Field) :
 810                 field_value = None
 811 
 812             pars.append(field_value)
 813         
 814         pars.append( getattr(entity,ne.key_column.property_name) )
 815 
 816         self.dao.execute(sql,pars)
 817         
 818     def delete(self,entity) :
 819         
 820         ne = EntityManager.get_meta(entity.__class__)
 821         sql = SqlGenerator.generate_delete(ne)
 822 
 823         pars = [ getattr(entity,ne.key_column.property_name) ] 
 824 
 825         self.dao.execute(sql,pars)
 826 
 827     def execute(self,cmd,pars = None) :
 828         return self.dao.execute(cmd,pars)
 829 
 830     def fetchone(self,cmd,pars = None) :
 831         return self.dao.fetchone(cmd,pars)
 832 
 833     def fetchall(self,cmd,pars = None) :
 834         return self.dao.fetchall(cmd,pars)
 835 
 836     def executeScalar(self,cmd,pars = None) :
 837         return self.dao.executeScalar(cmd,pars)      
 838     
 839     def commit(self) :
 840         self.dao.commit()
 841     
 842     def open(self,host=configuration.host,port=configuration.port,db=configuration.db,user=configuration.user,pwd=configuration.pwd) :
 843         
 844         self.dao = DataAccess()
 845         self.dao.open(host=host,port=port,db=db,user=user,pwd=pwd)
 846         self.isClosed = False
 847 
 848     def close(self):
 849         self.dao.close()
 850         self.dao.isClosed=True
 851 
 852 #########################################################################     
 853 
 854 class PTable(Persistable):
 855     
 856     __table_name__ = 'tables'
 857 
 858     table_catalog =  StringFiled(size=512)              #
 859     table_schema = StringFiled(size=64)                 #
 860     table_name = StringFiled(size=64, is_primary_key=True)
 861     table_type = StringFiled(size=64)                   #
 862     engine = StringFiled(size=64)                       #
 863     version = LongField()                      #
 864     row_format = StringFiled(size=10)                   #
 865     table_rows = LongField()                   #
 866     avg_row_length = LongField()               #
 867     data_length = LongField()                  #
 868     max_data_length = LongField()              #
 869     index_length = LongField()                 #
 870     data_free = LongField()                    #
 871     auto_increment = LongField()               #
 872     create_time = DateTimeFiled()                        #
 873     update_time = DateTimeFiled()                        #
 874     check_time = DateTimeFiled()                         #
 875     table_collation = StringFiled(size=32)              #
 876     checksum = LongField()                     #
 877     create_options = StringFiled(size=2048)             #
 878     table_comment = StringFiled(size=2048)              #
 879 
 880 class PColumn(Persistable):
 881     
 882     __table_name__ = 'columns'
 883 
 884     table_schema = StringFiled(size=255)
 885     table_name = StringFiled(size=255)
 886     column_name = StringFiled(size=50, is_primary_key=True)
 887     data_type = StringFiled(size=255)
 888     character_maximum_length = StringFiled(size=255) #字符类型时,字段长度
 889     column_key = StringFiled(size=255) #PRI为主键,UNI为unique,MUL是什么意思?
 890     column_comment = StringFiled(size=255) #字段说明
 891     extra = StringFiled(size=255) #'auto_increment'
 892     numeric_precision = IntField()
 893     numeric_scale= IntField()
 894 
 895 class GedColumn(Persistable):
 896     batch = IntField()
 897     dbtype = StringFiled(size=50)    
 898 
 899 class dbtype(Persistable) :
 900     
 901     __table_name__ = 'dbtype'
 902 
 903     id = IntField()                                    #
 904     code = StringFiled(size=50, is_primary_key=True)                               #
 905     name = StringFiled(size=50)                               #
 906     host = StringFiled(size=50)                               #
 907     port = IntField()                                  #
 908     user = StringFiled(size=50)                               #
 909     passwd = StringFiled(size=50)                             #
 910     db = StringFiled(size=50)                                 #
 911     charset = StringFiled(size=50)                            #    
 912 
 913 
 914 class EntityGenerator(object) :
 915 
 916     pm = Persister()
 917     dic = {}
 918 
 919     def open(self) :
 920         
 921         self.pm.open(configuration.host,configuration.port,"information_schema",configuration.user,configuration.pwd)
 922 
 923         self.dic["tinyint"] = "BoolFiled"
 924         self.dic["smallint"] = "ShortField"
 925         self.dic["mediumint"] = "IntField"
 926         self.dic["int"] = "IntField"
 927         self.dic["integer"] = "IntField"
 928         self.dic["bigint"] = "LongField"
 929         self.dic["float"] = "FloatFiled"
 930         self.dic["double"] = "DoubleFiled"
 931         self.dic["decimal"] = "DecimalFiled"
 932         self.dic["date"] = "DateTimeFiled"
 933         self.dic["time"] = "DateTimeFiled"
 934         self.dic["year"] = "IntField"
 935         self.dic["datetime"] = "DateTimeFiled"
 936         self.dic["timestamp"] = "DateTimeFiled"
 937         self.dic["char"] = "StringFiled"
 938         self.dic["varchar"] = "StringFiled"
 939         self.dic["tinyblob"] = "StringFiled" 
 940         self.dic["tinytext"] = "StringFiled"
 941         self.dic["blob"] = "StringFiled"
 942         self.dic["text"] = "StringFiled"
 943         self.dic["mediumblob"] = "BinaryFiled"
 944         self.dic["mediumtext"] = "StringFiled"
 945         self.dic["longblob"] = "BinaryFiled"
 946         self.dic["longtext"] = "StringFiled"       
 947 
 948     def close(self) :
 949         self.pm.close()
 950 
 951     # 根据数据库生成实体
 952     def generate_db(self,db_name) :
 953         ts = self.pm.query_list(PTable,"table_schema = %s",[db_name])
 954 
 955         for t in ts:
 956             self.generate_table(t.table_name,t.table_comment)
 957 
 958     def generate_table(self,table_name,memoto) : 
 959         
 960         cls_name = self.get_class_name(table_name)
 961 
 962         self.out_put( "")
 963         self.out_put(  "#%s" % memoto)
 964         self.out_put(  "class %s(Persistable) : " % cls_name)
 965         self.out_put(  "" )
 966         self.out_put(  "    __table_name__ = '%s'" % table_name )
 967         self.out_put(  "" )
 968         
 969         cs = self.pm.query_list(PColumn,"table_name = %s ",[table_name] )
 970 
 971         for c in cs:
 972             item = self.generate_column(c)
 973             self.out_put(  item )
 974 
 975     def out_put(self,txt) :
 976         print txt
 977     
 978     def get_class_name(self,table_name) :
 979         cls_name = table_name
 980         splits = cls_name.split("_")
 981         if len(splits) > 1 :
 982             items = []
 983             
 984             is_first = True
 985             for item in splits :
 986                 if is_first :
 987                     is_first = False
 988                     continue
 989                     
 990                 items.append(item[0].upper()+item[1:len(item)]) 
 991 
 992             cls_name = "".join(items)
 993 
 994         return cls_name
 995 
 996     def generate_column(self,c) : 
 997         
 998 
 999         properties = []
1000 
1001         if c.character_maximum_length != None :
1002             properties.append("size = %d" % c.character_maximum_length)
1003         
1004         if c.data_type == "decimal" :
1005             properties.append("size = %d" % c.numeric_precision)
1006             properties.append("precision = %d" % c.numeric_scale)
1007 
1008         if c.column_key == "PRI":
1009             properties.append("is_primary_key=True")
1010 
1011         if c.extra == 'auto_increment' :
1012             properties.append( "is_auto = True" )
1013 
1014         item = "    %s = %s( %s )" % (c.column_name.lower(),self.dic[c.data_type],",".join(properties))
1015         item = item.ljust(60)
1016 
1017         if c.column_comment != None :
1018             item = item +"# "+c.column_comment
1019         return item
1020 
1021     #把数据库表结构生成到ged的columns表中
1022     def ged_db(self) :
1023         
1024         db = Persister()
1025         db.open(configuration.host,configuration.port,"ged",configuration.user,configuration.pwd)
1026         ds = db.query_list(dbtype)
1027 
1028         columns = []
1029 
1030         for d in ds :
1031             self.ged_fields(d,columns)
1032 
1033         for column in columns :
1034             db.add(column)
1035 
1036         db.commit()
1037         db.close()
1038 
1039     def ged_fields(self,d,columns) :
1040         db = Persister()
1041         db.open(d.host,d.port,"information_schema",d.user,d.passwd)
1042         cs = db.query_list(PColumn,"table_schema = %s",[d.db] )
1043 
1044         for c in cs :
1045             gedc = GedColumn()
1046             gedc.dbtype=d.code
1047             gedc.column_comment=c.column_comment
1048             gedc.column_key=c.column_key
1049             gedc.column_name=c.column_name
1050             gedc.data_type=c.data_type
1051             gedc.extra=c.extra
1052             gedc.table_name=c.table_name
1053             gedc.table_schema=c.table_schema
1054             gedc.character_maximum_length=c.character_maximum_length
1055 
1056             columns.append(gedc)
1057 
1058         db.close()
1059         
1060 #########################################################################  
View Code

三,test_orm.py

  1 # !/usr/bin/python
  2 # -*- coding: UTF-8 -*-
  3 
  4 import sys
  5 from dao.configuration import *
  6 from dao.orm import *
  7 import startup
  8 import logging
  9 import unittest
 10 
 11 ##############################################################
 12 class Customer(BizEntity) :
 13     '客户'
 14 
 15     __table_name__ = "ns_customer"
 16     
 17 
 18 class Product(BizEntity) :
 19     '产品'
 20 
 21     __table_name__ = "ns_product"
 22     
 23 
 24 class OrderItem(Entity) :
 25     '订单明细'
 26 
 27     __table_name__ = "ns_order_item"
 28 
 29     quantity = DecimalFiled()
 30     price = DecimalFiled()
 31     amount = DecimalFiled()   
 32 
 33     product_id = IntField()
 34     product = Reference(foreign_key = "product_id",header="产品",reference_type = Product)  
 35 
 36     order_id = IntField()    
 37 
 38 class SalesOrder(Entity) :
 39     '销售订单'
 40     
 41     __table_name__ = "ns_order"
 42 
 43     code = StringFiled()
 44     quantity = DecimalFiled()
 45     price = DecimalFiled()
 46     amount = DecimalFiled()
 47 
 48     customer_id = IntField()
 49     customer = Reference(foreign_key = "customer_id",header="客户",reference_type = Customer)
 50 
 51     items = Subs(foreign_key = "order_id",header="订单明细",sub_type = OrderItem)
 52 
 53 class Wdbtype(Persistable) :
 54     
 55     __table_name__ = 'ns_dbtype'
 56 
 57     code = StringFiled( size = 50,is_primary_key=True )     #
 58     name = StringFiled( size = 50 )                         #
 59     host = StringFiled( size = 50 )                         #
 60     port = IntField(  )                                     #
 61     user = StringFiled( size = 50 )                         #
 62     passwd = StringFiled( size = 50 )                       #
 63     db = StringFiled( size = 50 )                           #
 64     charset = StringFiled( size = 50 )                      #    
 65 
 66 ##############################################################
 67 
 68 
 69 __author__ = 'xufangbo'
 70 
 71 class orm_test(unittest.TestCase) :
 72     
 73     item_size = 10
 74     order = None
 75     
 76     def setUp(self):
 77 
 78         self.set_db()
 79         self.test_create()
 80 
 81     def tearDown(self) :
 82         
 83         db = Db()
 84         db.open()
 85 
 86         clss = [SalesOrder,OrderItem,Customer,Product,Wdbtype]
 87 
 88         for cls in clss :
 89             db.drop_table(cls)
 90 
 91         db.commit()
 92         db.close() 
 93 
 94     def test_byid(self) :
 95         
 96         order = self.order
 97         
 98         pm = Persister()
 99         pm.open()        
100         
101         pm.byid(order)
102 
103         self.assertEqual(order.code,"DD001")
104         self.assertEqual(len(order.items),self.item_size)
105         self.assertIsNotNone(order.create_time)
106         self.assertIsNotNone(order.creator)
107 
108         self.assertIsNotNone(order.customer)
109         self.assertIsNotNone(order.items[0].product)
110 
111         pm.commit()
112         pm.close()        
113     
114     def test_persist(self) :
115 
116         pm = Persister()
117         pm.open()
118 
119         order = self.order
120 
121         order.code="dd002"
122         order.to_persist()
123 
124         order.items[1].to_delete()
125         order.items[2].price = 18
126         order.items[2].to_persist()
127 
128         pm.save(order)
129 
130         self.assertIsNotNone(order.update_time)
131         self.assertIsNotNone(order.updator)    
132 
133         pm.byid(order)
134 
135         self.assertEqual(len(order.items),self.item_size-1)
136         self.assertEqual(order.code,"dd002")
137         self.assertIsNotNone(order.update_time)
138         self.assertIsNotNone(order.updator)
139 
140         pm.commit()
141         pm.close()
142 
143     def test_delete(self) :
144     
145         pm = Persister()
146         pm.open()
147 
148         order = self.order
149 
150         order.to_delete()
151         pm.save(order)
152 
153         o = pm.byid(order)
154         self.assertIsNone(o)
155 
156         pm.commit()
157         pm.close()       
158 
159         #=============================
160         dao = DataAccess()
161         dao.open()
162         
163         sql = 'select count(0) from %s where order_id = %s' % (EntityManager.get_meta(OrderItem).table_name,'%s')
164         pars = [order.id]
165         count = int(dao.executeScalar(sql,pars))
166 
167         self.assertEqual(count,0)
168 
169         dao.close()            
170 
171     def test_create(self) :
172         
173         pm = Persister()
174         pm.open()
175 
176         customer = Customer()
177         customer.code = "C001"
178         customer.name="北京千舟科技发展有限公司"
179         pm.save(customer)
180 
181         product = Product()
182         product.code = "P001"
183         product.name="电商ERP标准版4.0"
184         pm.save(product)
185 
186         order = SalesOrder();
187         order.code="DD001"
188         order.quantity=2
189         order.price =2 
190         order.amount=4.6
191         order.customer = customer
192         order.items = []
193 
194         for i in range(self.item_size) :
195             item = OrderItem()
196             item.quantity=i+1
197             item.price =i+1
198             item.amount= (i+1)*(i+1)
199             item.product = product
200             order.items.append(item)            
201 
202         pm.save(order)
203 
204         self.order = order
205 
206         self.assertIsNotNone(order.id)
207         self.assertIsNotNone(order.create_time)
208         self.assertIsNotNone(order.creator)
209 
210         pm.commit()
211         pm.close()
212 
213     def test_query(self) :
214         
215         pm = Persister()
216         pm.open()
217 
218         orders = pm.query_list(SalesOrder)
219         count1 = len(orders)
220 
221         customer = Customer()
222         customer.code = "C003"
223         customer.name="北京千舟科技发展有限公司"
224         pm.save(customer)
225 
226         product = Product()
227         product.code = "P003"
228         product.name="电商ERP标准版4.0"
229         pm.save(product)
230 
231         for i in range(self.item_size) :
232             order = SalesOrder();
233             order.code="DD001"
234             order.quantity=2
235             order.price =2 
236             order.amount=4.6
237             order.customer = customer
238             order.items = []
239 
240             pm.save(order)
241 
242         orders = pm.query_list(SalesOrder)
243         count2 = len(orders)
244 
245         self.assertEquals( count1 +self.item_size,count2 ) 
246 
247         orders = pm.query_list(SalesOrder,"customer.code = %s",["C003"])
248         count3 = len(orders)
249         self.assertEqual(count3,self.item_size)
250 
251         pm.commit()
252         pm.close()
253     
254     def test_customer_primary_key(self) :
255         pm = Persister()
256         pm.open()
257 
258         dt = Wdbtype(code='zl44',name='专利',host='mysql',port=3310,db='patent',user='root',passwd='mysql123',charset='utf8')
259 
260         pm.add(dt)
261         pm.delete(dt)
262 
263         pm.commit()
264         pm.close()        
265 
266     def set_db(self) :
267         
268         db = Db()
269         db.open()
270 
271         clss = [SalesOrder,OrderItem,Customer,Product,Wdbtype]
272 
273         for cls in clss :
274             db.drop_table(cls)
275             db.create_table(cls)
276 
277         db.commit()
278         db.close() 
279 
280 class EntityGeneratorTest(unittest.TestCase) :
281 
282     def test_generator(self) :
283         
284         generator = EntityGenerator()
285         generator.open()
286         generator.generate_db("wolf")
287         generator.close()        
288 
289 
290 if __name__ == '__main__':
291     unittest.main()
View Code

 

posted on 2018-02-23 21:19  Netsharp  阅读(352)  评论(0编辑  收藏  举报

导航