程序结构基本与https://blog.youkuaiyun.com/haohulala/article/details/107660273这篇文章相似,有什么问题也可以参考这篇文章。
import torch
from torch import nn
import torch.nn.functional as f
import torchvision
import torchvision.transforms as tfs
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torchvision.models as models
import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt
from datetime import datetime
# VOC数据集中对应的标签
classes = ["Animal", "Archway","Bicyclist","Bridge","Building","Car","CartLuggagePram",
"Child","Column_Pole", "Fence", "LaneMkgsDriv", "LaneMkgsNonDriv", "Misc_Text",
"MotorcycleScooter", "OtherMoving", "ParkingBlock", "Pedestrian", "Road", "RoadShoulder",
"Sidewalk", "SignSymbol", "Sky", "SUVPickupTruck", "TrafficCone", "TrafficLight",
"Train", "Tree", "Truck_Bus", "Tunnel", "VegetationMisc", "Void", "Wall"]
# 各种标签所对应的颜色
colormap = [[64,128,64],[192,0,128],[0,128,192],[0,128,64],[128,0,0],[64,0,128],
[64,0,192],[192,128,64],[192,192,128],[64,64,128],[128,0,192],[192,0,64],
[128,128,64],[192,0,192],[128,64,64],[64,192,128],[64,64,0],[128,64,128],
[128,128,192],[0,0,192],[192,128,128],[128,128,128],[64,128,192],[0,0,64],
[0,64,64],[192,64,128],[128,128,0],[192,128,192],[64,0,64],[192,192,0],
[0,0,0],[64,192,0]]
num_classes = len(classes)
print(num_classes)
print(len(colormap))
32
32
data_root = "./data"
ROOT = "./data/SegNet/CamVid"
# 开始读取数据
def read_image(mode="train", val=False):
if(mode=="train"): # 加载训练数据
filename = ROOT + "/train.txt"
elif(mode == "test"): # 加载测试数据
filename = ROOT + "/test.txt"
elif(mode == "val"):
filename = ROOT + "/val.txt"
else:
print("没有这个mod,请检查代码是否写错")
data = []
label = []
with open(filename, "r") as f:
images = f.read().split()
for i in range(len(images)):
if(i%2 == 0):
data.append(data_root+images[i])
else:
label.append(data_root+images[i])
if(val==True):
if(mode == "train"): # 将验证集也读入训练数据
filename = ROOT + "/val.txt"
with open(filename, "r") as f:
images = f.read().split()
for i in range(len(images)):
if(i%2 == 0):
data.append(data_root+images[i])
else:
label.append(data_root+images[i])
print(mode+":读取了"+str(len(data))+"张图片")
print(mode+":读取了"+str(len(label))+"张图片的标签")
return data, label
data, label = read_image("train")
print(data[:10], label[:10])
train:读取了367张图片
train:读取了367张图片的标签
['./data/SegNet/CamVid/train/0001TP_006690.png', './data/SegNet/CamVid/train/0001TP_006720.png', './data/SegNet/CamVid/train/0001TP_006750.png', './data/SegNet/CamVid/train/0001TP_006780.png', './data/SegNet/CamVid/train/0001TP_006810.png', './data/SegNet/CamVid/train/0001TP_006840.png', './data/SegNet/CamVid/train/0001TP_006870.png', './data/SegNet/CamVid/train/0001TP_006900.png', './data/SegNet/CamVid/train/0001TP_006930.png', './data/SegNet/CamVid/train/0001TP_006960.png'] ['./data/SegNet/CamVid/trainannot/0001TP_006690.png', './data/SegNet/CamVid/trainannot/0001TP_006720.png', './data/SegNet/CamVid/trainannot/0001TP_006750.png', './data/SegNet/CamVid/trainannot/0001TP_006780.png', './data/SegNet/CamVid/trainannot/0001TP_006810.png', './data/SegNet/CamVid/trainannot/0001TP_006840.png', './data/SegNet/CamVid/trainannot/0001TP_006870.png', './data/SegNet/CamVid/trainannot/0001TP_006900.png', './data/SegNet/CamVid/trainannot/0001TP_006930.png', './data/SegNet/CamVid/trainannot/0001TP_006960.png']
im = Image.open(data[0])
lab = Image.open(label[0])
plt.subplot(1,2,1), plt.imshow(im)
plt.subplot(1,2,2), plt.imshow(lab)
lab = np.array(lab)
print(lab.shape)
lab = torch.from_numpy(lab)
print(lab.shape)
im = tfs.ToTensor()(im)
print(im.shape)
(360, 480)
torch.Size([360, 480])
torch.Size([3, 360, 480])
size = 224
def crop(data, label, height=size, width=size):
st_x = 50
st_y = 50
box = (st_x, st_y, st_x+width, st_y+height)
data = data.crop(box)
label = label.crop(box)
return data, label
im = Image.open(data[0])
lab = Image.open(label