import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from einops import rearrange
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class SelfAttention(nn.Module):
def __init__(self, channels):
super(SelfAttention, self).__init__()
self.channels = channels
self.query = nn.Conv2d(channels, channels // 8, kernel_size=1)
self.key = nn.Conv2d(channels, channels // 8, kernel_size=1)
self.value = nn.Conv2d(channels, channels, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
B, C, H, W = x.size()
proj_query = self.query(x).view(B, -1, H * W).permute(0, 2, 1) # B x N x C'
proj_key = self.key(x).view(B, -1, H * W) # B x C' x N
energy = torch.bmm(proj_query, proj_key) # B x N x N
attention = self.softmax(energy) # B x N x N
proj_value = self.value(x).view(B, -1, H * W) # B x C x N
out = torch.bmm(proj_value, attention.permute(0, 2, 1)) # B x C x N
out = out.view(B, C, H, W)
out = self.gamma * out + x
return out
class single_conv(nn.Module):
def __init__(self, in_ch, out_ch):
super(single_conv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 5, padding=2),
nn.ReLU(inplace=True)
)
if in_ch != out_ch:
self.residual = nn.Conv2d(in_ch, out_ch, 1)
else:
self.residual = nn.Identity()
def forward(self, x):
out = self.conv(x)
residual = self.residual(x)
return F.relu(out + residual)
class up(nn.Module):
def __init__(self, in_ch):
super(up, self).__init__()
self.up = nn.ConvTranspose2d(in_ch, in_ch // 2, 2, stride=2)
self.conv_concat = nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2))
x = torch.cat((x1, x2), dim=1)
x = self.conv_concat(x)
x = self.relu(x)
return x
# diff
class down(nn.Module):
def __init__(self, in_channels, dilation_rates=(1, 2, 3)):
super(down, self).__init__()
self.conv = nn.ModuleList([
nn.Sequential(
nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=d, dilation=d, bias=False),
nn.ReLU(inplace=True)
)
for d in dilation_rates
])
self.conv_2 = nn.Conv2d(in_channels * 3, in_channels, 3, 1, 1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
outputs = [conv(x) for conv in self.conv]
out = torch.cat(outputs, dim=1)
out = self.conv_2(out)
out = self.relu(out)
return out
class outconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(outconv, self).__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 1)
def forward(self, x):
x = self.conv(x)
return x
class adjust_net(nn.Module):
def __init__(self, out_channels=64, middle_channels=32):
super(adjust_net, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(2, middle_channels, 3, padding=1),
nn.ReLU(inplace=True),
nn.AvgPool2d(2),
nn.Conv2d(middle_channels, middle_channels * 2, 3, padding=1),
nn.ReLU(inplace=True),
nn.AvgPool2d(2),
nn.Conv2d(middle_channels * 2, middle_channels * 4, 3, padding=1),
nn.ReLU(inplace=True),
nn.AvgPool2d(2),
nn.Conv2d(middle_channels * 4, out_channels * 2, 1, padding=0)
)
def forward(self, x):
out = self.model(x)
out = F.adaptive_avg_pool2d(out, (1, 1))
out1 = out[:, :out.shape[1] // 2]
out2 = out[:, out.shape[1] // 2:]
return out1, out2
class UNet(nn.Module):
def __init__(self, in_channels=2, out_channels=1):
super(UNet, self).__init__()
dim = 32
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
)
self.inc = nn.Sequential(
single_conv(in_channels, 64),
single_conv(64, 64)
)
self.down1 = down(64)
self.attn1 = SelfAttention(64)
self.mlp1 = nn.Sequential(
nn.GELU(),
nn.Linear(dim, 64)
)
self.adjust1 = adjust_net(64)
self.conv1 = nn.Sequential(
single_conv(64, 128),
single_conv(128, 128),
single_conv(128, 128)
)
self.down2 = down(128)
self.attn2 = SelfAttention(128)
self.mlp2 = nn.Sequential(
nn.GELU(),
nn.Linear(dim, 128)
)
self.adjust2 = adjust_net(128)
self.conv2 = nn.Sequential(
single_conv(128, 256),
single_conv(256, 256),
single_conv(256, 256),
single_conv(256, 256),
single_conv(256, 256),
single_conv(256, 256)
)
self.up1 = up(256)
self.attn3 = SelfAttention(128)
self.mlp3 = nn.Sequential(
nn.GELU(),
nn.Linear(dim, 128)
)
self.adjust3 = adjust_net(128)
self.conv3 = nn.Sequential(
single_conv(128, 128),
single_conv(128, 128),
single_conv(128, 128)
)
self.up2 = up(128)
self.attn4 = SelfAttention(64)
self.mlp4 = nn.Sequential(
nn.GELU(),
nn.Linear(dim, 64)
)
self.adjust4 = adjust_net(64)
self.conv4 = nn.Sequential(
single_conv(64, 64),
single_conv(64, 64)
)
self.outc = outconv(64, out_channels)
def forward(self, x, t, x_adjust, adjust):
inx = self.inc(x)
time_emb = self.time_mlp(t)
down1 = self.down1(inx)
down1 = self.attn1(down1)
condition1 = self.mlp1(time_emb)
b, c = condition1.shape
condition1 = rearrange(condition1, 'b c -> b c 1 1')
if adjust:
gamma1, beta1 = self.adjust1(x_adjust)
down1 = down1 + gamma1 * condition1 + beta1
else:
down1 = down1 + condition1
conv1 = self.conv1(down1)
down2 = self.down2(conv1)
down2 = self.attn2(down2)
condition2 = self.mlp2(time_emb)
b, c = condition2.shape
condition2 = rearrange(condition2, 'b c -> b c 1 1')
if adjust:
gamma2, beta2 = self.adjust2(x_adjust)
down2 = down2 + gamma2 * condition2 + beta2
else:
down2 = down2 + condition2
conv2 = self.conv2(down2)
up1 = self.up1(conv2, conv1)
up1 = self.attn3(up1)
condition3 = self.mlp3(time_emb)
b, c = condition3.shape
condition3 = rearrange(condition3, 'b c -> b c 1 1')
if adjust:
gamma3, beta3 = self.adjust3(x_adjust)
up1 = up1 + gamma3 * condition3 + beta3
else:
up1 = up1 + condition3
conv3 = self.conv3(up1)
up2 = self.up2(conv3, inx)
up2 = self.attn4(up2)
condition4 = self.mlp4(time_emb)
b, c = condition4.shape
condition4 = rearrange(condition4, 'b c -> b c 1 1')
if adjust:
gamma4, beta4 = self.adjust4(x_adjust)
up2 = up2 + gamma4 * condition4 + beta4
else:
up2 = up2 + condition4
conv4 = self.conv4(up2)
out = self.outc(conv4)
return out
class Network(nn.Module):
def __init__(self, in_channels=3, out_channels=1, context=True):
super(Network, self).__init__()
self.unet = UNet(in_channels=in_channels, out_channels=out_channels)
self.context = context
def forward(self, x, t, y, x_end, adjust=True):
if self.context:
x_middle = x[:, 1].unsqueeze(1)
else:
x_middle = x
x_adjust = torch.cat((y, x_end), dim=1)
out = self.unet(x, t, x_adjust, adjust=adjust) + x_middle
return out
# WeightNet of the one-shot learning framework
class WeightNet(nn.Module):
def __init__(self, weight_num=10):
super(WeightNet, self).__init__()
init = torch.ones([1, weight_num, 1, 1]) / weight_num
self.weights = nn.Parameter(init)
def forward(self, x):
weights = F.softmax(self.weights, 1)
out = weights * x
out = out.sum(dim=1, keepdim=True)
return out, weights
corediff2
最新推荐文章于 2025-12-02 20:45:02 发布
部署运行你感兴趣的模型镜像
您可能感兴趣的与本文相关的镜像
PyTorch 2.5
PyTorch
Cuda
PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理
2215

被折叠的 条评论
为什么被折叠?



