1. 训练和采样流程

2. 无条件实现
import torch, time, os
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import torch.nn.functional as F
class ResidualConvBlock(nn.Module):
def __init__(
self, in_channels: int, out_channels: int, is_res: bool = False
) -> None:
super().__init__()
'''
standard ResNet style convolutional block
'''
self.same_channels = in_channels==out_channels
self.is_res = is_res
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1),
nn.BatchNorm2d(out_channels),
nn.GELU(),
)
self.conv2 = nn.Sequential(
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
nn.BatchNorm2d(out_channels),
nn.GELU(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.is_res:
x1 = self.conv1(x)
x2 = self.conv2(x1)
# this adds on correct residual in case channels have increased
if self.same_channels:
out = x + x2
else:
out = x1 + x2
return out / 1.414
else:
x1 = self.conv1(x)
x2 = self.conv2(x1)
return x2
class UnetDown(nn.Module):
def __init__(self, in_channels, out_channels):
super(UnetDown, self).__init__()
'''
process and downscale the image feature maps
'''
layers = [ResidualConvBlock(in_channels, out_channels), nn.MaxPool2d(2)]
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class UnetUp(nn.Module):
def __init__(self, in_channels, out_channels):
super(UnetUp, self).__init__()
'''
process and upscale the image feature maps
'''
layers = [
nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
ResidualConvBlock(out_channels, out_channels),
ResidualConvBlock(out_channels, out_channels),
]
self.model = nn.Sequential(*layers)
def forward(self, x, skip):
x = torch.cat((x, skip), 1)
x = self.model(x)
return x
class EmbedFC(nn.Module):
def __init__(self, input_dim, emb_dim):
super(EmbedFC, self).__init__()
'''
generic one layer FC NN for embedding things
'''
self.input_dim = input_dim
layers = [
nn.Linear(input_dim, emb_dim),
nn.GELU(),
nn.Linear(emb_dim, emb_dim),
]
self.model = nn.Sequential(*layers)
def forward(self, x):
x = x.view(-1, self.input_dim)
return self.model(x)
class Unet(nn.Module):
def __init__(self, in_channels, n_feat=256):
super(Unet, self).__init__()
self.in_channels = in_channels
self.n_feat = n_feat
self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)
self.down1 = UnetDown(n_feat, n_feat)
self.down2 = UnetDown(n_feat, 2 * n_feat)
self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU())
self.timeembed1 = EmbedFC(1, 2 * n_feat)
self.timeembed2 = EmbedFC(1, 1 * n_feat)
self.up0 = nn.Sequential(
# nn.ConvTranspose2d(6 * n_feat, 2 * n_feat, 7, 7), # when concat temb and cemb end up w 6*n_feat
nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, 7, 7), # otherwise just have 2*n_feat
nn.GroupNorm(8, 2 * n_feat),
nn.ReLU(),
)
self.up1 = UnetUp(4 * n_feat, n_feat)
self.up2 = UnetUp(2 * n_feat, n_feat)
self.out = nn.Sequential(
nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1),
nn.GroupNorm(8, n_feat),
nn.ReLU(),
nn.Conv2d(n_feat, self.in_channels, 3, 1, 1),
)
def forward(self, x, t):
'''
输入加噪图像和对应的时间step,预测反向噪声的正态分布
: