pytorch unet+res_block 用于图像序列预测

该博客介绍了使用PyTorch实现的Residual U-Net网络结构,该模型结合了U-Net的上下采样路径和残差块。网络适用于图像分类和预测任务,从TensorFlow中移植而来。博主还探讨了TensorFlow中activation=linear是否等同于leakyrelu(negative_slope=1)的问题。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init


def transfer(x):
    seq_number, batch_size, input_channel, height, width = x.size()
    x = torch.reshape(x, (-1, input_channel, height, width))
    return x


class basic_block(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(basic_block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding='same'),
            nn.BatchNorm2d(ch_out),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding='same'),
        )
        self.residual = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=1, padding='same')
        )

    def forward(self, x):
        x = (self.conv(x) + self.residual(x))
        return x


class down_block(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(down_block, self).__init__()
        self.conv = nn.Sequential(
            nn.BatchNorm2d(ch_in),
            nn.LeakyReLU(0.1),
            nn.MaxPool2d(2, stride=2),
            nn.BatchNorm2d(ch_in),
            nn.LeakyReLU(0.1),
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding='same'),
        )
        self.residual = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=2)
        )

    def forward(self, x):
        x2 = self.conv(x)
        x1 = (x2 + self.residual(x))
        return x1, x2


class up_block(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(up_block, self).__init__()
        self.conv = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.BatchNorm2d(ch_in),
            nn.LeakyReLU(0.1),
            nn.Conv2d(ch_in, ch_out, kernel_size=3, padding='same'),
            nn.BatchNorm2d(ch_out),
            nn.LeakyReLU(0.1),
            nn.Conv2d(ch_out, ch_out, kernel_size=3, padding='same'),
        )
        self.residual = nn.Sequential(
            nn.ConvTranspose2d(ch_in, ch_out, kernel_size=2, stride=2)
        )

    def forward(self, x, y):
        x3 = torch.cat((x, y), dim=1)
        x2 = self.conv(x3)
        x1 = (x2 + self.residual(x3))
        return x1


class res_U_Net(nn.Module):
    def __init__(self, img_ch=5, output_ch=5):
        super(res_U_Net, self).__init__()
        self.basicx1 = basic_block(img_ch, 32)
        self.downx1 = down_block(32, 32)
        self.downx2 = down_block(32, 64)
        self.downx4 = down_block(64, 128)
        self.downx16 = down_block(128, 256)
        self.downx32 = down_block(256, 512)
        self.downx64 = down_block(512, 512)
        self.downx128 = down_block(512, 1024)
        self.centerx128 = basic_block(1024, 1024)
        self.upx64 = up_block(2048, 1024)
        self.upx32 = up_block(1536, 512)
        self.upx16 = up_block(1024, 512)
        self.upx8 = up_block(768, 256)
        self.upx4 = up_block(384, 128)
        self.upx2 = up_block(192, 64)
        self.upx1 = up_block(96,32)
        self.conv = nn.Conv2d(32, output_ch, kernel_size=3,padding='same')

    def forward(self, x):
        basicx1 = self.basicx1(x)
        downx1, downx1_skip = self.downx1(basicx1)
        downx2, downx2_skip = self.downx2(downx1)
        downx4, downx4_skip = self.downx4(downx2)
        downx16, downx16_skip = self.downx16(downx4)
        downx32, downx32_skip = self.downx32(downx16)
        downx64, downx64_skip = self.downx64(downx32)
        downx128, downx128_skip = self.downx128(downx64)
        centerx128 = self.centerx128(downx128)
        upx64 = self.upx64(centerx128, downx128_skip)
        upx32 = self.upx32(upx64, downx64_skip)
        upx16 = self.upx16(upx32, downx32_skip)
        upx8 = self.upx8(upx16, downx16_skip)
        upx4 = self.upx4(upx8, downx4_skip)
        upx2 = self.upx2(upx4, downx2_skip)
        upx1 = self.upx1(upx2, downx1_skip)
        out = self.conv(upx1)
        return out


import torchsummary

torchsummary.summary(res_U_Net(img_ch=5, output_ch=5), input_size=[(5, 256, 256)], batch_size=2,
                     device="cpu")

Unet+残差网络,可用与分类或预测,从tensorflow改写而来。

另外还有一点疑惑,tensorflow中的activation=linear是否等同于leakyrelu(negative_slope = 1)

1. 首先导入所需的PyTorch模块 ``` import torch import torch.nn as nn import torch.nn.functional as F ``` 2. 定义注意力模块 可以从原始的U-Net卷积层中获取不同维度的信息,从而为每个像素提供更准确的定位。 ``` class AttentionBlock(nn.Module): def __init__(self, in_channels, gate_channels, use_res=True): super(AttentionBlock, self).__init__() self.use_res = use_res self.in_channels = in_channels self.W = nn.Sequential( nn.Conv2d(in_channels, gate_channels, kernel_size=1, bias=False), nn.BatchNorm2d(gate_channels), nn.ReLU(inplace=True), nn.Conv2d(gate_channels, in_channels, kernel_size=1, bias=False), nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True) ) self.gamma = nn.Parameter(torch.zeros(1)) def forward(self, x): assert x.size()[1] == self.in_channels Wx = self.W(x) if self.use_res: out = x + self.gamma * Wx else: out = Wx return out ``` 3. 定义U-Net网络结构 ``` class UNet(nn.Module): def __init__(self, in_channels=3, out_channels=1): super().__init__() # Encoder部分 self.enc1 = nn.Sequential( nn.Conv2d(in_channels, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True) ) self.enc2 = nn.Sequential( nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True) ) self.enc3 = nn.Sequential( nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(128, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True) ) self.enc4 = nn.Sequential( nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(256, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True) ) self.enc5 = nn.Sequential( nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(512, 1024, kernel_size=3, padding=1), nn.BatchNorm2d(1024), nn.ReLU(inplace=True), nn.Conv2d(1024, 1024, kernel_size=3, padding=1), nn.BatchNorm2d(1024), nn.ReLU(inplace=True) ) # Decoder部分 self.dec5 = nn.Sequential( nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True) ) self.dec4 = nn.Sequential( nn.ConvTranspose2d(1024, 256, kernel_size=2, stride=2), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True) ) self.dec3 = nn.Sequential( nn.ConvTranspose2d(512, 128, kernel_size=2, stride=2), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True) ) self.dec2 = nn.Sequential( nn.ConvTranspose2d(256, 64, kernel_size=2, stride=2), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True) ) self.dec1 = nn.Sequential( nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True) ) # 辅助注意力模块 self.att1 = AttentionBlock(64, 64) self.att2 = AttentionBlock(128, 64) self.att3 = AttentionBlock(256, 64) self.att4 = AttentionBlock(512, 64) # 最后一层卷积层(输出层) self.out = nn.Conv2d(64, out_channels, kernel_size=1) def forward(self, x): # Encoder部分 enc1 = self.enc1(x) enc2 = self.enc2(F.max_pool2d(enc1, kernel_size=2, stride=2)) enc3 = self.enc3(F.max_pool2d(enc2, kernel_size=2, stride=2)) enc4 = self.enc4(F.max_pool2d(enc3, kernel_size=2, stride=2)) enc5 = self.enc5(F.max_pool2d(enc4, kernel_size=2, stride=2)) # Decoder部分 dec5 = torch.cat((enc4, self.dec5(enc5)), dim=1) dec4 = self.att4(torch.cat((enc3, self.dec4(dec5)), dim=1)) dec3 = self.att3(torch.cat((enc2, self.dec3(dec4)), dim=1)) dec2 = self.att2(torch.cat((enc1, self.dec2(dec3)), dim=1)) dec1 = self.att1(self.dec1(dec2)) out = self.out(dec1) return out ``` 4. 实例化模型并开始训练 可以使用常规的训练和测试代码来训练和测试新的U-Net网络结构,以便于检测和分割各种目标。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值