metadata的使用以及简单的orm模式

使用sqllite3和metadata简单的封装了个简单的orm

 

#!/usr/bim/python
#-*-coding: utf-8 -*-
 
import threading
import sqlite3
 
import sys
 
__module__ = sys.modules[__name__]
 
def setup_database(database):
    Mapper.initialize(database)
 
def bind_mapper(mapper_name, model_cls, mapper_cls = None):
 
    if mapper_cls is None:
        mapper_cls = Mapper
    mapper = mapper_cls(model_cls)
    setattr(__module__, mapper_name, mapper)
 
class FieldError(Exception):
 
    def __init__(self, message,  *args, **kwargs):
        Exception.__init__(self, message, *args, **kwargs)
 
 
class Field:
 
    def __init__(self, name, affinity=None, validator=None, required=False, default=None,  **kwargs):
 
        self.affinity = affinity
        self.validator = validator
        self.default = default
        self.required = required
        self.name = name
        self.kwargs = kwargs
 
    def get_value(self):
        if hasattr(self, "data"):
            return self.data
        else:
            raise FieldError("Field is not initilize")
 
    def set_value(self, value):
        self.value = self.process_formdata(value)
 
    value = property(get_value, set_value)
 
    def validate(self):
        if self.required and self.data is None:
            raise FieldError("Field is required!")
        if self.value == self.default:
            return
        if self.validator:
            self.validator(self)
 
    def process_formdata(self, value):
        if value or self.required == False:
            try:
                self.data = value
            except ValueError:
                self.data = None
                raise ValueError('Not a valid integer value')
 
    def _pre_validate(self):
        pass
 
    def __call__(self):
        if hasattr(self, "value"):
            return self.value
        else:
            raise FieldError("Filed is not initilize")
 
 
class IntegerField(Field):
 
    """
    A text field, except all input is coerced to an integer.  Erroneous input
    is ignored and will not be accepted as a value.
    """
 
    def __init__(self, name, affinity=None, validator=None, required=False, defalut=None,  **kwargs):
        Field.__init__(self, name, validator, required, defalut,  **kwargs)
 
    def process_formdata(self, value):
        if value:
            try:
                self.value = int(value)
            except ValueError:
                self.value = None
                raise ValueError('Not a valid integer value')
 
 
def with_metaclass(meta, bases=(object,)):
    return meta("NewBase", bases, {})
 
class ModelMeta(type):
 
    def __new__(metacls, cls_name, bases, attrs):
        fields = {}
        new_attrs = {}
        for k, v in attrs.iteritems():
            if isinstance(v, Field):
                fields[k] = v
            else:
                new_attrs[k] = v
 
        cls = type.__new__(metacls, cls_name, bases, new_attrs)
        cls.fields = cls.fields.copy()
        cls.fields.update(fields)
        return cls
 
class ModelMinix(object):
 
    fields = {}
    
    def __str__(self):
        return '<' + self.__class__.__name__ + \
            ' : {' + ", ".join(["%s=%s" % (field.name, getattr(self, column))
                                for column, field in self.fields.items()]) + '}>'
 
    def save(self):
        return self.__mapper__.save(self)
      
        
class Model(with_metaclass(ModelMeta, (ModelMinix,))):
 
    def __init__(self, **kwargs):
        for k in kwargs:
            try:
                if k in self.fields:
                    setattr(self, k, kwargs[k])
            except:
                raise ValueError("not found filed %s" % (k))
 
    def __json__(self):
        raise NotImplemented("subclass of Model must implement __json__")
 
    @classmethod
    def create_table(cls):
        raise NotImplemented("subclass of Model must implement create_table")
 
 
 
class Mapper:
 
    _local = threading.local()
 
    """Database Mapper"""
    def __init__(self, model_cls):
         self.model = model_cls
         self.model.__mapper__ = self
 
    @staticmethod
    def initialize(database):
        if not hasattr(Model, "__database__ "):
            Mapper.__database__ = database
 
    def execute(self, sql):
        with self.connect() as conn:
            try:
                cursor = conn.cursor()
                cursor.execute(sql)
                conn.commit()
            except sqlite3.Error, e:    
                print "SQLite Error: %s" % e.args[0]
                conn.rollback()
 
    @staticmethod
    def initialized():
        """Returns true if the class variable __database__ has been setted."""
        print Mapper.__database__
        return hasattr(Mapper, "__database__")
 
    @classmethod
    def connect(cls):
        """create a thread local connection if there isn't one yet"""
        # print('connect',cls)
        if not hasattr(cls._local, 'conn'):
            try:
                cls._local.conn = sqlite3.connect(cls.__database__)
                #cls._local.conn.execute('pragma integrity_check')
                cls._local.conn.row_factory = sqlite3.Row
            except sqlite3.Error, e:    
                print "Error %s:" % e.args[0]
        return cls._local.conn
 
 
    def create_table(self):
        sql = 'CREATE TABLE IF NOT EXISTS ' + self.model.__tablename__ + ' (' + \
            'name varchar(50),' + \
            'email varchar(20)' + ')'
 
        with self.connect() as conn:
            try:
                cursor = conn.cursor()
                cursor.execute(sql)
                conn.commit()
                print 'Create table %s'  % (self.model.__tablename__)
            except sqlite3.Error, e:    
                print "SQLite Error: %s" % e.args[0]
                conn.rollback() 
 
    def drop_table(self):
        sql = 'DROP TABLE IF EXISTS ' + self.model.__tablename__
        print 'Drop table ' + self.model.__tablename__
        self.execute(sql)
 
    def deleteby(self, paterns=None, **kwargs):
        dels = []
        vals = []
        for k, v in kwargs.iteritems():
            if k in self.model.fields:
                dels.append(k + '=?')
                vals.append(v)
        sql = 'DELETE FROM %s WHERE %s' % (self.model.__tablename__, ' AND '.join(dels))
        with self.connect() as conn:
            try:
                cursor = conn.cursor()
                cursor.execute(sql, vals)
                conn.commit()
                return True
            except sqlite3.Error, e:    
                print "SQLite Error: %s" % e.args[0]
                conn.rollback()
        return False
 
    def save(self, model):
        cols = model.fields.keys()
        vals = [getattr(model, c) for c in self.model.fields]
        sql = 'INSERT INTO ' + self.model.__tablename__ + \
            ' ('  + ', '.join(cols) + ')' + \
            ' VALUES (' + ', ' .join(['?'] * len(cols)) + ')'
        with self.connect() as conn:
            try:
                cursor = conn.cursor()
                cursor.execute(sql, vals)
                conn.commit()
                print 'save %s' % model
                return True
            except sqlite3.Error, e:    
                print "SQLite Error: %s" % e.args[0]
                conn.rollback()
        return False
 
 
import db
import unittest
import threading
 
class  ModelTest(unittest.TestCase):
    """docstring for  ModelTest"""
    
    def setUp(self):
        class User(db.Model):
            __tablename__ = "users"
            name = db.Field("name", "varchar(50)")
            email = db.Field("email", "varchar(20)", validator=lambda x: 7 < len(x) < 21)
    
        db.setup_database('conf/test.sqlite')
        db.bind_mapper('UserMapper', User)
        db.UserMapper.create_table()
        self.user = User(name='test', email='test@example.com')
 
    def tearDown(self):
        res = db.UserMapper.deleteby(name=self.user.name, email=self.user.email)
        #db.UserMapper.drop_table()
        self.assertEqual(True, res)
 
    def test_model_save(self):
        self.assertEqual(True, self.user.save())
 
    def test_mapper_save(self):
        self.assertEqual(True, db.UserMapper.save(self.user))
 
    def test_in_threads(self):
        n = 10
        ts = []
        for i in range(n):
            t = threading.Thread(target=self.user.save)
            ts.append(t)
        for t in ts:
            t.start()
        for t in ts:
            t.join()
 
 
if __name__ == "__main__":
    unittest.main()
 

 

posted @ 2014-12-29 13:07  媛媛画  阅读(517)  评论(0编辑  收藏  举报