廖雪峰python教程day3-编写ORM

本文深入浅出地介绍了ORM的工作原理及其实现细节,包括Field模块、元类MetaClass和基类Model的应用。通过实战演示了如何使用Python进行数据库操作,如创建、查询、更新和删除等。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

首先要明确:
ORM的编写较为复杂,但编写完成后使用接口进行调用则显得非常简单。并且ORM编写模式基本为
-Field模块
-元类MetaClass
-基类Model

有着较为固定的写法,没必要重复造轮子,能复用尽量复用。重要的是要理解元类这块硬骨头的妙用。
基本思路:
(收集数据;对这些数据进行分类,识别(相对应数据库),生成SQL语句;最后,连接数据库,并执行SQL语句进行操作。)

  • User类负责收集数据,并尝试归类出这些数据对应数据库表的映射关系,类如对应表的字段(包含名字、类型、是否为表的主键、默认值)等;
  • 它的基类负责执行操作,比如数据库的存储、读取,查找等操作;
  • 它的元类负责分类、整理收集的数据并以此创建一些类属性(如SQL语句)供基类作为参数。

  1. 在一个Web App中,所有的数据,包括用户信息,用户发布的日志,评论都放在数据库中,本次实战使用MySQL作为数据库。
  2. Web App中,有许多地方都要用到数据库,访问数据要创建数据库连接,创建游标对象,执行SQL语句,然后要处理异常,清理资源等。
  3. 首先,要封装数据库的SELECT,INSERT,UPDATE,DELETE语句
  4. 其次,由于Web框架使用了基于asyncio的aiohttp,这是基于协程的异步模型。Web App框架采用异步IO编程,aiomysql为MySQL数据库提供了异步IO的驱动。
    一旦决定使用异步,则系统每一层都必须是异步,“开弓没有回头箭”。
    一步异步,步步异步
    期待代码
# 创建实例:
user = User(id=123, name='Michael')
# 存入数据库:
user.insert()
# 查询所有User对象:
users = User.findAll()

完整代码

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import asyncio, logging

import aiomysql


def log(sql, args=()):
    logging.info('SQL: %s' % sql)

async def create_pool(loop, **kw):
    logging.info('create database connection pool...')
    global __pool
    __pool = await aiomysql.create_pool(
        host=kw.get('host', 'localhost'),
        port=kw.get('port', 3306),
        user=kw['user'],
        password=kw['password'],
        db=kw['db'],
        charset=kw.get('charset', 'utf8'),
        autocommit=kw.get('autocommit', True),
        maxsize=kw.get('maxsize', 10),
        minsize=kw.get('minsize', 1),
        loop=loop
    )
#单独封装select,其他insert,update,delete一并封装,理由如下:
#使用Cursor对象执行insert,update,delete语句时,执行结果由rowcount返回影响的行数,就可以拿到执行结果。
#使用Cursor对象执行select语句时,通过featchall()可以拿到结果集。结果集是一个list,每个元素都是一个tuple,对应一行记录。
async def select(sql, args, size=None):
    log(sql, args)
    global __pool
    async with __pool.get() as conn:#打开pool的方法:with await __pool as conn: 
        # 创建一个结果为字典的游标
        async with conn.cursor(aiomysql.DictCursor) as cur:
            # 执行sql语句,将sql语句中的'?'替换成'%s'
            await cur.execute(sql.replace('?', '%s'), args or ())
            # 如果指定了数量,就返回指定数量的记录,如果没有,就返回所有记录
            if size:
                rs = await cur.fetchmany(size)
            else:
                rs = await cur.fetchall()
        logging.info('rows returned: %s' % len(rs))
        return rs#返回的结果集

async def execute(sql, args, autocommit=True):
    log(sql)
    async with __pool.get() as conn:
        if not autocommit:
            await conn.begin()
        try:
            async with conn.cursor(aiomysql.DictCursor) as cur:
                await cur.execute(sql.replace('?', '%s'), args)
                # 获取操作的记录数
                affected = cur.rowcount
            if not autocommit:
                await conn.commit()
        except BaseException as e:
            if not autocommit:
                await conn.rollback()#数据回滚
            raise
        return affected
		
#用于输出元类中创建sql_insert语句中的占位符,计算需要拼接多少个占位符
def create_args_string(num):
    L = []
    for n in range(num):
        L.append('?')
    return ', '.join(L)

class Field(object):

    def __init__(self, name, column_type, primary_key, default):
        self.name = name
        self.column_type = column_type
        self.primary_key = primary_key
        self.default = default

    def __str__(self):
        return '<%s, %s:%s>' % (self.__class__.__name__, self.column_type, self.name)

class StringField(Field):

    def __init__(self, name=None, primary_key=False, default=None, ddl='varchar(100)'):
        super().__init__(name, ddl, primary_key, default)

class BooleanField(Field):

    def __init__(self, name=None, default=False):
        super().__init__(name, 'boolean', False, default)

class IntegerField(Field):

    def __init__(self, name=None, primary_key=False, default=0):
        super().__init__(name, 'bigint', primary_key, default)

class FloatField(Field):

    def __init__(self, name=None, primary_key=False, default=0.0):
        super().__init__(name, 'real', primary_key, default)

class TextField(Field):

    def __init__(self, name=None, default=None):
        super().__init__(name, 'text', False, default)

        
#定义Model的metaclass元类
#所有的元类都继承自type
#ModelMetaclass元类定义了所有Model基类(继承ModelMetaclass)的子类实现的操作

# -*-ModelMetaclass:为一个数据库表映射成一个封装的类做准备
# 读取具体子类(eg:user)的映射信息
#创造类的时候,排除对Model类的修改
#在当前类中查找所有的类属性(attrs),如果找到Field属性,就保存在__mappings__的dict里,
#同时从类属性中删除Field(防止实例属性覆盖类的同名属性)
#__table__保存数据库表名
class ModelMetaclass(type):
    # 调用__init__方法前会调用__new__方法
    def __new__(cls, name, bases, attrs):
    # cls:当前准备创建的类的对象,name:类的名称,
    # bases:类继承的父类集合,attrs:类的方法集合
        if name=='Model':
            return type.__new__(cls, name, bases, attrs)
        # 获取table名称,如果未设置,tableName就是类的名字
        tableName = attrs.get('__table__', None) or name
        logging.info('found model: %s (table: %s)' % (name, tableName))
        # 获取所有的Field(类属性)和主键名:
		mappings = dict() #保存映射关系
        fields = [] #保存除主键外的属性
        primaryKey = None
        # key是列名,value是field的子类
        for k, v in attrs.items():
            if isinstance(v, Field):
                logging.info('  found mapping: %s ==> %s' % (k, v))
                mappings[k] = v
                if v.primary_key:
                    # 找到主键:
                    if primaryKey:
                        raise StandardError('Duplicate primary key for field: %s' % k)
                    primaryKey = k#此列设为列表的主键
                else:
                    #非主键,一律放在fields  
                    fields.append(k)
        if not primaryKey:# 如果遍历了所有属性都没有找到主键,则主键没定义
            raise StandardError('Primary key not found.')
        #从类属性中删除Field属性
        for k in mappings.keys():
            attrs.pop(k)#从类属性中删除Field属性,否则,容易造成运行时错误(实例的属性会遮盖类的同名属性)
        
        # 保存非主键属性为字符串列表形式
        # 将非主键属性变成`id`,`name`这种形式(带反引号)
        # repr函数和反引号的功能一致:取得对象的规范字符串表示
        # 将fields中属性名以`属性名`的方式装饰起来
        escaped_fields = list(map(lambda f: '`%s`' % f, fields))
        attrs['__mappings__'] = mappings # 保存属性和列的映射关系
        attrs['__table__'] = tableName# 保存表名
        attrs['__primary_key__'] = primaryKey # 主键属性名
        attrs['__fields__'] = fields # 除主键外的属性名
        
        # 构造默认的SELECT, INSERT, UPDATE和DELETE语句
        attrs['__select__'] = 'select `%s`, %s from `%s`' % (primaryKey, ', '.join(escaped_fields), tableName)
        attrs['__insert__'] = 'insert into `%s` (%s, `%s`) values (%s)' % (tableName, ', '.join(escaped_fields), primaryKey, create_args_string(len(escaped_fields) + 1))
        attrs['__update__'] = 'update `%s` set %s where `%s`=?' % (tableName, ', '.join(map(lambda f: '`%s`=?' % (mappings.get(f).name or f), fields)), primaryKey)
        attrs['__delete__'] = 'delete from `%s` where `%s`=?' % (tableName, primaryKey)
        return type.__new__(cls, name, bases, attrs)


#定义ORM所有映射的基类:Model
#Model类的任意子类可以映射一个数据库表
#Model类可以看做是对所有数据库表操作的基本定义的映射
#基于字典查询形式
#Model从dict继承,拥有字典的所有功能,同时实现特殊方法__getattr__和__setattr__,能够实现属性操作
#实现数据库操作的所有方法,定义为class方法,所有继承自Model都具有数据库操作方法
        
        
class Model(dict, metaclass=ModelMetaclass):
    #也可在此处写  __metaclass__ = ModelMetaclass,与参数处效果相同

    def __init__(self, **kw):
        super(Model, self).__init__(**kw)

    def __getattr__(self, key):
        try:
            return self[key]
        except KeyError:
            raise AttributeError(r"'Model' object has no attribute '%s'" % key)

    def __setattr__(self, key, value):
        self[key] = value

    def getValue(self, key):
		# 返回对象的属性,如果没有对应属性,则会调用__getattr__
		#直接调回内置函数,注意这里没有下划符,注意这里None的用处,是为了当user没有赋值数据时,返回None,调用于update
        return getattr(self, key, None)

    def getValueOrDefault(self, key):
        value = getattr(self, key, None)
        if value is None:
            field = self.__mappings__[key]
            if field.default is not None:
                value = field.default() if callable(field.default) else field.default
                logging.debug('using default value for %s: %s' % (key, str(value)))
				# 将默认值设置进行
                setattr(self, key, value)
        return value
	 # 类方法第一个参数为cls,而实例方法第一个参数为self
    @classmethod
	    #这里可以使用User.findAll()是因为:用@classmethod修饰了Model类里面的findAll()
        #一般来说,要使用某个类的方法,需要先实例化一个对象再调用方法
        #而使用@staticmethod或@classmethod,就可以不需要实例化,直接类名.方法名()来调用
		#申明是类方法:有类变量cls传入,cls可以做一些相关的处理
		#有子类继承时,调用该方法,传入的类变量cls是子类,而非父类
    async def findAll(cls, where=None, args=None, **kw):
        ' find objects by where clause. '
        sql = [cls.__select__]
        # 如果where查询条件存在
        if where:
            sql.append('where')# 添加where关键字
            sql.append(where)  # 拼接where查询条件
        if args is None:
            args = []
        orderBy = kw.get('orderBy', None)# 获取kw里面的orderby查询条件
        if orderBy:
            sql.append('order by')# 拼接orderBy字符串
            sql.append(orderBy)# 拼接orderBy查询条件
        limit = kw.get('limit', None)# 获取limit查询条件
        if limit is not None:
            sql.append('limit')
            if isinstance(limit, int):# 如果limit是int类型
                sql.append('?')# sql拼接一个占位符
                args.append(limit)# 将limit添加进参数列表,之所以添加参数列表之后再进行整合是为了防止sql注入
            elif isinstance(limit, tuple) and len(limit) == 2:# 如果limit是一个tuple类型并且长度是2
                sql.append('?, ?')# sql语句拼接两个占位符
                args.extend(limit)# 将limit添加进参数列表
            else:
                raise ValueError('Invalid limit value: %s' % str(limit))
        rs = await select(' '.join(sql), args) # 将args参数列表注入sql语句之后,传递给select函数进行查询并返回查询结果
        return [cls(**r) for r in rs]

    @classmethod
    #查询某个字段的数量
    async def findNumber(cls, selectField, where=None, args=None):
        ' find number by select and where. '
		 # 将列名重命名为_num
        sql = ['select %s _num_ from `%s`' % (selectField, cls.__table__)]
        if where:
            sql.append('where')
            sql.append(where)
			# 限制结果数为1
        rs = await select(' '.join(sql), args, 1)
        if len(rs) == 0:
            return None
        return rs[0]['_num_']

    @classmethod
    
    async def find(cls, pk):
        ' find object by primary key. '
        rs = await select('%s where `%s`=?' % (cls.__select__, cls.__primary_key__), [pk], 1)
        if len(rs) == 0:
            return None
        return cls(**rs[0])
		#返回一条记录,以dict的形式返回,因为cls的父类继承了dict类


    async def save(self):
		# 获取所有value
        args = list(map(self.getValueOrDefault, self.__fields__))
        args.append(self.getValueOrDefault(self.__primary_key__))
        rows = await execute(self.__insert__, args)
        if rows != 1:
            logging.warn('failed to insert record: affected rows: %s' % rows)

    async def update(self):
        args = list(map(self.getValue, self.__fields__))
        args.append(self.getValue(self.__primary_key__))
        rows = await execute(self.__update__, args)
        if rows != 1:
            logging.warn('failed to update by primary key: affected rows: %s' % rows)

    async def remove(self):
        args = [self.getValue(self.__primary_key__)]
        rows = await execute(self.__delete__, args)
        if rows != 1:
            logging.warn('failed to remove by primary key: affected rows: %s' % rows)

最大的收获是理解过程中查找了很多资料和大神对相关知识的讲解,与君共飨之!
廖雪峰python教程:Day 3 - 编写ORM
廖雪峰python教程:使用元类
关于metaclass,我原以为我是懂的
oop - What are metaclasses in Python? - Stack Overflow(专指e-satis的回答)
深刻理解Python中的元类(metaclass)(上述回答的翻译)
python-进阶-元类在ORM上的应用详解
python-学习-ORM中遇到的 mapping 详解并再总结字典dict
两句话掌握 Python 最难知识点——元类(固然是标题党!但也是很有意思的联想。)
理解 Python super(看似和orm无关,其实super()是全篇下来我最困惑的地方)
MySQL
aiomysql
Things to Know About Python Super
廖雪峰python教程实战 Day 3 - 编写ORM
Python3廖雪峰实战项目:重难点ORM
Python廖雪峰实战web开发(Day3-编写ORM)
MySQL Connector/Python Developer Guide
SQL 教程
python sql cursor用法
Python 数据库的Connection、Cursor两大对象
记录 廖雪峰老师 实战 学习到 Day10的bug 以及解决方案
讨论-RequestHandler

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值