首先要明确:
ORM的编写较为复杂,但编写完成后使用接口进行调用则显得非常简单。并且ORM编写模式基本为
-Field模块
-元类MetaClass
-基类Model
有着较为固定的写法,没必要重复造轮子,能复用尽量复用。重要的是要理解元类这块硬骨头的妙用。
基本思路:
(收集数据;对这些数据进行分类,识别(相对应数据库),生成SQL语句;最后,连接数据库,并执行SQL语句进行操作。)
- User类负责收集数据,并尝试归类出这些数据对应数据库表的映射关系,类如对应表的字段(包含名字、类型、是否为表的主键、默认值)等;
- 它的基类负责执行操作,比如数据库的存储、读取,查找等操作;
- 它的元类负责分类、整理收集的数据并以此创建一些类属性(如SQL语句)供基类作为参数。
即
- 在一个Web App中,所有的数据,包括用户信息,用户发布的日志,评论都放在数据库中,本次实战使用MySQL作为数据库。
- Web App中,有许多地方都要用到数据库,访问数据要创建数据库连接,创建游标对象,执行SQL语句,然后要处理异常,清理资源等。
- 首先,要封装数据库的SELECT,INSERT,UPDATE,DELETE语句
- 其次,由于Web框架使用了基于asyncio的aiohttp,这是基于协程的异步模型。Web App框架采用异步IO编程,aiomysql为MySQL数据库提供了异步IO的驱动。
一旦决定使用异步,则系统每一层都必须是异步,“开弓没有回头箭”。
一步异步,步步异步
期待代码
# 创建实例:
user = User(id=123, name='Michael')
# 存入数据库:
user.insert()
# 查询所有User对象:
users = User.findAll()
完整代码
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import asyncio, logging
import aiomysql
def log(sql, args=()):
logging.info('SQL: %s' % sql)
async def create_pool(loop, **kw):
logging.info('create database connection pool...')
global __pool
__pool = await aiomysql.create_pool(
host=kw.get('host', 'localhost'),
port=kw.get('port', 3306),
user=kw['user'],
password=kw['password'],
db=kw['db'],
charset=kw.get('charset', 'utf8'),
autocommit=kw.get('autocommit', True),
maxsize=kw.get('maxsize', 10),
minsize=kw.get('minsize', 1),
loop=loop
)
#单独封装select,其他insert,update,delete一并封装,理由如下:
#使用Cursor对象执行insert,update,delete语句时,执行结果由rowcount返回影响的行数,就可以拿到执行结果。
#使用Cursor对象执行select语句时,通过featchall()可以拿到结果集。结果集是一个list,每个元素都是一个tuple,对应一行记录。
async def select(sql, args, size=None):
log(sql, args)
global __pool
async with __pool.get() as conn:#打开pool的方法:with await __pool as conn:
# 创建一个结果为字典的游标
async with conn.cursor(aiomysql.DictCursor) as cur:
# 执行sql语句,将sql语句中的'?'替换成'%s'
await cur.execute(sql.replace('?', '%s'), args or ())
# 如果指定了数量,就返回指定数量的记录,如果没有,就返回所有记录
if size:
rs = await cur.fetchmany(size)
else:
rs = await cur.fetchall()
logging.info('rows returned: %s' % len(rs))
return rs#返回的结果集
async def execute(sql, args, autocommit=True):
log(sql)
async with __pool.get() as conn:
if not autocommit:
await conn.begin()
try:
async with conn.cursor(aiomysql.DictCursor) as cur:
await cur.execute(sql.replace('?', '%s'), args)
# 获取操作的记录数
affected = cur.rowcount
if not autocommit:
await conn.commit()
except BaseException as e:
if not autocommit:
await conn.rollback()#数据回滚
raise
return affected
#用于输出元类中创建sql_insert语句中的占位符,计算需要拼接多少个占位符
def create_args_string(num):
L = []
for n in range(num):
L.append('?')
return ', '.join(L)
class Field(object):
def __init__(self, name, column_type, primary_key, default):
self.name = name
self.column_type = column_type
self.primary_key = primary_key
self.default = default
def __str__(self):
return '<%s, %s:%s>' % (self.__class__.__name__, self.column_type, self.name)
class StringField(Field):
def __init__(self, name=None, primary_key=False, default=None, ddl='varchar(100)'):
super().__init__(name, ddl, primary_key, default)
class BooleanField(Field):
def __init__(self, name=None, default=False):
super().__init__(name, 'boolean', False, default)
class IntegerField(Field):
def __init__(self, name=None, primary_key=False, default=0):
super().__init__(name, 'bigint', primary_key, default)
class FloatField(Field):
def __init__(self, name=None, primary_key=False, default=0.0):
super().__init__(name, 'real', primary_key, default)
class TextField(Field):
def __init__(self, name=None, default=None):
super().__init__(name, 'text', False, default)
#定义Model的metaclass元类
#所有的元类都继承自type
#ModelMetaclass元类定义了所有Model基类(继承ModelMetaclass)的子类实现的操作
# -*-ModelMetaclass:为一个数据库表映射成一个封装的类做准备
# 读取具体子类(eg:user)的映射信息
#创造类的时候,排除对Model类的修改
#在当前类中查找所有的类属性(attrs),如果找到Field属性,就保存在__mappings__的dict里,
#同时从类属性中删除Field(防止实例属性覆盖类的同名属性)
#__table__保存数据库表名
class ModelMetaclass(type):
# 调用__init__方法前会调用__new__方法
def __new__(cls, name, bases, attrs):
# cls:当前准备创建的类的对象,name:类的名称,
# bases:类继承的父类集合,attrs:类的方法集合
if name=='Model':
return type.__new__(cls, name, bases, attrs)
# 获取table名称,如果未设置,tableName就是类的名字
tableName = attrs.get('__table__', None) or name
logging.info('found model: %s (table: %s)' % (name, tableName))
# 获取所有的Field(类属性)和主键名:
mappings = dict() #保存映射关系
fields = [] #保存除主键外的属性
primaryKey = None
# key是列名,value是field的子类
for k, v in attrs.items():
if isinstance(v, Field):
logging.info(' found mapping: %s ==> %s' % (k, v))
mappings[k] = v
if v.primary_key:
# 找到主键:
if primaryKey:
raise StandardError('Duplicate primary key for field: %s' % k)
primaryKey = k#此列设为列表的主键
else:
#非主键,一律放在fields
fields.append(k)
if not primaryKey:# 如果遍历了所有属性都没有找到主键,则主键没定义
raise StandardError('Primary key not found.')
#从类属性中删除Field属性
for k in mappings.keys():
attrs.pop(k)#从类属性中删除Field属性,否则,容易造成运行时错误(实例的属性会遮盖类的同名属性)
# 保存非主键属性为字符串列表形式
# 将非主键属性变成`id`,`name`这种形式(带反引号)
# repr函数和反引号的功能一致:取得对象的规范字符串表示
# 将fields中属性名以`属性名`的方式装饰起来
escaped_fields = list(map(lambda f: '`%s`' % f, fields))
attrs['__mappings__'] = mappings # 保存属性和列的映射关系
attrs['__table__'] = tableName# 保存表名
attrs['__primary_key__'] = primaryKey # 主键属性名
attrs['__fields__'] = fields # 除主键外的属性名
# 构造默认的SELECT, INSERT, UPDATE和DELETE语句
attrs['__select__'] = 'select `%s`, %s from `%s`' % (primaryKey, ', '.join(escaped_fields), tableName)
attrs['__insert__'] = 'insert into `%s` (%s, `%s`) values (%s)' % (tableName, ', '.join(escaped_fields), primaryKey, create_args_string(len(escaped_fields) + 1))
attrs['__update__'] = 'update `%s` set %s where `%s`=?' % (tableName, ', '.join(map(lambda f: '`%s`=?' % (mappings.get(f).name or f), fields)), primaryKey)
attrs['__delete__'] = 'delete from `%s` where `%s`=?' % (tableName, primaryKey)
return type.__new__(cls, name, bases, attrs)
#定义ORM所有映射的基类:Model
#Model类的任意子类可以映射一个数据库表
#Model类可以看做是对所有数据库表操作的基本定义的映射
#基于字典查询形式
#Model从dict继承,拥有字典的所有功能,同时实现特殊方法__getattr__和__setattr__,能够实现属性操作
#实现数据库操作的所有方法,定义为class方法,所有继承自Model都具有数据库操作方法
class Model(dict, metaclass=ModelMetaclass):
#也可在此处写 __metaclass__ = ModelMetaclass,与参数处效果相同
def __init__(self, **kw):
super(Model, self).__init__(**kw)
def __getattr__(self, key):
try:
return self[key]
except KeyError:
raise AttributeError(r"'Model' object has no attribute '%s'" % key)
def __setattr__(self, key, value):
self[key] = value
def getValue(self, key):
# 返回对象的属性,如果没有对应属性,则会调用__getattr__
#直接调回内置函数,注意这里没有下划符,注意这里None的用处,是为了当user没有赋值数据时,返回None,调用于update
return getattr(self, key, None)
def getValueOrDefault(self, key):
value = getattr(self, key, None)
if value is None:
field = self.__mappings__[key]
if field.default is not None:
value = field.default() if callable(field.default) else field.default
logging.debug('using default value for %s: %s' % (key, str(value)))
# 将默认值设置进行
setattr(self, key, value)
return value
# 类方法第一个参数为cls,而实例方法第一个参数为self
@classmethod
#这里可以使用User.findAll()是因为:用@classmethod修饰了Model类里面的findAll()
#一般来说,要使用某个类的方法,需要先实例化一个对象再调用方法
#而使用@staticmethod或@classmethod,就可以不需要实例化,直接类名.方法名()来调用
#申明是类方法:有类变量cls传入,cls可以做一些相关的处理
#有子类继承时,调用该方法,传入的类变量cls是子类,而非父类
async def findAll(cls, where=None, args=None, **kw):
' find objects by where clause. '
sql = [cls.__select__]
# 如果where查询条件存在
if where:
sql.append('where')# 添加where关键字
sql.append(where) # 拼接where查询条件
if args is None:
args = []
orderBy = kw.get('orderBy', None)# 获取kw里面的orderby查询条件
if orderBy:
sql.append('order by')# 拼接orderBy字符串
sql.append(orderBy)# 拼接orderBy查询条件
limit = kw.get('limit', None)# 获取limit查询条件
if limit is not None:
sql.append('limit')
if isinstance(limit, int):# 如果limit是int类型
sql.append('?')# sql拼接一个占位符
args.append(limit)# 将limit添加进参数列表,之所以添加参数列表之后再进行整合是为了防止sql注入
elif isinstance(limit, tuple) and len(limit) == 2:# 如果limit是一个tuple类型并且长度是2
sql.append('?, ?')# sql语句拼接两个占位符
args.extend(limit)# 将limit添加进参数列表
else:
raise ValueError('Invalid limit value: %s' % str(limit))
rs = await select(' '.join(sql), args) # 将args参数列表注入sql语句之后,传递给select函数进行查询并返回查询结果
return [cls(**r) for r in rs]
@classmethod
#查询某个字段的数量
async def findNumber(cls, selectField, where=None, args=None):
' find number by select and where. '
# 将列名重命名为_num
sql = ['select %s _num_ from `%s`' % (selectField, cls.__table__)]
if where:
sql.append('where')
sql.append(where)
# 限制结果数为1
rs = await select(' '.join(sql), args, 1)
if len(rs) == 0:
return None
return rs[0]['_num_']
@classmethod
async def find(cls, pk):
' find object by primary key. '
rs = await select('%s where `%s`=?' % (cls.__select__, cls.__primary_key__), [pk], 1)
if len(rs) == 0:
return None
return cls(**rs[0])
#返回一条记录,以dict的形式返回,因为cls的父类继承了dict类
async def save(self):
# 获取所有value
args = list(map(self.getValueOrDefault, self.__fields__))
args.append(self.getValueOrDefault(self.__primary_key__))
rows = await execute(self.__insert__, args)
if rows != 1:
logging.warn('failed to insert record: affected rows: %s' % rows)
async def update(self):
args = list(map(self.getValue, self.__fields__))
args.append(self.getValue(self.__primary_key__))
rows = await execute(self.__update__, args)
if rows != 1:
logging.warn('failed to update by primary key: affected rows: %s' % rows)
async def remove(self):
args = [self.getValue(self.__primary_key__)]
rows = await execute(self.__delete__, args)
if rows != 1:
logging.warn('failed to remove by primary key: affected rows: %s' % rows)
最大的收获是理解过程中查找了很多资料和大神对相关知识的讲解,与君共飨之!
廖雪峰python教程:Day 3 - 编写ORM
廖雪峰python教程:使用元类
关于metaclass,我原以为我是懂的
oop - What are metaclasses in Python? - Stack Overflow(专指e-satis的回答)
深刻理解Python中的元类(metaclass)(上述回答的翻译)
python-进阶-元类在ORM上的应用详解
python-学习-ORM中遇到的 mapping 详解并再总结字典dict
两句话掌握 Python 最难知识点——元类(固然是标题党!但也是很有意思的联想。)
理解 Python super(看似和orm无关,其实super()是全篇下来我最困惑的地方)
MySQL
aiomysql
Things to Know About Python Super
廖雪峰python教程实战 Day 3 - 编写ORM
Python3廖雪峰实战项目:重难点ORM
Python廖雪峰实战web开发(Day3-编写ORM)
MySQL Connector/Python Developer Guide
SQL 教程
python sql cursor用法
Python 数据库的Connection、Cursor两大对象
记录 廖雪峰老师 实战 学习到 Day10的bug 以及解决方案
讨论-RequestHandler