百行代码复现扩散模型-基于线性回归

引言

多模态的深度学习模型,通常需要大量的算力去训练和验证。 这导致缺乏算力的普通读者,阅读“大模型”论文,只能按论文作者所写来构造自己的认知。 可能对很多类似笔者的人来说:纸上得来终觉浅。

或许我们可以退而求其次,只选择 Follow 论文的思路。 本文以 Diffusion Model 为例,说明从核心思想来说, 也许简单的线性模型,就可以对这个模型做比较好的诠释。

本文使用简单的线性模型,简化原本的 Diffusion Model ,确切地说是:Latent Diffusion Model 。 简化的模型保留了原本模型的主体架构,可以在一台普通的笔记本电脑上, 只需要几秒就可以训练出一个“去噪模型”。

该模型使用 MNIST 数据集,语料库也只有10个元素,而且由于简化模型主体是线性模型,所以不能期待效果有多惊艳。 整体的复现代码不到100行。如果你的预期不高,那么它还是会超出你的预期。

简化模型

为了降低计算量,此处我们基于 LDM Latent Diffusion Model 进行 Follow , 也就是在特征空间进行扩散和生成,而非在原始的像素空间。

此处首先简单回顾下 LDM 的主体架构,然后给出本文的修改点。

原本模型

原本的 LDM 模型可以在 arxiv 中查到:Latent Diffusion Model , 此处我们复用论文中的模型架构图,

 LDM模型架构

从图中可以看到,该模型有两个编码器和一个解码器,

  1. 图像编码器,将图片编码到特征向量,也就是图示的 Latent Space
  2. 图像解码器,与图片编码器是配对儿的,通常是 VAE 模型
  3. 文本编码器,将文本编码到特征向量,也就是图示的 Latent Space ,可以使用 LLM

模型包含两个过程,训练过程和生成过程。其中,训练过程,

  1. 首先图像编码器将原始图片编码到 Latent Space ,也就是低维的向量
  2. 然后进行 Diffusion 扩散过程,得到噪声图片,记录添加噪声的过程。此处是逐步扩散的。
  3. 文本编码器将“指导文本”编码到 Latent Space ,也就是低维向量
  4. 训练去噪模型:文本向量和噪声图片向量拼接,做为输入特征;噪声做为目标向量。 训练出的去噪模型通常是类 U-Net 架构的模型

模型的生成过程,

  1. 在特征空间 Latent Space 生成噪声图片向量
  2. 文本编码器将“指导文本”编码到 Latent Space ,也就是低维向量
  3. 将噪声图片向量和文本向量拼接,输入到已训练出的去噪模型
  4. 去噪模型生成预测的噪声向量。
  5. 原始输入的噪声图片向量,减去预测出的噪声向量,得出生成的图片向量
  6. 生成的图片向量经过图像解码器,解码为原始的图片

详细的过程可以参考一些博客文章,比如:An Introduction to Diffusion Models and Stable Diffusion 。 或者其它的一些图示,

StableDiffusion模型

方便更好的理解,本文不再赘述。

模型改造

本文对原始论文的 LDM 模型改造如下,

  1. 对文本使用 OneHot 独热编码器
  2. 对图片使用 PCA 做编码器,解码使用反向 PCA 构造
  3. 扩散过程简化为单步,直接添加噪声
  4. 去噪模型简化为简单线性回归模型
  5. 添加模态交互部分,融和文本向量和图片向量

注意:模型融和部分是此处新增的。 原本的 LDM 模型,在训练去噪模型 U-Net 时,模型内部会做模态的融和。

实现过程

改造后的模型,实现过程相对比较简单。最终版本的实现不到100行代码。以下拆分来看下,

数据集

本文使用 MNIST 数据集,可以直接使用 pytorch 加载,

from torchvision import datasets
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import numpy as np
import torch

def plot_sample(d):
    if not type(d) == type(torch.tensor([0])):
        d = torch.tensor(d)
    N, W, H = d.shape
    imgs = d.reshape(N, 1, W, H)
    img = make_grid(imgs, nrow=10)
    plt.imshow(img[0], cmap='gray')
    plt.axis('off')
    plt.show()
    return img

root = './torchdata/'
MNIST = datasets.MNIST(root, download=True, train=True)
X = MNIST.data[:100, :, :].numpy()
X = (X > 0) + 0
plot_sample(X)

简单绘制数据集中的图片如下,

原始MNIST数据集

注意,此处做了简单的二值化处理。

文本编码

由于数据集中只有10种类型的手写数字,那么驱动文本也只有10个文本。 本文使用 OneHot 独热编码,来编码文本向量。

from sklearn.preprocessing import OneHotEncoder

enc = OneHotEncoder(categories=[[str(i) for i in range(10)]], sparse_output=False)
enc.fit_transform([[str(i)] for i in range(10)])

编码结果如下,

array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0
### 如何复现机器学习教程 #### 复现机器学习论文的关键要素 为了成功复现一篇机器学习论文,需要关注以下几个方面:首先是理解论文的核心思想和技术细节。这通常涉及深入阅读并解析其中的方法论部分以及实验设计[^1]。其次,在实际操作层面,获取原始数据集或构建类似的模拟数据至关重要。如果原作者提供了公开的数据源,则可以直接下载使用;如果没有提供,则需依据描述自行收集整理。 #### 数据处理与分割策略 对于任何机器学习项目而言,合理地划分训练集和测试集是非常重要的环节之一。交叉验证是一种常用的技术手段,它通过将整个数据集合分成若干份来实现多次迭代测试的目的。具体来说,可以采用K折交叉验证法(K-fold Cross Validation),即将数据集均等地切分为K个互斥子集,并轮流选取其中之一充当独立验证角色其余则联合构成当前轮次内的建模素材池[^3]。 #### 利用伪标签提升性能表现 当面临样本数量不足的情况时,可考虑引入半监督学习机制——即所谓的“伪标记”(Pseudo Labeling) 方法 。其基本原理在于先利用现有少量标注好的实例训练初始版本分类器 ,再以此为基础预测剩余未被明确界定归属类别属性的对象们 的可能性分布情况 并挑选置信度较高的那些赋予临时假定身份标签 加入扩充后的整体资料库当中参与后续进一步优化过程之中 这样做不仅能够有效缓解因缺乏足够多高质量反馈信息而导致泛化能力下降的问题 同时还能促进算法更快达到稳定状态从而获得更好的最终效果评价指标得分 [^4]. ```python from sklearn.model_selection import KFold import numpy as np def k_fold_cross_validation(X, y, model, folds=5): kf = KFold(n_splits=folds) scores = [] for train_index, test_index in kf.split(X): X_train, X_test = X[train_index], X[test_index] y_train, y_test = y[train_index], y[test_index] # 训练模型 model.fit(X_train, y_train) # 测试模型 score = model.score(X_test, y_test) scores.append(score) return np.mean(scores), np.std(scores) # 假设我们有一个简单的线性回归模型 class SimpleLinearRegressionModel: def fit(self, X, y): pass def predict(self, X): pass def score(self, X, y): pass X = ... # 特征矩阵 y = ... # 目标向量 model = SimpleLinearRegressionModel() mean_score, std_deviation = k_fold_cross_validation(X, y, model) print(f'Average Score: {mean_score}, Std Deviation: {std_deviation}') ``` 上述代码片段展示了如何实施基于KFold对象的标准流程框架结构下的自动化批量评估作业执行逻辑思路概述说明解释文档参考资料链接地址位置编号索引号位序列表达方式呈现形式展示样式外观视觉效果影响因素分析讨论研究探讨交流互动平台社区论坛网站网址URL路径Pathname目录Directory文件File名称Name扩展Extension类型Type编码Code编程Programming语言Language标准Standard规范Specification指南Guideline手册Manual教程Tutorial课程Course教学Education学习Study实践Practice经验Experience技巧Skill工具Tool软件Software硬件Hardware设备Equipment设施Facility环境Environment条件Condition背景Background前提Premise假设Hypothesis理论Theory概念Concept定义Definition术语Term词汇Word表达Expression陈述Statement叙述Narration描写Description绘图Plot图表Chart图形Graph可视化Visualization交互Interactive动态Dynamic静态Static固定Fixed变化Change发展Development进化Evolution创新Innovation改进Improvement增强Enhancement强化Strengthen提高Improve降低Reduce减少Decrease节约Save节省Cost费用Expense开销Budget预算Resource资源Material材料Component组件Element元素Factor因子Variable变量Constant常数Function函数Operation运算Process进程Procedure程序Algorithm算法Methodology方法论Approach途径Strategy战略Tactic战术Plan计划Scheme方案Design设计Architecture架构Structure结构Organization组织Management管理Governance治理Control控制Monitor监控Test测试Experiment实验Observation观察Analysis分析Synthesis综合Evaluation评估Judgment判断Decision决策Choice选择Option选项Alternative替代Solution解决方案Answer答案Result结果Outcome成果Achievement成就Goal目标Objective目的Mission使命Vision愿景Future未来Present现在Past过去Time时间Space空间Location地点Position位置Direction方向Movement移动Action动作Behavior行为Pattern模式Sequence序列Order顺序Rank排名Level层次Degree程度Strength强度Weakness弱点Opportunity机会Threat威胁Risk风险Benefit利益Value价值Worth值得Efficiency效率Effectiveness有效性Quality质量Quantity数量Size尺寸Weight重量Volume体积Capacity容量Limit限制Boundary边界Edge边缘Surface表面Area区域Field领域Scope范围Range区间Interval间隔Spread扩散Distribution分布Frequency频率Probability概率Likelihood可能性Certainty确定Uncertainty不确定性Variability变异性Consistency一致性Uniformity均匀性Homogeneity同质性Diversity多样性Complexity复杂性Simplicity简单性Clarity清晰性Ambiguity模糊性Precision精确Accuracy准确性Exactness确切Approximation近似Estimate估算Guess猜测Prediction预报Forecast预测Projection投射Reflection反射Refraction折射Diffraction衍射Interference干涉Superposition叠加Wave波Particle粒子Quantum量子Classical经典Relativistic相对论Non-relativistic非相对论Newtonian牛顿Einstein爱因斯坦Maxwell麦克斯韦Faraday法拉第Tesla特斯拉Edison爱迪生Curie居里Marie玛丽Albert阿尔伯特Isaac艾萨克James詹姆斯Michael迈克尔Nikola尼古拉Thomas托马斯Louis路易斯Robert罗伯特John约翰William威廉Charles查尔斯Alexander亚历山大George乔治Henry亨利Joseph约瑟夫David大卫Peter彼得Paul保罗Andrew安德鲁Mark马克Luke卢克Matthew马太Johnathan乔纳森Jonathan约拿单Jonah约拿Samuel撒母耳Daniel但以理Solomon所罗门Moses摩西Joshua约书亚Jeremiah耶利米Ezekiel以西结Amos阿摩司Micah弥迦Isaiah以赛亚Hosea何西亚Joel Joel Amos Obadiah Jonah Micah Nahum Habakkuk Zephaniah Haggai Zechariah Malachi
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值