pytorch版Unet实现医学图像分割

目录

1 model

2 dataset

3 main

4 下载资源


1 model

import torch.nn as nn
import torch
from torch import autograd

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, input):
        return self.conv(input)


class Unet(nn.Module):
    def __init__(self,in_ch,out_ch):
        super(Unet, self).__init__()

        self.conv1 = DoubleConv(in_ch, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        self.conv5 = DoubleConv(512, 1024)
        self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.conv6 = DoubleConv(1024, 512)
        self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv7 = DoubleConv(512, 256)
        self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv8 = DoubleConv(256, 128)
        self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv9 = DoubleConv(128, 64)
        self.conv10 = nn.Conv2d(64,out_ch, 1)

    def forward(self,x):
        c1=self.conv1(x)
        p1=self.pool1(c1)
        c2=self.conv2(p1)
        p2=self.pool2(c2)
        c3=self.conv3(p2)
        p3=self.pool3(c3)
        c4=self.conv4(p3)
        p4=self.pool4(c4)
        c5=self.conv5(p4)
        up_6= self.up6(c5)
        merge6 = torch.cat([up_6, c4], dim=1)
        c6=self.conv6(merge6)
        up_7=self.up7(c6)
        merge7 = torch.cat([up_7, c3], dim=1)
        c7=self.conv7(merge7)
        up_8=self.up8(c7)
        merge8 = torch.cat([up_8, c2], dim=1)
        c8=self.conv8(merge8)
        up_9=self.up9(c8)
        merge9=torch.cat([up_9,c1],dim=1)
        c9=self.conv9(merge9)
        c10=self.conv10(c9)
        out = nn.Sigmoid()(c10)
        return out

2 dataset

import torch.utils.data as data
import PIL.Image as Image
import os


def make_dataset(root):
    imgs=[]
    n=len(os.listdir(root))//2
    for i in range(n):
        img=os.path.join(root,"%03d.png"%i)
        mask=os.path.join(root,"%03d_mask.png"%i)
        imgs.append((img,mask))
    return imgs


class LiverDataset(data.Dataset):
    def __init__(self, root, transform=None, target_transform=None):
        imgs = make_dataset(root)
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        x_path, y_path = self.imgs[index]
        img_x = Image.open(x_path)
        img_y = Image.open(y_path)
        if self.transform is not None:
            img_x = self.transform(img_x)
        if self.target_transform is not None:
            img_y = self.target_transform(img_y)
        return img_x, img_y

    def __len__(self):
        return len(self.imgs)

3 main

import numpy as np
import torch
import argparse
from torch.utils.data import DataLoader
from torch import autograd, optim
from torchvision.transforms import transforms
from unet import Unet
from dataset import LiverDataset


# 是否使用cuda

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 把多个步骤整合到一起, channel=(channel-mean)/std, 因为是分别对三个通道处理
x_transforms = transforms.Compose([
    transforms.ToTensor(),  # -> [0,1]
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # ->[-1,1]
])

# mask只需要转换为tensor
y_transforms = transforms.ToTensor()

# 参数解析器,用来解析从终端读取的命令
parse = argparse.ArgumentParser()


def train_model(model, criterion, optimizer, dataload, num_epochs=20):
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        dt_size = len(dataload.dataset)
        epoch_loss = 0
        step = 0
        for x, y in dataload:
            step += 1
            inputs = x.to(device)
            labels = y.to(device)
            # zero the parameter gradients
            optimizer.zero_grad()
            # forward
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            print("%d/%d,train_loss:%0.3f" % (step, (dt_size - 1) // dataload.batch_size + 1, loss.item()))
        print("epoch %d loss:%0.3f" % (epoch, epoch_loss))
        torch.save(model.state_dict(), 'weights_%d.pth' % epoch)
    return model


# 训练模型
def train():
    model = Unet(3, 1).to(device)
    batch_size = args.batch_size
    criterion = torch.nn.BCELoss()
    optimizer = optim.Adam(model.parameters())
    liver_dataset = LiverDataset("data/train",transform=x_transforms,target_transform=y_transforms)
    dataloaders = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    train_model(model, criterion, optimizer, dataloaders)


# 显示模型的输出结果
def test():
    model = Unet(3, 1)
    model.load_state_dict(torch.load(args.ckp,map_location='cpu'))
    liver_dataset = LiverDataset("data/val", transform=x_transforms,target_transform=y_transforms)
    dataloaders = DataLoader(liver_dataset, batch_size=1)
    model.eval()
    import matplotlib.pyplot as plt
    plt.ion()
    with torch.no_grad():
        for x, _ in dataloaders:
            y=model(x)
            img_y=torch.squeeze(y).numpy()
            plt.imshow(img_y)
            plt.pause(0.01)
        plt.show()


parse = argparse.ArgumentParser()
# parse.add_argument("action", type=str, help="train or test")
parse.add_argument("--batch_size", type=int, default=1)
parse.add_argument("--ckp", type=str, help="the path of model weight file")
args = parse.parse_args()

# train
#train()

# test()
args.ckp = "weights_19.pth"
test()

测试结果:

4 下载资源

 只需关注下面的公众号,回复unet,即可免费获取源码权重和对应的数据集

<think>好的,我现在需要帮助用户找到在PyTorch实现半监督图像分割UNet代码。首先,用户的问题是关于UNet的半监督学习在PyTorch中的实现,所以我要先明确用户的需求。用户可能已经了解UNet的基本结构,现在想扩展到半监督场景,这意味着需要处理有标签和无标签的数据。 接下来,我应该回顾一下UNet的基本实现UNet的架构包括编码器和解码器,带有跳跃连接,这在医学图像分割中很常见。PyTorch中通常用卷积层和上采样层来构建。然后,半监督学习的关键在于如何利用无标签数据。常见的方法包括自训练(self-training)、一致性正则化(consistency regularization)或者协同训练(co-training)。 根据用户的引用资料,UNet++和UNet3+提到了结合分类和分割任务,这可能是一种策略。比如,在训练分割网络的同时,添加一个分类分支来减少假阳性。不过用户现在需要的是半监督方法,所以可能需要结合这些思想,比如用分类任务来辅助半监督训练。 接下来,我需要考虑代码的结构。用户可能需要一个UNet的基础实现,然后添加半监督的组件。例如,在训练循环中同时处理有标签和无标签的数据批次,对无标签数据生成伪标签,或者应用一致性损失,比如Mean Teacher方法。 然后,我应该查找现有的开源项目或教程。用户提供的引用2提到了Mamba-UNet的GitHub仓库,虽然这可能不直接相关,但可以看看是否有半监督的实现。另外,GitHub上搜索关键词如“UNet PyTorch semi-supervised”可能会找到相关项目。例如,有一个项目可能结合了UNet和Mean Teacher,或者使用对抗训练进行半监督。 另外,半监督常用的技术包括伪标签(pseudo-labeling)、一致性训练(如在不同增强下的预测一致性)、以及熵最小化。在代码实现中,这些方法需要修改损失函数,处理无标签数据,并可能设计双分支网络。 现在,我需要确保提供的代码示例正确,并符合用户的要求。例如,展示一个UNet的基础类,然后添加半监督的部分,比如在forward函数中处理有标签和无标签数据,或者在损失计算中加入一致性损失。 同时,用户可能需要参考现有的论文或框架,比如FixMatch或Mean Teacher在分割任务中的应用。例如,在Mean Teacher中,学生模型和教师模型的参数通过指数移动平均更新,并对无标签数据施加一致性约束。 最后,整理这些思路,给出一个清晰的步骤,包括数据加载、模型定义、损失函数、训练循环中的半监督处理,以及参考实现和文献,确保用户能够根据指引找到合适的资源。</think>在PyTorch实现UNet的半监督图像分割,核心在于结合有标签数据和无标签数据的协同训练。以下是实现思路和代码框架: ### 1. 基础UNet模型定义 ```python import torch import torch.nn as nn class DoubleConv(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() self.conv = nn.Sequential( nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True) ) def forward(self, x): return self.conv(x) class UNet(nn.Module): def __init__(self, in_ch=3, out_ch=1): super().__init__() # 编码器 self.down1 = DoubleConv(in_ch, 64) # ...完整UNet结构定义... # 解码器 self.up1 = nn.ConvTranspose2d(1024, 512, 2, stride=2) # ... def forward(self, x): # 标准UNet前向传播 return x ``` ### 2. 半监督训练策略 推荐采用**一致性正则化**方法[^1]: ```python class SemiSupervisedLoss(nn.Module): def __init__(self, supervised_weight=1.0, consistency_weight=0.1): super().__init__() self.supervised_loss = nn.BCEWithLogitsLoss() self.consistency_loss = nn.MSELoss() self.weights = (supervised_weight, consistency_weight) def forward(self, labeled_pred, labeled_target, unlabeled_pred1, unlabeled_pred2): sup_loss = self.supervised_loss(labeled_pred, labeled_target) cons_loss = self.consistency_loss(unlabeled_pred1.sigmoid(), unlabeled_pred2.detach().sigmoid()) return self.weights[0]*sup_loss + self.weights[1]*cons_loss ``` ### 3. 训练流程关键代码 ```python def train_step(model, labeled_data, unlabeled_data, optimizer): labeled_imgs, labels = labeled_data unlabeled_imgs = unlabeled_data[0] # 有监督分支 supervised_pred = model(labeled_imgs) # 无监督分支(添加不同数据增强) aug1 = weak_augmentation(unlabeled_imgs) aug2 = strong_augmentation(unlabeled_imgs) pred1 = model(aug1) pred2 = model(aug2) # 计算联合损失 loss = criterion(supervised_pred, labels, pred1, pred2) optimizer.zero_grad() loss.backward() optimizer.step() return loss ``` ### 实现建议 1. **数据加载**:需同时加载有标签和无标签数据集 2. **数据增强**: - 有标签数据:常规增强(旋转/翻转) - 无标签数据:强增强(颜色抖动、高斯模糊)[^2] 3. **动态权重调整**:随着训练过程逐步增加无监督损失权重 ### 相关资源推荐 1. GitHub参考实现: - [Semi-Supervised Medical Image Segmentation](https://github.com/HiLab-git/SSL4MIS) - [Uncertainty-aware Self-ensembling Model](https://github.com/yulequan/UA-MT) 2. 重要论文方法: - Mean Teacher框架 - Cross-Consistency Training[^2] - Uncertainty-aware方案(适合医学图像
评论 93
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Mr.Q

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值