使用PyTorch实现U-Net图像分割:从理论到实践

引言

图像分割是计算机视觉领域的重要任务之一,而U-Net因其独特的架构在医学图像分割等领域表现出色。本文将详细介绍如何使用PyTorch实现一个完整的U-Net图像分割模型,包括数据准备、网络构建、训练和测试的全过程。

1. 数据集准备

首先,我们需要准备数据集并实现数据加载功能。我们创建了一个自定义的Dataset类来处理图像和对应的分割掩码。

from torch.utils.data import Dataset
import os
from utils import *
from torchvision import transforms

transform=transforms.Compose([
    transforms.ToTensor()
])

class MyDataset(Dataset):
    def __init__(self,path):
        super().__init__()
        self.path=path
        self.name=os.listdir(os.path.join(self.path,'SegmentationClass'))

    def __len__(self):
        return len(self.name)

    def __getitem__(self, index):
        segment_name=self.name[index]
        segment_path=os.path.join(self.path,'SegmentationClass',segment_name)
        image_path=os.path.join(self.path,'JPEGImages',segment_name.replace('png', 'jpg'))
        segment_image=keep_image_size_open(segment_path)
        image=keep_image_size_open(image_path)
        return transform(image),transform(segment_image)
    
if __name__=='__main__':
    dataset=MyDataset('D:\\python_text\\python\\pytorch深度学习实战\\U-Net图像分割\\VOC2012')
    print(dataset[0][0].shape)
    print(dataset[0][1].shape)

2. 工具函数

我们还需要一些辅助函数来处理图像大小,确保所有输入图像具有相同的尺寸。

from PIL import Image

def keep_image_size_open(path,size=(256,256)):
    img=Image.open(path)
    temp=max(img.size)
    mask=Image.new('RGB',(temp,temp),(0,0,0)) # 创建边长为temp的正方形黑色背景
    mask.paste(img,(0,0)) # 把img粘贴到mask的(0, 0)位置(左上角)
    mask=mask.resize(size) # 对mask进行resize
    return mask

3. U-Net网络实现

U-Net的核心是其编码器-解码器结构,以及跳跃连接。以下是完整的实现:

from torch import nn
from torch.nn import functional as F
import torch

class Conv_Block(nn.Module):
    def __init__(self,in_channels,out_channels):
        super().__init__()
        self.layer=nn.Sequential(
            nn.Conv2d(in_channels,out_channels,3,1,1,padding_mode='reflect',bias=False),
            nn.BatchNorm2d(out_channels),
            nn.Dropout2d(0.3),
            nn.LeakyReLU(),
            nn.Conv2d(out_channels,out_channels,3,1,1,padding_mode='reflect',bias=False),
            nn.BatchNorm2d(out_channels),
            nn.Dropout2d(0.3),
            nn.LeakyReLU(),
        )

    def forward(self,x):
        return self.layer(x)
    
class DownSample(nn.Module):
    def __init__(self,channels):
        super().__init__()
        self.layer=nn.Sequential(
            nn.Conv2d(channels,channels,3,2,1,padding_mode='reflect',bias=False),
            nn.BatchNorm2d(channels),
            nn.LeakyReLU(),
        )

    def forward(self,x):
        return self.layer(x)
    
class UpSample(nn.Module):
    def __init__(self,channels):
        super().__init__()
        self.layer=nn.Conv2d(channels,channels//2,1,1)

    def forward(self,x,feature_map):
        up=F.interpolate(x,scale_factor=2,mode='nearest')
        x=self.layer(up)
        return torch.cat((x,feature_map),dim=1)
    
class UNet(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.c1=Conv_Block(3,64)
        self.d1=DownSample(64)
        self.c2=Conv_Block(64,128)
        self.d2=DownSample(128)
        self.c3=Conv_Block(128,256)
        self.d3=DownSample(256)
        self.c4=Conv_Block(256,512)
        self.d4=DownSample(512)
        self.c5=Conv_Block(512,1024)
        self.u1=UpSample(1024)
        self.c6=Conv_Block(1024,512)
        self.u2=UpSample(512)
        self.c7=Conv_Block(512,256)
        self.u3=UpSample(256)
        self.c8=Conv_Block(256,128)
        self.u4=UpSample(128)
        self.c9=Conv_Block(128,64)
        self.out=nn.Conv2d(64,3,3,1,1)
        self.sigmoid=nn.Sigmoid()

    def forward(self,x):
        x1=self.c1(x)
        x2=self.c2(self.d1(x1))
        x3=self.c3(self.d2(x2))
        x4=self.c4(self.d3(x3))
        x5=self.c5(self.d4(x4))
        x=self.c6(self.u1(x5,x4))
        x=self.c7(self.u2(x,x3))
        x=self.c8(self.u3(x,x2))
        x=self.c9(self.u4(x,x1))
        x=self.sigmoid(self.out(x))
        return x

if __name__=='__main__':
    x=torch.rand(2,3,256,256)
    net=UNet()
    print(net(x).shape)

4. 训练过程

训练脚本负责加载数据、定义损失函数和优化器,并执行训练循环。

import torch
from torch import nn,optim
from torch.utils.data import DataLoader
from net import *
from dataset import *
from torchvision.utils import save_image

device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
weight_path=r'D:\python_text\python\pytorch深度学习实战\U-Net图像分割\src\params\unet.pth'
data_path=r'D:\python_text\python\pytorch深度学习实战\U-Net图像分割\VOC2012'
save_image_path=r'D:\python_text\python\pytorch深度学习实战\U-Net图像分割\src\train_image'

if __name__=='__main__':
    data_loader=DataLoader(MyDataset(data_path),batch_size=10,shuffle=True)
    net=UNet().to(device)
    if os.path.exists(weight_path):
        net.load_state_dict(torch.load(weight_path))
        print('load weight success')
    else:
        print('not load weight')

    optimizer=optim.Adam(net.parameters(),lr=0.001)
    loss=nn.BCELoss().to(device)

    for epoch in range(10):
        for i,(image,segment) in enumerate(data_loader):
            image,segment=image.to(device),segment.to(device)
            output=net(image)
            l=loss(output,segment)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()

            if i%5==0:
                print(epoch+1,i,l.item())
            
            if i%50==0:
                torch.save(net.state_dict(),weight_path)

            _image=image[0]
            _segment=segment[0]
            _output=output[0]

            img=torch.stack([_image,_segment,_output],dim=0)
            save_image(img,f'{save_image_path}/{i}.png')

5. 测试模型

训练完成后,我们可以使用测试脚本来评估模型在单个图像上的表现。

from net import *
import os
import torch
from utils import *
from dataset import *
from torchvision.utils import save_image

net=UNet().cuda()

weight_path=r'D:\python_text\python\pytorch深度学习实战\U-Net图像分割\src\params\unet.pth'
if os.path.exists(weight_path):
    net.load_state_dict(torch.load(weight_path))
    print('load weight success')
else:
    print('not load weight')

input_path=r'D:\python_text\python\pytorch深度学习实战\U-Net图像分割\VOC2012\JPEGImages\2007_000027.jpg'
image=keep_image_size_open(input_path)
img=transform(image).cuda()
img=torch.unsqueeze(img,dim=0)
output=net(img)

save_image(output,'D:\\python_text\\python\\pytorch深度学习实战\\U-Net图像分割\\src\\result\\result.png')

总结

本文完整展示了使用PyTorch实现U-Net图像分割的全过程,包括:

  1. 数据集准备和预处理

  2. U-Net网络架构的实现

  3. 训练过程和优化策略

  4. 模型测试和结果保存

U-Net因其独特的编码器-解码器结构和跳跃连接,在图像分割任务中表现出色。通过本文的实现,读者可以深入理解U-Net的工作原理,并将其应用于自己的图像分割任务中。

希望这篇文章对你理解和实现U-Net有所帮助!如果有任何问题,欢迎在评论区留言讨论。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值