基础分割模型U-Net

数据集carvana:https://www.kaggle.com/competitions/carvana-image-masking-challenge/data

import os
import numpy as np
import collections
import torch
import torch.nn as nn
import torch.optim as optim
import totch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
form torchvision import transforms
import PIL
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

class DoubleConv(nn.Module):
	def __init__(self, in_channels, out_channels, mid_channels=None):
		super().__init__()
		if not mid_channels:
			mid_channels = out_channels
		self.double_conv = nn.Sequential(
			nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
			nn.BatchNorm2d(mid_channels),
			nn.ReLU(),
			nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
			nn.BatchNorm2d(out_channels),
			nn.ReLU(),
		)
	def forward(self, x):
		return self.double_conv(x)

class Down(nn.Module):
	def __init__(self, in_channels, out_channels):
		super().__init__()
		self.maxpool_conv = nn.Sequential(
			nn.MaxPool2d(2),
			DoubleConv(in_channels, out_channels)
		)
	def forward(self, x):
		return self.maxpool_conv(x)

class Up(nn.Module):
	def __init__(self, in_channels, out_channels, bilinear=True):
		super().__init__()
		if bilinear:
			self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
			self.conv = DoubleConv(in_channels, out_channels, in_channels//2)
		else:
			self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2)
			self.conv = DoubleConv(in_channels, out_channels)
	def forward(self, x1, x2):
		x1 = self.up(x1)
		diffX = x2.size()[2] - x1.size()[2]
		diffY = x2.size()[3] - x1.size()[3]
		x1 = F.pad(x1, [diffX//2, diffX-diffX//2, diffY//2, diffY-diffY//2])
		x = torch.cat([x2, x1], dim=1)
		return self.conv(x)

class OutConv(nn.Module):
	def __init__(self, in_channels, out_channels):
		super().__init__()
		self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
	def forward(self, x):
		return self.conv(x)

class UNet(nn.Module):
	def __init__(self, n_channels, n_classes, bilinear=True):
		super().__init__()
		self.n_channels = n_channels
		self.n_classes = n_classes
		self.bilinear = bilinear

		self.inc = DoubleConv(n_channels, 64)
		self.down1 = Down(64, 128)
		self.down2 = Down(128, 256)
		self.down3 = Down(256, 512)
		factor = 2 if bilinear else 1
		self.down4 = Down(512, 1024//factor)
		self.up1 = Up(1024, 512//factor, bilinear)
		self.up2 = Up(512, 256//factor, bilinear)
		self.up3 = Up(256, 128//factor, bilinear)
		self.up4 = Up(128, 64, bilinear)
		self.outc = OutConv(64, n_classes)

	def forward(self, x)
		x1 = self.inc(x)
		x2 = self.down1(x1)
		x3 = self.down2(x2)
		x4 = self.down3(x3)
		x5 = self.down4(x4)
		x = self.up1(x5, x4)
		x = self.up2(x, x3)
		x = self.up3(x, x2)
		x = self.up4(x, x1)
		logits = self.outc(x)
		return logits

class CarvanaDataset(Dataset):
	def __init__(self, base_dir, idx_list, mode="train", transform=None):
		super().__init__()
		self.base_dir = base_dir
		self.idx_list = idx_list
		self.images = os.listdir(base_dir+"train")
		self.masks = os.listdir(base_dir+"train_masks")
		self.mode = mode
		self.transform = transform
	def __len__(self):
		return len(self.idx_list)
	def __getitem__(self, idx):
		image_file = self.images[self.idx_list[idx]]
		mask_file = image_file[:-4] + "_mask.gif"
		image = PIL.Image.open(os.path.join(self.base_dir, "train", image_file))
		if self.mode == "train":
			mask = PIL.Image.open(os.path.join(self.base_dir, "train_masks", mask_file))
			if self.transform is not None:
				image = self.transform(image)
				mask = self.transform(mask)
				mask[mask!=0] = 1.0
			return image, mask.float()
		else:
			if self.transform is not None:
				image = self.transform(image)
			return image

def dice_coeff(pred, target):
	# 评价指标
	eps = 1e-4
	num = pred.size(0)
	m1 = pred.view(num, -1)
	m2 = target.view(num, -1)
	intersection = (m1 * m2).sum()
	return (2. * intersection + eps) / (m1.sum() + m2.sum() + eps)

class DiceLoss(nn.Module):
	# 分割模型常用dice系数作为损失函数,这里是自定义对应的损失函数
	def __init__(self, weight=None, size=average=True):
		super().__init__()
	def forward(self, inputs, targets, smooth=1):
		inputs = torch.sigmoid(inputs)
		inputs = inputs.view(-1)
		targets = targets.view(-1)
		intersection = (inputs * targets).num()
		dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
		return 1 - dice

if __name__ == "__main__":
	batch_size = 16
	num_works = 4
	epochs = 10
	lr = 1e-3
	img_size = 256
	weight_decay = 1e-8
	interval = 50
	device = torch.device("cuda:0")
	base_dir = "./carvana"

	transform = transforms.Compose([
		transforms.Resize((img_size, img_size)),
		transforms.ToTensor(),
	])
	train_idxs, val_idxs = train_test_split(
		range(len(os.listdir(base_dir+"train_masks"))),
		test_size=0.3,
	)
	train_data = CarvanaDataset(base_dir, train_idxs, transform=transform)
	val_data = CarvanaDataset(base_dir, val_idxs, transform=transform)
	train_loader = DataLoader(
		train_data, batch_size=batch_size, num_workders=num_works, shuffle=True,
	)
	val_loader = DataLoader(
		val_data, batch_size=batch_size, num_workders=num_works, shuffle=False,
	)
	image, mask = next(iter(train_loader))
	plt.imsave("tmp_check.jpg", image[0][0])
	plt.imsave("tmp_mask.jpg", mask[0][0], camp="gray")

	model = UNet(3, 1)
	model = model.cuda(device)

	criterion = nn.BCEWithLogitsLoss()
	# criterion = DiceLoss() # 使用自定义的损失函数
	optimizer = optim.Adam(model.paramters(), lr=lr, weight_decay=weight_decay)
	# 规划器,动态调整学习率 每个epoch变为原来的0.8倍
	scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.8)

	def train():
		model.train(epoch, epochs, interval)
		train_loss = 0
		for i, (data, mask) in enumerate(train_loader):
			data, mask = data.to(device), mask.to(device)
			optimizer.zero_grad()
			output = model(data)
			loss = criterion(output, mask)
			loss.backward()
			optimizer.setp()
			train_loss += loss.item() * data.size(0)
			if (i+1) % interval == 0:
				print(
					"loader({}/{}): lr:{:.7f} \ttrain_loss: {:.4f}".format(
						i + 1,
						len(train_loader),
						optimizer.state_dict()["param_groups"][0]["lr"],
						(train_loss / ((i+1)*train_loader.batch_size))
					),
				)
		train_loss = train_loss / len(train_loader.dataset)
		print("Epoch({}/{}): lr:{:.7f} \ttrain_loss: {:.4f}".format(epoch, epochs, lr, train_loss), end="")
	def val(epoch):
		model.eval()
		val_loss = 0
		dice_score = 0
		with torch.no_grad():
			for data, mask in val_loader:
				data, mask = data.to(device), mask.to(device)
				output = model(data)
				loss = criterion(output, mask)
				val_loss += loss.item() * data.size(0)
				dice_score += dice_coeff(torch.sigmoid(output).cpu(), mask.cpu()*data.size(0))
		val_loss = val_loss / len(val_loader.dataset)
		dice_score = dice_score / len(val_loader)
		print(" \tval_loss: {:.4f} \tdice_score: {:.4f}".format(val_loss, dice_score))
		return dice_score

	best_dice = 0
	for epoch in range(1, epochs + 1):
		train_loss = train(epoch, epochs, interval)
		dice_score = val(epoch)
		scheduler.step()	# 动态调整学习率
		if dice_score > best_dice:
			torch.save(model, "UNet_best.pth")
		torch.sava(model, "UNet_last.pth")
	print("best dice score:", best_dice)
		
	# 修改模型层
	import copy
	model1 = copy.deepcopy(model)
	x = torch.rand(1, 3, 224, 224)
	out = model(x)
	print(out.shape)
	model1.outc = OutConv(64, 5)
	out1 = model1(x)
	print(out1.shape)

	# 保存整个模型
	torch.save(model, "UNet.pth")
	# 保存模型权重,同时适用于多卡的情况
	torch.save(model.state_dict(), "UNet2.pth")
	
	# 冻结最后一层的参数,让其不进行梯度回传,适用于模型微调
	model.outc.conv.weight.requires_grad = False
	model.outc.conv.bias.requires_grad = False
	for layer, param in model.named_parameters():
		print(layer, "\t", param.requires_grad)

	# 查看参数量
	from torchinfo import summary
	summary(model, (1, 3, 224, 224))
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值