您的位置:首页 > 编程语言 > Python开发

Python项目实战:个人博客(2):搭建orm框架

2016-04-06 22:35 706 查看

这一部分的代码相对比较复杂,主要用来方便我们进行数据库的操作.

需要注意几点:首先’?’这个占位符的作用是什么?我们在传入sql语句时需要动态的替换值,但是一些值不是字符串的类型,所以不能直接用%s来替换,我们希望用一个?来表示需要替换的地方,然后在专门用一个函数来把?和传入参数替换.

其次,最后的用户接口类里面Field类型的属性都被存在字典的keys和values里面了,直接用dir()是看不到的,不属于类的直接属性了,而是隐藏在属性里面的属性.这是用gettattr做到的.

===============================================

另外,getValueorDefault的用处,如果该key直接存在在keys属性里面,就返回它。否则,我们需要在mappings里面寻找这个key,并且获得对应Filed中存储的默认值.如果这个默认值不是None,比如IntgerField的默认值一般是0,我们就要返回对应的默认值.注意了,keys是从dict继承而来的,它对应的是该实例的属性,而Field则是对应的该类的属性.Field对应的是除主键外的一行的属性.我们创建一个User实例的时候不一定会将一行所有的属性都当作参数输入的,可能只是对应某几个属性而已.总结一下,这个函数的意思是:该实例没有的属性,就用Filed和主键类型提前定义好的默认值来返回.

================================================

还有为什么要删除类的Filed类型的属性,将它存储到mappings里面?有两个好处,第一方便我们操作,将这些自定义属性和类继承的属性分离.我们进行sql操作时我们知道这个表对应数据库的字段的属性在什么地方.第二,防止和实例属性冲突:请看这段代码:

class test(dict):
id = 9
def __getattr(self,key):
try:
return self[key]
except KeyError:
raise  AttributeError(r'''Dict' object has no object %s'''%key)

a= test(id = 10)
print(a.id)//输出9
print(a['id'])//输出10


我们使用获取id的值时,首先寻找实例属性,但是由于继承的是dict,传入的id被放在了keys和values这两个属性里面,所以找不到实例属性id(尽管实际上它应该是实例属性),然后解释器会去找类属性,这样类属性就掩盖了实例属性(找到了类属性就不会调用getattr方法了).所以应该将类属性删除并放入mappings里面.

import asyncio,logging
logging.basicConfig(level = logging.DEBUG)
import aiomysql

@asyncio.coroutine
def create_pool(loop,**kw):
logging.info('create database connection pool...')
global __pool
__pool = aiomysql.create_pool(
host = kw.get('host','localhost'),
port = kw.get('port',3306),
user = kw['user'],
password = kw['password'],
db = kw['db'],
charset = kw.get('charset','utf8'),
autocommit = kw.get('auotcommit',True),
maxsize = kw.get('maxsize',10),#连接池最多10条连接,默认是10
minsize = kw.get('minsize',1),#连接池最少1条连接,默认是1
loop = loop
)
@asyncio.coroutine
def select(sql,args,size = None):
logging.info(sql,args)
global __pool
with (yield from __pool) as conn:#从连接池获取一个connect
cur = yield from conn.cursor(aiomysql.DictCursor)#获取一个cursor,通过aiomysql.DictCursor获取到的cursor在返回结果时会返回一个字典格式
#把sql语句的'?'替换为'%s',并把args的值填充到相应的位置补充成完整的可执行sql语句并执行mysql中得占位符是'?'
yield from cur.execute(sql.replace('?', '%s'), args or ())
#如果有要求的返回行数,则取要求的行数,如果没有,则全部取出
if size:
rs = yield from cur.fetchmany(size)
else:
rs = yield from cur.fetchall()
yield from cur.close()      #关闭cursor
logging.info('rows returned : %s' % len(rs))#显示获取数据的长度
return rs#返回获取的数据
@asyncio.coroutine
def execute(sql, args):
logging.info(sql)
with (yield from __pool) as conn:
try:
cur = yield from conn.cursor()#获取cursor
yield from cur.execute(sql.replace('?', '%s'), args)#执行sql语句
affected = cur.rowcount#返回影响的行数
yield from cur.close()#关闭cursor
except BaseException as e:
raise
return affected
#构造sql语句参数字符串,最后返回的字符串会以','分割多个'?',如 num==2,则会返回 '?, ?'
def create_args_string(num):
L = []
for n in range(num):
L.append('?')
return ', '.join(L)
#用于标识model里每个成员变量的类
#name:名字 column_type:值类型 primary_key:是否primary_key default:默认值(value)
class Field(object):
def __init__(self, name, column_type,  primary_key, default):
self.name = name
self.column_type = column_type
self.primary_key = primary_key
self.default = default
#直接打印对象的实现方法
def __str__(self):
return '<%s, %s:%s>' % (self.__class__.__name__, self.column_type, self.name)
#string类型的默认设定,默认不是主键
class StringField(Field):
def __init__(self, name = None, primary_key = False, default = None, ddl = 'varchar(100)'):
super().__init__(name, ddl, primary_key, default)
#bool类型的默认设定
class BooleanField(Field):
def __init__(self, name = None, default = False):
super().__init__(name, 'boolern', False, default)
#integer类型的默认设定
class IntegerField(Field):
def __init__(self, name=None, primary_key=False, default=0):
super().__init__(name, 'bigint', primary_key, default)
#float类型的默认设定
class FloatField(Field):
def __init__(self, name=None, primary_key=False, default=0.0):
super().__init__(name, 'real', primary_key, default)
#text类型的默认设定
class TextField(Field):
def __init__(self, name=None, default=None):
super().__init__(name, 'text', False, default)
#mappings字典,存放属性名和属性对应的Field实例
#Fields存放除了主键以外的属性名字
class ModelMetalclass(type):
def __new__(cls, name, bases, attrs):
#如果当前类是Model类,不做任何修改,因为不会用model类来映射数据库表
logging.info(name)
if name == 'Model':
return type.__new__(cls, name, bases, attrs)

#获取table(数据库表的名字)名称:attrs的'__table__'键对应的value,如果为空的话则用name(当前处理的类的名字)
tableName = attrs.get('__table__', None) or name
logging.info('found model : %s (table: %s)' % (name, tableName))
#获取所有的Field和主键名
mappings = dict()   #mappings字典,存放所有Field键值对,属性名:value
fields = []         #fields数组,存放除了主键以外的属性名
primaryKey = None   #主键
for k, v in attrs.items():
if isinstance(v, Field):        #如果value是数据库列的映射
logging.info('found mapping : %s ==> %s' % (k, v))
mappings[k] = v             #把符合要求的放到mappings里
if v.primary_key:           #如果当前Field是主键,则记录下来
if primaryKey:#如果已经找到过主键了,那么抛出错误
raise RuntimeError('Duplicate primary key for field:%s' % k)
primaryKey = k          #记录主键,单独的一个属性
else:
fields.append(k)        #不是主键的话把key值放到fields里

if not primaryKey:#如果最后没有主键,抛出错误
raise RuntimeError('Primary key not found')
#把attrs里面的属性删除,防止和实例的属性冲突
for k in mappings.keys():
attrs.pop(k)

escaped_fields = list(map(lambda f : '`%s`' % f, fields))   #把fields的值全部加了个 '',这样名字打印出来就变成了'xxxxx'
attrs['__mappings__'] = mappings        #保存属性和列的映射关系
attrs['__table__'] = tableName          #表名
attrs['__primary_key__'] = primaryKey   #主键属性名
attrs['__fields__'] = fields            # 除主键外的属性名
# 构造默认的SELECT, INSERT, UPDATE和DELETE语句:
attrs['__select__'] = 'select `%s`, %s from `%s`' % (primaryKey, ', '.join(escaped_fields), tableName)
attrs['__insert__'] = 'insert into `%s` (%s, `%s`) values (%s)' % (tableName, ', '.join(escaped_fields), primaryKey, create_args_string(len(escaped_fields) + 1))
attrs['__update__'] = 'update `%s` set %s where `%s`=?' % (tableName, ', '.join(map(lambda f: '`%s`=?' % (mappings.get(f).name or f), fields)), primaryKey)
attrs['__delete__'] = 'delete from `%s` where `%s`=?' % (tableName, primaryKey)
#构造类
return type.__new__(cls, name, bases, attrs)

#Model从dict继承,所以具备所有dict的功能,同时又实现了特殊方法__getattr__()和__setattr__(),因此又可以像引用普通字段那样写:
class Model(dict,metaclass = ModelMetalclass):
def __init__(self, **kw):
super().__init__(**kw)
def __getattr__(self, key):#重写访问属性的方法,没有属性和key一样则抛错
try:
return self[key]
except KeyError:
raise AttributeError(r"'Model' object has no attribute '%s'" % key)
def __setattr__(self, key, value):
self[key] = value
def getValue(self, key):
return getattr(self, key, None)
#访问某个key,如果value是None,则去mappings获取default值
def getValueOrDefault(self, key):
value = self.getValue(key)#访问某个key
print('Getvvvvvvvvvvvvvvvvvvvvvvvalue',key,value)
if value is None:#该key不存在
field = self.__mappings__[key]#从mappings里面寻找默认值
if field.default is not None:
value = field.default() if callable(field.default) else field.default
logging.debug('using default value for %s: %s' % (key, str(value)))
setattr(self, key, value)#给这个实例添加属性,将类的属性转变成为实例的属性
print('set key %s : value: %s '%(key,value))
return value#只有mappings里面的default默认值且对应的key不存在才返回None
@classmethod#将该方法定义为类方法,所有子类都可以用
@asyncio.coroutine
#根据主键查找pk的值,取第一条
def find(cls, pk):
'find object by primary key'
rs = yield from select('%s where `%s`= ?' % (cls.__select__, cls.__primary_key__), [pk], 1)
if len(rs) == 0:#返回0,没有查到任何东西
return None
return cls(**rs[0])#返回cls类的一个实例,初始化的参数是rs[0]
@classmethod
@asyncio.coroutine
def findAll(cls, where=None, args=None, **kw):
' find objects by where clause. '
sql = [cls.__select__]
if where:
sql.append('where')
sql.append(where)
if args is None:
args = []
orderBy = kw.get('orderBy', None)
if orderBy:
sql.append('order by')
sql.append(orderBy)
limit = kw.get('limit', None)
if limit is not None:
sql.append('limit')
if isinstance(limit, int):
sql.append('?')
args.append(limit)
elif isinstance(limit, tuple) and len(limit) == 2:
sql.append('?, ?')
args.extend(limit)
else:
raise ValueError('Invalid limit value: %s' % str(limit))
logging.info('sql = %s' % sql)
rs = yield from select(' '.join(sql), args)
return [cls(**r) for r in rs]
@classmethod
@asyncio.coroutine
def findNumber(cls, selectField, where=None, args=None):
' find number by select and where. '
sql = ['select %s _num_ from `%s`' % (selectField, cls.__table__)]
if where:
sql.append('where')
sql.append(where)
rs = yield from select(' '.join(sql), args, 1)
if len(rs) == 0:
return None
return rs[0]['_num_']
#更新条目数据
@asyncio.coroutine
def update(self):
args = list(map(self.getValue, self.__fields__))
args.append(self.getValue(self.__primary_key__))
rows = yield from execute(self.__update__, args)
if rows != 1:
logging.warn('failed to update by primary key: affected rows: %s' % rows)
#根据主键的值删除条目
@asyncio.coroutine
def remove(self):
args = [self.getValue(self.__primary_key__)]
rows = yield from execute(self.__delete__, args)
if rows != 1:
logging.warn('failed to remove by primary key: affected rows: %s' % rows)
#添加实例方法,因为只有实例(一行)才能保存
#根据当前类的属性,往相关table里插入一条数据
@asyncio.coroutine
def save(self):
print('save start')
args = list(map(self.getValueOrDefault, self.__fields__))#args是所有的类的属性值(value)
args.append(self.getValueOrDefault(self.__primary_key__))#添加主键对应的值
rows = yield from execute(self.__insert__, args)
if rows != 1:
logging.warn('failed to insert record: affected rows: %s' % rows)

###User类
#class User(Model):
#   __table__ = 'users'
#   id = IntegerField(primary_key=True)
#   name = StringField()
#u = User(id = 1993,name = 'ssss')
#m = User(id = 1994,name = 'mmm')
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: