分类实战:对图片进行分类
其中带标签的训练数据共有11类,每类280个,不带标签的训练数据共6786个。
上半是对带标签的数据进行训练验证,下半则主要介绍半监督学习及代码复盘。
半监督学习利用少量有标签和大量无标签的数据来训练模型。首先用初始的带标签数据训练模型;其次用模型对无标签数据进行预测,打上标签;然后在保证一定准确率和置信度的基础上,选择无标签数据加入带标签数据,循环往复直至停止。
1.数据处理
(1)在原数据集类中加入无标签数据的部分,初始化函数与读文件函数中,若为半监督模式,只读入x。
(2)定义一个新的数据集类以选择满足要求的无标签数据
no_label_set = food_Dataset(no_label_path, "semi") #无标签数据实例化
no_label_loader = DataLoader(no_label_set, batch_size=16, shuffle=False)
#shuffle=true可能导致打上的标签与图片不对应
class semi_Dataset(Dataset):
def __init__(self, no_label_loader, model, device, thres=0.99):
x, y = self.get_label(no_label_loader, model, device, thres) #调用获取标签函数
if x == []: #无符合要求数据
self.flag = False
else: #有符合要求的数据
self.flag = True
self.X = np.array(x) #将x转换成numpy数组
self.Y = torch.LongTensor(y) #将y转换为长整型张量(标签)
self.transform = train_transform #加入训练数据,故使用训练集的数据变换
def get_label(self, no_label_loader, model, device, thres):
model = model.to(device)
pred_prob = [] #每个样本的最大预测概率
labels = [] #每个样本的预测类别
x = []
y = []
soft = nn.Softmax()
with torch.no_grad():
for bat_x, _ in no_label_loader:
bat_x = bat_x.to(device)
pred = model(bat_x) #利用模型得到预测值
pred_soft = soft(pred) #应用softmax函数将原始输出转为概率分布
pred_max, pred_value = pred_soft.max(1)
#获取每个输入样本的最大概率值及对应的类别索引,1代表维度
pred_prob.extend(pred_max.cpu().numpy().tolist())
#放到cpu,转为矩阵再转为列表
labels.extend(pred_value.cpu().numpy().tolist())
for index, prob in enumerate(pred_prob):
if prob > thres: #若概率大于阈值,则返回原始的无标签数据及其被打上的标签
x.append(no_label_loader.dataset[index][1]) #调用到原始的无标签样本
y.append(labels[index])
return x, y
#获取标签函数额外接收thres阈值参数,返回所有无标签数据及其预测标签
def __getitem__(self, item):
return self.transform(self.X[item]), self.Y[item]
def __len__(self):
return len(self.X)
2.模型构建(与上半相同)
3.训练和验证
(1)由于新打上标签的数据是在训练过程中产生的,可能为空,所以需要定义一个加载器函数。
def get_semi_Loader(no_label_loader, model, device, thres):
semiset = semi_Dataset(no_label_loader, model, device, thres)
if semiset.flag == False: #如果有符合要求的数据则生成加载器,否则返回none
return None
else:
semi_loader = DataLoader(semiset, batch_size=16, shuffle=False)
return semi_loader
(2)在原训练验证过程中,当准确率大于一定值,且轮次数满足一定规律(每轮读浪费时间),读入semi_loader,若其不为空,则进行训练并打印准确率。同时训练和验证函数需引入新参数thres,便于加载器读入。
4.在实际配置中,阈值通常设为0.99,以保证打上标签的准确性;准确率通常设在0.6以上以保证可靠性。
另外,本次实战课程认识了随机种子,用以保证多次运行时产生相同的随机效果。
比如一个程序需要在1~5中随机取三次数,运行结果为第一次取2,第二次取1,第三次取4.若使用随机种子,该程序下一次运行时结果依然是第一次取2,第二次取1,第三次取4,若不使用则第一次可能取3...
def seed_everything(seed):
#随机种子,固定随机数,每次运行同一随机变量的值不变(可直接调用)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
seed_everything(0)