一、数据集制作
主要内容是Dataset类的使用
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from utils import *
from torchvision import transforms
transform = transforms.Compose([
transforms.ToTensor()
])
class MyDataset(Dataset):
def __init__(self, p):
self.path = p
self.name = os.listdir(os.path.join(p, 'SegmentationClass'))
def __len__(self):
return len(self.name)
def __getitem__(self, index):
segment_name = self.name[index] # xx.png
segment_path = os.path.join(self.path, 'SegmentationClass', segment_name)
image_path = os.path.join(self.path, 'JPEGImages', segment_name)
segment_image = keep_image_size_open(segment_path)
image = keep_image_size_open_rgb(image_path)
return transform(image), torch.Tensor(np.array(segment_image))
if __name__ == '__main__':
from torch.nn.functional import one_hot
data_total = MyDataset('data')
print(data_total[0][0].shape)
print(data_total[0][1].shape)
out=one_hot(data_total[0][1].long())
print(out.shape)
print(data_total.name)
print(data_total.path)
二、train训练(包括net)
Dataloader的使用、Unet网络的使用
import os
import tqdm
from torch import nn, optim
import torch
from torch.utils.data import DataLoader
from data import *
from net import *
from torchvision.utils import save_image
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
weight_path = 'params/unet.pth'
data_path = r'data'
save_path = 'train_image'
if __name__ == '__main__':
num_classes = 2 + 1 # +1是背景也为一类
data_loader = DataLoader(MyDataset(data_path), batch_size=1, shuffle=True)
# print(MyDataset(data_path).path)
net = UNet(num_classes).to(device)
if os.path.exists(weight_path):
net.load_state_dict(torch.load(weight_path))
print('successful load weight!')
else:
print('not successful load weight')
opt = optim.Adam(net.parameters())
loss_fun = nn.CrossEntropyLoss()
epoch = 1
while epoch < 200:
for i, (image, segment_image) in enumerate(tqdm.tqdm(data_loader)):
image, segment_image = image.to(device), segment_image.to(device)
out_image = net(image)
train_loss = loss_fun(out_image, segment_image.long())
opt.zero_grad()
train_loss.backward()
opt.step()
if i % 1 == 0:
print(f'{epoch}-{i}-train_loss===>>{train_loss.item()}')
_image = image[0]
_segment_image = torch.unsqueeze(segment_image[0], 0) * 255
_out_image = torch.argmax(out_image[0], dim=0).unsqueeze(0) * 255
img = torch.stack([_segment_image, _out_image], dim=0)
save_image(img, f'{save_path}/{i}.png')
if epoch % 20 == 0:
torch.save(net.state_dict(), weight_path)
print('save successfully!')
epoch += 1
net
import torch
from torch import nn
from torch.nn import functional as F
class Conv_Block(nn.Module):
def __init__(self,in_channel,out_channel):
super(Conv_Block, self).__init__()
self.layer=nn.Sequential(
nn.Conv2d(in_channel,out_channel,3,1,1,padding_mode='reflect',bias=False),
nn.BatchNorm2d(out_channel),
nn.Dropout2d(0.3),
nn.LeakyReLU(),
nn.Conv2d(out_channel, out_channel, 3, 1, 1, padding_mode='reflect', bias=False),
nn.BatchNorm2d(out_channel),
nn.Dropout2d(0.3),
nn.LeakyReLU()
)
def forward(self,x):
return self.layer(x)
class DownSample(nn.Module):
def __init__(self,channel):
super(DownSample, self).__init__()
self.layer=nn.Sequential(
nn.Conv2d(channel,channel,3,2,1,padding_mode='reflect',bias=False),
nn.BatchNorm2d(channel),
nn.LeakyReLU()
)
def forward(self,x):
return self.layer(x)
class UpSample(nn.Module):
def __init__(self,channel):
super(UpSample, self).__init__()
self.layer=nn.Conv2d(channel,channel//2,1,1)
def forward(self,x,feature_map):
up=F.interpolate(x,scale_factor=2,mode='nearest')
out=self.layer(up)
return torch.cat((out,feature_map),dim=1)
class UNet(nn.Module):
def __init__(self,num_classes):
super(UNet, self).__init__()
self.c1=Conv_Block(3,64)
self.d1=DownSample(64)
self.c2=Conv_Block(64,128)
self.d2=DownSample(128)
self.c3=Conv_Block(128,256)
self.d3=DownSample(256)
self.c4=Conv_Block(256,512)
self.d4=DownSample(512)
self.c5=Conv_Block(512,1024)
self.u1=UpSample(1024)
self.c6=Conv_Block(1024,512)
self.u2 = UpSample(512)
self.c7 = Conv_Block(512, 256)
self.u3 = UpSample(256)
self.c8 = Conv_Block(256, 128)
self.u4 = UpSample(128)
self.c9 = Conv_Block(128, 64)
self.out=nn.Conv2d(64,num_classes,3,1,1)
def forward(self,x):
R1=self.c1(x)
R2=self.c2(self.d1(R1))
R3 = self.c3(self.d2(R2))
R4 = self.c4(self.d3(R3))
R5 = self.c5(self.d4(R4))
O1=self.c6(self.u1(R5,R4))
O2 = self.c7(self.u2(O1, R3))
O3 = self.c8(self.u3(O2, R2))
O4 = self.c9(self.u4(O3, R1))
return self.out(O4)
if __name__ == '__main__':
x=torch.randn(2,3,256,256)
net=UNet()
print(net(x).shape)
utils
from PIL import Image
def keep_image_size_open(path, size=(256, 256)):
img = Image.open(path)
temp = max(img.size)
mask = Image.new('P', (temp, temp))
mask.paste(img, (0, 0))
mask = mask.resize(size)
return mask
def keep_image_size_open_rgb(path, size=(256, 256)):
img = Image.open(path)
temp = max(img.size)
mask = Image.new('RGB', (temp, temp))
mask.paste(img, (0, 0))
mask = mask.resize(size)
return mask
三、test
import os
import cv2
import numpy as np
import torch
from net import *
from utils import *
from data import *
from torchvision.utils import save_image
from PIL import Image
net=UNet(3).cuda()
weights='params/unet.pth'
if os.path.exists(weights):
net.load_state_dict(torch.load(weights))
print('successfully')
else:
print('no loading')
_input=input('please input JPEGImages path:')
img=keep_image_size_open_rgb(_input)
img_data=transform(img).cuda()
img_data=torch.unsqueeze(img_data,dim=0)
net.eval()
out=net(img_data)
out=torch.argmax(out,dim=1)
out=torch.squeeze(out,dim=0)
out=out.unsqueeze(dim=0)
print(set((out).reshape(-1).tolist()))
out=(out).permute((1,2,0)).cpu().detach().numpy()
cv2.imwrite('result/result.png',out)
cv2.imshow('out',out*255.0)
cv2.waitKey(0)