from backbone import ResNet12, ResNet18, WRN, ConvNet
from backbone import ResNet12, ResNet18, WRN, ConvNet
from dpgn import DPGN
from utils import set_logging_config, adjust_learning_rate, save_checkpoint, allocate_tensors, preprocessing,
initialize_nodes_edges, backbone_two_stage_initialization, one_hot_encode
from dataloader import MiniImagenet, TieredImagenet, Cifar, CUB200, DataLoader # CUB和Omuglot数据集用NN
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os
import random
import logging # 输出运行日志,可以设置输出日志的等级、日志保存路径、日志文件 #https://blog.youkuaiyun.com/u011159607/article/details/79985087
import argparse
import imp # 模块名,可以看出其实就是"import"的缩写。在功能上也一样,但用起来比import要成熟一些。
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
#将论文《Attentive》的node_self想法加入其中,构成新的点图特征.
#因模型过深,导致结果准确率为1
class DPGNTrainer(object):
def init(self, enc_module, gnn_module, data_loader, log, arg, config, best_step):
“”"
The Trainer of DPGN model
:param enc_module: backbone network (Conv4, ResNet12, ResNet18, WRN)
:param gnn_module: DPGN model
:param data_loader: data loader
:param log: logger
:param arg: command line arguments
:param config: model configurations
:param best_step: starting step (step at best eval acc or 0 if starts from scratch)
“”"
self.arg = arg
self.config = config
self.train_opt = config['train_config']
self.eval_opt = config['eval_config']
# initialize variables
self.tensors = allocate_tensors()
for key, tensor in self.tensors.items():
self.tensors[key] = tensor.to(self.arg.device)
# set backbone and DPGN
self.enc_module = enc_module.to(self.arg.device)
self.gnn_module = gnn_module.to(self.arg.device)
# set logger
self.log = log
# get data loader
self.data_loader = data_loader
# set parameters
self.module_params = list(self.enc_module.parameters()) + list(self.gnn_module.parameters())
# set optimizer
self.optimizer = optim.Adam(
params=self.module_params,
lr=self.train_opt['lr'],
weight_decay=self.train_opt['weight_decay'])
# set loss
self.edge_loss = nn.BCELoss(reduction='none') # 只用于二分类问题
self.pred_loss = nn.CrossEntropyLoss(reduction='none')
# initialize other global variables
self.global_step = best_step
self.best_step = best_step
self.val_acc = 0
self.test_acc = 0
def train(self):
"""
train function
:return: None
""" # 除前5行前5列之外都是1
#见utils.py105 5 10 维度251010 维度251010 的全一矩阵
num_supports, num_samples, query_edge_mask, evaluation_mask =
preprocessing(self.train_opt[‘num_ways’], # 5
self.train_opt[‘num_shots’], # 1
self.train_opt[‘num_queries’], # 1
self.train_opt[‘batch_size’], # 25
self.arg.device)
# main training loop, batch size is the number of tasks
for iteration, batch in enumerate(self.data_loader['train']()):
# 每个batch是一个列表,且列表中有4个元素,第一个元素是tensor矩阵维度[1,25,5,3,84,84],第二个元素是teonor维度是[1,25,5],每一行是[0,1,2,3,4]
# 第3个元素是tensor矩阵维度[1,25,5,3,84,84],第4个元素是teonor维度是[1,25,5],每一行是[0,1,2,3,4]
# 这四个元素分别对应support_data, support_label, query_data, query_label
# init grad
self.optimizer.zero_grad() # 梯度初始化为0 当前参数空间对应的梯度,这也就解释了为什么optimzier使用之前需要zero清零一下,
# 因为如果不清零, 那么使用的这个grad就得同上一个mini-batch有关,这不是我们需要的结果。
# set current step
self.global_step = self.global_step + 1
# initialize nodes and edges for dual graph model
support_data, support_label, query_data, query_label, all_data, all_label_in_edge, node_feature_gd, \
edge_feature_gp, edge_feature_gd = initialize_nodes_edges(batch,
num_supports, # 5
self.tensors,
self.train_opt['batch_size'],
self.train_opt['num_queries'],
self.train_opt['num_ways'],
self.arg.device)
# print('X_S', support_label)
# print('X_Q ', query_label)
# support_data.shape [25,5,3,84,84], support_label.shape[25,5],每一行是[0,1,2,3,4]
# query_data.shape [25,5,3,84,84], query_label.shape[25,5],每一行是[0,1,2,3,4]
# all_data.shape [25,10,3,84,84], all_label_in_edge.shape [25,10,10]
# node_feature_gd.shape [25,10,5] , edge_feature_gd.shape[25,10,10]
# edge_feature_gp.shape [25,10,10]
# print('1 ', support_data.shape, ' label:', support_label, support_label.shape)
# print('2 ', query_data.shape, ' label:', query_label, query_label.shape)
# print('3 ', all_data, all_data.shape, ' label: ', all_label_in_edge, all_label_in_edge.shape)
# print('4 ', node_feature_gd, node_feature_gd.shape, ' edge_feature_gd: ', edge_feature_gd, edge_feature_gd.shape)
# print('5 ', edge_feature_gp, edge_feature_gp.shape)
# exit()
# print(all_label_in_edge)
# set as train mode
self.enc_module.train()
self.gnn_module.train()
# use backbone encode image
# 调用网络模型时需要加上 with torch.no_grad():,防止内存溢出
# with torch.no_grad():
# [25,10,128] [25,10,128] utils.py 201
last_layer_data, second_last_layer_data = backbone_two_stage_initialization(all_data, self.enc_module)
#one_hot
supoort_hot = one_hot_encode(self.train_opt[‘num_ways’], support_label.long(), self.arg.device)
query_hot = one_hot_encode(self.train_opt[‘num_ways’], query_label.long(), self.arg.device)
y_hot = torch.cat([supoort_hot, query_hot], dim=1)
#转置
last_layer_data_transpose = torch.transpose(last_layer_data, 1, 2)
second_last_layer_data_transpose = torch.transpose(second_last_layer_data, 1, 2)
y_hot_transpose = torch.transpose(y_hot, 1, 2)
#类间和样本间关系
last_c_data = torch.softmax(torch.bmm(last_layer_data, last_layer_data_transpose), dim=-1)
second_c_data = torch.softmax(torch.bmm(second_last_layer_data, second_last_layer_data_transpose), dim=-1)
c_y = torch.softmax(torch.bmm(y_hot, y_hot_transpose), dim=-1)
#拼接
last_c = torch.cat([last_c_data, c_y], dim=-1) #[25,10,20]
second_c = torch.cat([second_c_data, c_y], dim=-1) #[25,10,20]
# print(second_c.shape)
#用1×1卷积核进行下采样
layer_list = []
layer_list = layer_list + [nn.Conv2d(in_channels=num_samples2, out_channels=num_samples, kernel_size=1, bias=False),
nn.BatchNorm2d(num_features=num_samples),
nn.LeakyReLU()]
conv = nn.Sequential(layer_list)
c_last_f = conv.cuda()(last_c.transpose(1, 2).unsqueeze(-1)).transpose(1, 2).squeeze(-1) #[25,10,10]
# print(c_last_f.shape)
c_second_f = conv.cuda()(second_c.transpose(1, 2).unsqueeze(-1)).squeeze(-1) #[25,10,10]
#增强自注意
x_l_last = torch.bmm(c_last_f, last_layer_data) #[25,10,128]
x_l_second = torch.bmm(c_second_f, second_last_layer_data) #[25,10,128]
y_l_last = 0.5y_hot + 0.5torch.bmm(c_last_f, y_hot) #[25,10,5]
y_l_second = 0.5 * y_hot + 0.5 * torch.bmm(c_second_f, y_hot) #[25,10,5]
#将最终的结果拼接并且用1×1卷积下采样
F_last = torch.cat([x_l_last, y_l_last], dim=-1) #[25,10,133]
F_second = torch.cat([x_l_second, y_l_second], dim=-1) #[25,10,133]
last_layer_data = nn.Conv2d(in_channels=133, out_channels=128, kernel_size=1, bias=False).cuda()(F_last.transpose(1, 2).unsqueeze(-1)).transpose(1, 2).squeeze(-1) #[25,10,128`]
second_last_layer_data = nn.Conv2d(in_channels=133, out_channels=128, kernel_size=1, bias=False).cuda()(F_second.transpose(1, 2).unsqueeze(-1)).transpose(1, 2).squeeze(-1)
# print(last_layer_data.shape)
# print(second_last_layer_data.shape)
# exit()
# run the DPGN model
# 首先这3个是列表,,每个列表中都有6个元素,每个元素都是tensor矩阵,维度都是[25,10,10]
point_similarity, node_similarity_l2, distribution_similarities = self.gnn_module(second_last_layer_data,
# [25,10,128]
last_layer_data,
# [25,10,128]
node_feature_gd,
# [25,10,5]
edge_feature_gd,
# [25,10,10]
edge_feature_gp) # [25,10,10]
# print('***************')
# print('point_similarity ', point_similarity[0], ' ', len(point_similarity), point_similarity[0].shape, point_similarity[1].shape)
# print('node_similarity_l2 ', node_similarity_l2[0], ' ', len(node_similarity_l2), node_similarity_l2[0].shape, node_similarity_l2[1].shape)
# print('distribution_similarities ', distribution_similarities[0], ' ', len(distribution_similarities), distribution_similarities[0].shape)
# exit()
# compute loss
# 1.一个tensor数据 2.列表,有6个元素,每个元素是个tensor一维数据 3. 列表,有6个元素,每个元素是个tensor一维数据
total_loss, query_node_cls_acc_generations, query_edge_loss_generations = \
self.compute_train_loss_pred(all_label_in_edge, # [25,10,10]
point_similarity, # [25,10,10]
node_similarity_l2, # [25,10,10]
query_edge_mask, # [25,10,10],除前5行前5列外,都是1
evaluation_mask, # [25,10,10],全一矩阵
num_supports, # 5
support_label, # [25,5]
query_label, # [25,5]
distribution_similarities) # [25,10,10]
# print(total_loss, total_loss.shape)
# print(query_node_cls_acc_generations, query_node_cls_acc_generations[0].shape)
# print(query_edge_loss_generations, query_edge_loss_generations[0].shape)
# exit()
# #toensor在反向传播(梯度)的时候,一定要注意传入一个参数requires_grad=True
# back propagation & update
# total_loss.requires_grad_().backward()
total_loss.backward()
self.optimizer.step()
# adjust learning rate
adjust_learning_rate(optimizers=[self.optimizer],
lr=self.train_opt['lr'],
iteration=self.global_step,
dec_lr_step=self.train_opt['dec_lr'])
# log training info
if self.global_step % self.arg.log_step == 0: # 每100轮打印一次
self.log.info('step : {} train_edge_loss : {} node_acc : {}'.format(
self.global_step,
query_edge_loss_generations[-1],
query_node_cls_acc_generations[-1]))
# evaluation
if self.global_step % self.eval_opt['interval'] == 0: # 每1000轮打印一次
is_best = 0
test_acc = self.eval(partition='test')
if test_acc > self.test_acc:
is_best = 1
self.test_acc = test_acc
self.best_step = self.global_step
# log evaluation info
self.log.info('test_acc : {} step : {} '.format(test_acc, self.global_step))
self.log.info('test_best_acc : {} step : {}'.format(self.test_acc, self.best_step))
self.log.info('---------NEXT_EPOCH----------')
# save checkpoints (best and newest)
save_checkpoint({
'iteration': self.global_step,
'enc_module_state_dict': self.enc_module.state_dict(),
'gnn_module_state_dict': self.gnn_module.state_dict(),
'test_acc': self.test_acc,
'optimizer': self.optimizer.state_dict(),
}, is_best, os.path.join(self.arg.checkpoint_dir, self.arg.exp_name))
def eval(self, partition='test', log_flag=True):
"""
evaluation function
:param partition: which part of data is used
:param log_flag: if log the evaluation info
:return: None
"""
num_supports, num_samples, query_edge_mask, evaluation_mask = preprocessing(
self.eval_opt['num_ways'],
self.eval_opt['num_shots'],
self.eval_opt['num_queries'],
self.eval_opt['batch_size'],
self.arg.device)
query_edge_loss_generations = []
query_node_cls_acc_generations = []
# main training loop, batch size is the number of tasks
for current_iteration, batch in enumerate(self.data_loader[partition]()):
# initialize nodes and edges for dual graph model
support_data, support_label, query_data, query_label, all_data, all_label_in_edge, node_feature_gd, \
edge_feature_gp, edge_feature_gd = initialize_nodes_edges(batch,
num_supports,
self.tensors,
self.eval_opt['batch_size'],
self.eval_opt['num_queries'],
self.eval_opt['num_ways'],
self.arg.device)
# print('test_S', support_label.shape)
# print('test_Q ', query_label.shape)
# if self.global_step %10 ==0:
# print(torch.cat([support_label, query_label], 1))
# print('eval', query_label)
# print(all_label_in_edge)
# set as eval mode
self.enc_module.eval() # 不启用 BatchNormalization 和 Dropout,框架会自动把BN和DropOut固定住,不会取平均,而是用训练好的值
self.gnn_module.eval()
# &&&&对于网络参数很大的模型,一定要在调用前加上with torch.no_grad() ,
# 只要出现关于内存问题的直接用with torch.no_grad():
with torch.no_grad():
last_layer_data, second_last_layer_data = backbone_two_stage_initialization(all_data, self.enc_module)
one_hot
supoort_hot = one_hot_encode(self.train_opt['num_ways'], support_label.long(), self.arg.device)
query_hot = one_hot_encode(self.train_opt['num_ways'], query_label.long(), self.arg.device)
y_hot = torch.cat([supoort_hot, query_hot], dim=1)
# 转置
last_layer_data_transpose = torch.transpose(last_layer_data, 1, 2)
second_last_layer_data_transpose = torch.transpose(second_last_layer_data, 1, 2)
y_hot_transpose = torch.transpose(y_hot, 1, 2)
# 类间和样本间关系
last_c_data = torch.softmax(torch.bmm(last_layer_data, last_layer_data_transpose), dim=-1)
second_c_data = torch.softmax(torch.bmm(second_last_layer_data, second_last_layer_data_transpose),
dim=-1)
c_y = torch.softmax(torch.bmm(y_hot, y_hot_transpose), dim=-1)
# 拼接
last_c = torch.cat([last_c_data, c_y], dim=-1) # [25,10,20]
second_c = torch.cat([second_c_data, c_y], dim=-1) # [25,10,20]
# print(second_c.shape)
# 用1×1卷积核进行下采样
layer_list = []
layer_list = layer_list + [
nn.Conv2d(in_channels=num_samples * 2, out_channels=num_samples, kernel_size=1, bias=False),
nn.BatchNorm2d(num_features=num_samples),
nn.LeakyReLU()]
conv = nn.Sequential(*layer_list)
c_last_f = conv.cuda()(last_c.transpose(1, 2).unsqueeze(-1)).transpose(1, 2).squeeze(-1) # [25,10,10]
# print(c_last_f.shape)
c_second_f = conv.cuda()(second_c.transpose(1, 2).unsqueeze(-1)).squeeze(-1) # [25,10,10]
# 增强自注意
x_l_last = torch.bmm(c_last_f, last_layer_data) # [25,10,128]
x_l_second = torch.bmm(c_second_f, second_last_layer_data) # [25,10,128]
y_l_last = 0.5 * y_hot + 0.5 * torch.bmm(c_last_f, y_hot) # [25,10,5]
y_l_second = 0.5 * y_hot + 0.5 * torch.bmm(c_second_f, y_hot) # [25,10,5]
# 将最终的结果拼接并且用1×1卷积下采样
F_last = torch.cat([x_l_last, y_l_last], dim=-1) # [25,10,133]
F_second = torch.cat([x_l_second, y_l_second], dim=-1)
last_layer_data = nn.Conv2d(in_channels=133, out_channels=128, kernel_size=1, bias=False).cuda()(
F_last.transpose(1, 2).unsqueeze(-1)).transpose(1, 2).squeeze(-1)
second_last_layer_data = nn.Conv2d(in_channels=133, out_channels=128, kernel_size=1, bias=False).cuda()(
F_second.transpose(1, 2).unsqueeze(-1)).transpose(1, 2).squeeze(-1)
# run the DPGN model
point_similarity, _, _ = self.gnn_module(second_last_layer_data,
last_layer_data,
node_feature_gd,
edge_feature_gd,
edge_feature_gp)
# print('111', point_similarity[5][:, 5:, :5])
# @@@@@@混淆矩阵
# if self.global_step == 80:
# labels_name = ['0', '1', '2', '3', '4']
# y_true = list(query_label.flatten().cpu().numpy())
# # print(y_true)
# y_pred_marxi = [point[:, num_supports:, :num_supports] for point in point_similarity] #查询集预测矩阵
# CM = []
#
# for i in y_pred_marxi:
# y_pred = list(torch.max(i, -1)[1].flatten().cpu().numpy()) #每一代查询集预测标签
# # print(y_pred)
# cm = confusion_matrix(y_true, y_pred) #每一代的混淆矩阵
# # print(cm)
# CM.append(cm)
# plt.figure(figsize=(8, 8), dpi=120)
# # plt.figure(figsize=(84, 84))
# for i in range(1, 7):
# p = plt.subplot(3, 3, int(i))
# cm = CM[i-1]
# # print(cm)
# cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] #归一化
# #附上数字
# for row in range(len(cm)):
# for col in range(len(cm[row])):
# plt.text(row, col, cm[col][row], va='center', ha='center')
#
# # plt.imshow(cm, interpolation='nearest')
# plt.title('gengertion' + str(i)) # 图像标题
# # plt.colorbar()
# num_local = np.array(range(len(labels_name)))
# plt.xticks(num_local, labels_name, rotation=90) # 将标签印在x轴坐标上
# plt.yticks(num_local, labels_name) # 将标签印在y轴坐标上
# plt.ylabel('y' + str(i))
# plt.xlabel('x' + str(i))
# plt.imshow(cm, interpolation='nearest')
# plt.colorbar() #子图渐变版
# plt.tight_layout() #自动调整每个图像的间距
# plt.savefig('/home/gzj/11/cmcm80.png', format='png')
# plt.show()
# exit()
# @@@@@
query_node_cls_acc_generations, query_edge_loss_generations = \
self.compute_eval_loss_pred(query_edge_loss_generations,
query_node_cls_acc_generations,
all_label_in_edge,
point_similarity,
query_edge_mask,
evaluation_mask,
num_supports,
support_label,
query_label)
# logging
if log_flag:
self.log.info('------------------------------------')
self.log.info('step : {} {}_edge_loss : {} {}_node_acc : {}'.format(
self.global_step, partition,
np.array(query_edge_loss_generations).mean(),
partition,
np.array(query_node_cls_acc_generations).mean()))
self.log.info('evaluation: total_count=%d, accuracy: mean=%.2f%%, std=%.2f%%, ci95=%.2f%%' %
(current_iteration,
np.array(query_node_cls_acc_generations).mean() * 100,
np.array(query_node_cls_acc_generations).std() * 100,
1.96 * np.array(query_node_cls_acc_generations).std()
/ np.sqrt(float(len(np.array(query_node_cls_acc_generations)))) * 100))
self.log.info('------------------------------------')
return np.array(query_node_cls_acc_generations).mean()
def compute_train_loss_pred(self,
all_label_in_edge, # [25,10,10]
point_similarities, # [25,10,10]
node_similarities_l2, # [25,10,10]
query_edge_mask, # [25,10,10],除前5行前5列外,都是1
evaluation_mask, # [25,10,10],全一矩阵
num_supports, # 5
support_label, # [25,5] ,每行都是[0,1,2,3,4]
query_label, # [25,5],每行都是[0,1,2,3,4]
distribution_similarities): # [25,10,10]
"""
compute the total loss, query classification loss and query classification accuracy
:param all_label_in_edge: ground truth label in edge form of point graph
:param point_similarities: prediction edges of point graph
:param node_similarities_l2: l2 norm of node similarities
:param query_edge_mask: mask for queries
:param evaluation_mask: mask for evaluation (for unsupervised setting)
:param num_supports: number of samples in support set
:param support_label: label of support set
:param query_label: label of query set
:param distribution_similarities: distribution-level similarities
:return: total loss
query classification accuracy
query classification loss
"""
# Point Loss
total_edge_loss_generations_instance = [ # 与self.edge_loss((point_similarity), (all_label_in_edge))结果相同
self.edge_loss((1 - point_similarity), (1 - all_label_in_edge))
# 点图的边损失,因为有6层,所以列表中有6个元素,每个元素是tensor矩阵[25,10,10]
for point_similarity
in point_similarities]
# print('P ', total_edge_loss_generations_instance[0], total_edge_loss_generations_instance[0].shape)
# Distribution Loss
total_edge_loss_generations_distribution = [ # 分配图的边损失,因为有6层,所以列表中有6个元素,每个元素是tensor矩阵[25,10,10]
self.edge_loss((1 - distribution_similarity), (1 - all_label_in_edge))
for distribution_similarity
in distribution_similarities]
# print('D ', total_edge_loss_generations_distribution[0], total_edge_loss_generations_distribution[0].shape)
# combine Point Loss and Distribution Loss #论文公式12中不带求和符号,点图损失+0.1分配图损失
distribution_loss_coeff = 0.1 if self.train_opt['loss_indicator'][2] != 0 else 0
total_edge_loss_generations = [ # 总的边损失,因为有6层,所以列表中有6个元素,每个元素是tensor矩阵[25,10,10]
total_edge_loss_instance + distribution_loss_coeff * total_edge_loss_distribution
for (total_edge_loss_instance, total_edge_loss_distribution)
in zip(total_edge_loss_generations_instance, total_edge_loss_generations_distribution)]
# print('T ', total_edge_loss_generations[0], total_edge_loss_generations[0].shape)
pos_query_edge_loss_generations = [ # 列表,有6个元素,每个元素是个tensor一维数据
torch.sum(total_edge_loss_generation * query_edge_mask * all_label_in_edge * evaluation_mask)
/ torch.sum(query_edge_mask * all_label_in_edge * evaluation_mask)
for total_edge_loss_generation
in total_edge_loss_generations]
# print('pos ', pos_query_edge_loss_generations[0], pos_query_edge_loss_generations[0].shape)
neg_query_edge_loss_generations = [ # 列表,有6个元素,每个元素是个tensor一维数据
torch.sum(total_edge_loss_generation * query_edge_mask * (1 - all_label_in_edge) * evaluation_mask)
/ torch.sum(query_edge_mask * (1 - all_label_in_edge) * evaluation_mask)
for total_edge_loss_generation
in total_edge_loss_generations]
# print('neg ', neg_query_edge_loss_generations[0], neg_query_edge_loss_generations[0].shape)
# weighted edge loss for balancing pos/neg
query_edge_loss_generations = [ # 列表,有6个元素,每个元素是个tensor一维数据
pos_query_edge_loss_generation + neg_query_edge_loss_generation
for (pos_query_edge_loss_generation, neg_query_edge_loss_generation)
in zip(pos_query_edge_loss_generations, neg_query_edge_loss_generations)]
# print('query_edge_loss ', query_edge_loss_generations[0], query_edge_loss_generations[0].shape)
# (normalized) l2 loss
# A = one_hot_encode(self.train_opt['num_ways'], support_label.long(), self.arg.device)
# B = support_label.long() #将数据化为长整型
# print(B, '\n', A)
query_node_pred_generations_ = [ # 列表,有6个元素,诶个元素是tenor矩阵[25,5,5],后面没用
torch.bmm(node_similarity_l2[:, num_supports:, :num_supports],
one_hot_encode(self.train_opt['num_ways'], support_label.long(), self.arg.device))
for node_similarity_l2
in node_similarities_l2]
# print('query_node_pred_', query_node_pred_generations_[0], query_node_pred_generations_[0].shape)
# exit()
# prediction
query_node_pred_generations = [ # 列表,有6个元素,诶个元素是tenor矩阵[25,5,5]
torch.bmm(point_similarity[:, num_supports:, :num_supports],
one_hot_encode(self.train_opt['num_ways'], support_label.long(), self.arg.device))
for point_similarity
in point_similarities]
# print('query_node_pred', query_node_pred_generations[0], query_node_pred_generations[0].shape)
# exit()
# print(query_node_pred_generations_[0].shape)
query_node_pred_loss = [ # 论文公式11
self.pred_loss(query_node_pred_generation, query_label.long()).mean()
for query_node_pred_generation
in query_node_pred_generations_]
# print(' query_node_pred_loss ', query_node_pred_loss[0], query_node_pred_loss[0].shape)
# train accuracy
query_node_acc_generations = [ # torch.max返回两个值,第一个是某行(列)的最大值,第二个是对应的索引值
torch.eq(torch.max(query_node_pred_generation, -1)[1], query_label.long()).float().mean()
# 2.py 8 #用预测返回的每行最大值的索引值来与查询集标签对比,相等为1,不等为0,然后取平均值
for query_node_pred_generation
in query_node_pred_generations]
# print(' query_node_acc_generations ', query_node_acc_generations[0], query_node_acc_generations[0].shape)
# total loss
total_loss_generations = [ # 列表,有6个元素,每个元素是个tensor一维数据
query_edge_loss_generation + 0.1 * query_node_pred_loss_
for (query_edge_loss_generation, query_node_pred_loss_)
in zip(query_edge_loss_generations, query_node_pred_loss)]
# print(' total_loss_generations ', total_loss_generations[0], total_loss_generations[0].view(-1)*0.2, total_loss_generations[0].shape)
# compute total loss
total_loss = []
num_loss = 3 if self.train_opt['num_shots'] == 1 else 6
for l in range(num_loss - 1):
total_loss = total_loss + [total_loss_generations[l].view(-1) * 0.2]
total_loss += [total_loss_generations[-1].view(-1) * 1.0]
# print('total_loss0 ', total_loss[0], total_loss[0].shape)
# print(torch.cat(total_loss, 0))
total_loss = torch.mean(torch.cat(total_loss, 0))
# print('total_loss ', total_loss)
# exit()
return total_loss, query_node_acc_generations, query_edge_loss_generations
# 1.一个tensor数据 2.列表,有6个元素,每个元素是个tensor一维数据 3. 列表,有6个元素,每个元素是个tensor一维数据
def compute_eval_loss_pred(self,
query_edge_losses,
query_node_accs,
all_label_in_edge,
point_similarities,
query_edge_mask,
evaluation_mask,
num_supports,
support_label,
query_label):
"""
compute the query classification loss and query classification accuracy
:param query_edge_losses: container for losses of queries' edges
:param query_node_accs: container for classification accuracy of queries
:param all_label_in_edge: ground truth label in edge form of point graph
:param point_similarities: prediction edges of point graph
:param query_edge_mask: mask for queries
:param evaluation_mask: mask for evaluation (for unsupervised setting)
:param num_supports: number of samples in support set
:param support_label: label of support set
:param query_label: label of query set
:return: query classification loss
query classification accuracy
"""
point_similarity = point_similarities[-1]
full_edge_loss = self.edge_loss(1 - point_similarity, 1 - all_label_in_edge)
pos_query_edge_loss = torch.sum(
full_edge_loss * query_edge_mask * all_label_in_edge * evaluation_mask) / torch.sum(
query_edge_mask * all_label_in_edge * evaluation_mask)
neg_query_edge_loss = torch.sum(
full_edge_loss * query_edge_mask * (1 - all_label_in_edge) * evaluation_mask) / torch.sum(
query_edge_mask * (1 - all_label_in_edge) * evaluation_mask)
# weighted loss for balancing pos/neg
query_edge_loss = pos_query_edge_loss + neg_query_edge_loss
# prediction
query_node_pred = torch.bmm(
point_similarity[:, num_supports:, :num_supports],
one_hot_encode(self.eval_opt['num_ways'], support_label.long(), self.arg.device))
# test accuracy
query_node_acc = torch.eq(torch.max(query_node_pred, -1)[1], query_label.long()).float().mean()
query_edge_losses += [query_edge_loss.item()]
query_node_accs += [query_node_acc.item()]
return query_node_accs, query_edge_losses
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cuda:0',
help='gpu device number of using')
parser.add_argument('--config', type=str,
default=os.path.join('.', 'config', '5way_1shot_resnet12_mini-imagenet.py'),
help='config file with parameters of the experiment. '
'It is assumed that the config file is placed under the directory ./config')
parser.add_argument('--checkpoint_dir', type=str, default=os.path.join('.', 'checkpoints'),
help='path that checkpoint will be saved and loaded. '
'It is assumed that the checkpoint file is placed under the directory ./checkpoints')
parser.add_argument('--num_gpu', type=int, default=1,
help='number of gpu')
parser.add_argument('--display_step', type=int, default=1,
help='display training information in how many step')
parser.add_argument('--log_step', type=int, default=1,
help='log information in how many steps')
parser.add_argument('--log_dir', type=str, default=os.path.join('.', 'logs'),
help='path that log will be saved. '
'It is assumed that the checkpoint file is placed under the directory ./logs')
parser.add_argument('--dataset_root', type=str, default='/home/gzj',
help='root directory of dataset')
parser.add_argument('--seed', type=int, default=222,
help='random seed')
parser.add_argument('--mode', type=str, default='train',
help='train or eval')
parser.add_argument('--fix', type=str, default='1646',
help='train or eval')
args_opt = parser.parse_args()
config_file = args_opt.config # ./config/5way_1shot_resnet12_mini-imagenet.py
# Set train and test datasets and the corresponding data loaders
config = imp.load_source("", config_file).config # 第一个参数为命名,重要的是第二个参数指定引用函数路径就行了
print('config: ', config)
train_opt = config['train_config']
eval_opt = config['eval_config']
args_opt.exp_name = '{}_{}way_{}shot_{}_{}'.format(args_opt.fix, train_opt['num_ways'],
# 5way_1shot_resnet12_mini-imagenet
train_opt['num_shots'],
config['backbone'], # resnet12
config['dataset_name']) # mini_imagenet
train_opt['num_queries'] = 1
print('train_opt: ', train_opt)
print('eval_opt: ', eval_opt)
# eval_opt['num_queries'] = 1
# 5w5s的设置
eval_opt['num_queries'] = 1
set_logging_config(
os.path.join(args_opt.log_dir, args_opt.exp_name)) # 创建./logs/5way_1shot_resnet12_mini-imagenet/log.txt
# https://blog.youkuaiyun.com/u011159607/article/details/79985087
logger = logging.getLogger('main') # 初始化
logger.info('运行模型名称: {} '.format(args_opt.fix))
# Load the configuration params of the experiment
logger.info('Launching experiment from: {}'.format(config_file))
logger.info('Generated logs will be saved to: {}'.format(args_opt.log_dir))
logger.info('Generated checkpoints will be saved to: {}'.format(args_opt.checkpoint_dir))
print()
logger.info('-------------command line arguments-------------')
logger.info(args_opt)
print()
logger.info('-------------configs-------------')
logger.info(config)
# set random seed
np.random.seed(args_opt.seed)
torch.manual_seed(args_opt.seed) # CPU设置种子用于生成随机数,以使得结果是确定的神经网络都需要初始化,保证每次随机初始化一样
torch.cuda.manual_seed_all(args_opt.seed) # 为所有当前GPU设置随机种子
random.seed(args_opt.seed)
torch.backends.cudnn.deterministic = True # 让内置的 cuDNN 的 auto-tuner 自动寻找最适合当前配置的高效算法,来达到优化运行效率的问题
torch.backends.cudnn.benchmark = False
# 判定是何种数据集 在这里默认的是mini_imagenet
# omniglot数据集。需要将main函数中所有关于验证集的东西注释掉
# if config['dataset_name'] == 'omniglot':
# dataset = Omniglot
# print('Dataset: Omniglot')
if config['dataset_name'] == 'mini-imagenet':
dataset = MiniImagenet
print('Dataset: MiniImagenet')
elif config['dataset_name'] == 'tiered-imagenet':
dataset = TieredImagenet
print('Dataset: TieredImagenet')
elif config['dataset_name'] == 'cifar-fs':
dataset = Cifar
print('Dataset: Cifar')
elif config['dataset_name'] == 'cub-200-2011':
dataset = CUB200
print('Dataset: CUB200')
else:
logger.info('Invalid dataset: {}, please specify a dataset from '
'mini-imagenet, tiered-imagenet, cifar-fs and cub-200-2011.'.format(config['dataset_name']))
# 选用嵌入网络网络模型 默认是resnet12
cifar_flag = True if args_opt.exp_name.__contains__('cifar') else False # cifar_flag为False
if config['backbone'] == 'resnet12':
enc_module = ResNet12(emb_size=config['emb_size'], cifar_flag=cifar_flag)
print('Backbone: ResNet12')
elif config['backbone'] == 'resnet18':
enc_module = ResNet18(emb_size=config['emb_size'])
print('Backbone: ResNet18')
elif config['backbone'] == 'wrn':
enc_module = WRN(emb_size=config['emb_size'])
print('Backbone: WRN')
elif config['backbone'] == 'convnet':
enc_module = ConvNet(emb_size=config['emb_size'], cifar_flag=cifar_flag)
print('Backbone: ConvNet')
else:
logger.info('`Invalid backbone: {}, please specify a backbone model from '
'convnet, resnet12, resnet18 and wrn.'.format(config['backbone']))
gnn_module = DPGN(config['num_generation'], # 6
train_opt['dropout'], # 0.1
train_opt['num_ways'] * train_opt['num_shots'], # 5
train_opt['num_ways'] * train_opt['num_shots'] + train_opt['num_ways'] * train_opt['num_queries'],
# 10
train_opt['loss_indicator']) # [1, 1, 0]
# multi-gpu configuration
[print('GPU: {} Spec: {}'.format(i, torch.cuda.get_device_name(i))) for i in range(args_opt.num_gpu)]
# 单GPU ,所以不执行
if args_opt.num_gpu > 1:
print('Construct multi-gpu model ...')
enc_module = nn.DataParallel(enc_module, device_ids=range(args_opt.num_gpu), dim=0)
gnn_module = nn.DataParallel(gnn_module, device_ids=range(args_opt.num_gpu), dim=0)
print('done!\n')
if not os.path.exists(os.path.join(args_opt.checkpoint_dir, args_opt.exp_name)):
os.makedirs(os.path.join(args_opt.checkpoint_dir, args_opt.exp_name))
logger.info('no checkpoint for model: {}, make a new one at {}'.format(
args_opt.exp_name,
os.path.join(args_opt.checkpoint_dir, args_opt.exp_name)))
best_step = 0
else: # 对已经生成的模型进行调用,实现了断点续训
if not os.path.exists(os.path.join(args_opt.checkpoint_dir, args_opt.exp_name, 'model_best.pth.tar')):
best_step = 0
else:
logger.info('find a checkpoint, loading checkpoint from {}'.format(
os.path.join(args_opt.checkpoint_dir, args_opt.exp_name)))
best_checkpoint = torch.load(os.path.join(args_opt.checkpoint_dir, args_opt.exp_name, 'model_best.pth.tar'))
logger.info('best model pack loaded')
best_step = best_checkpoint['iteration']
enc_module.load_state_dict(best_checkpoint['enc_module_state_dict'])
gnn_module.load_state_dict(best_checkpoint['gnn_module_state_dict'])
logger.info('current best test accuracy is: {}, at step: {}'.format(best_checkpoint['test_acc'], best_step))
dataset_train = dataset(root=args_opt.dataset_root, partition='train')
dataset_valid = dataset(root=args_opt.dataset_root, partition='val')
dataset_test = dataset(root=args_opt.dataset_root, partition='test')
# 见dataloader.py 278行
train_loader = DataLoader(dataset_train,
num_tasks=train_opt['batch_size'], # 25
num_ways=train_opt['num_ways'], # 5
num_shots=train_opt['num_shots'], # 1
num_queries=train_opt['num_queries'], # 1
epoch_size=train_opt['iteration']) # 100000
# print('1', train_loader.get_task_batch()[2])
# print('2', train_loader.get_task_batch()[2].shape)
# print('3', train_loader.get_task_batch()[3])
# print('4', train_loader.get_task_batch()[3].shape)
# # exit()
#
# support_label = train_loader.get_task_batch()[1]
# query_label = train_loader.get_task_batch()[3]
# all_label = torch.cat([support_label, query_label], 1)
# label_i = support_label.unsqueeze(-1).repeat(1, 1, 5) # [25,5,5]
# print('label_i ', label_i)
# label_j = label_i.transpose(1, 2) # [25,5,5]
# print('label_j', label_j)
# # compute edge
# edge = torch.eq(label_i, label_j).float()
# print('edge ', edge)
# exit()
valid_loader = DataLoader(dataset_valid,
num_tasks=eval_opt['batch_size'],
num_ways=eval_opt['num_ways'],
num_shots=eval_opt['num_shots'],
num_queries=eval_opt['num_queries'],
epoch_size=eval_opt['iteration'])
test_loader = DataLoader(dataset_test,
num_tasks=eval_opt['batch_size'],
num_ways=eval_opt['num_ways'],
num_shots=eval_opt['num_shots'],
num_queries=eval_opt['num_queries'],
epoch_size=eval_opt['iteration'])
data_loader = {'train': train_loader,
'val': valid_loader,
'test': test_loader}
# 每个batch是一个列表,且列表中有4个元素,第一个元素是tensor矩阵维度[1,25,5,3,84,84],第二个元素是teonor维度是[1,25,5]
# 第3个元素是tensor矩阵维度[1,25,5,3,84,84],第4个元素是teonor维度是[1,25,5]
# for iteration, batch in enumerate(data_loader['train']()):
# iteration = iteration
# print(iteration)
# exit()
# create trainer
trainer = DPGNTrainer(enc_module=enc_module, # resnet12
gnn_module=gnn_module, # DPGN
data_loader=data_loader,
log=logger,
arg=args_opt,
config=config,
best_step=best_step)
if args_opt.mode == 'train':
trainer.train()
elif args_opt.mode == 'eval':
trainer.eval()
else:
print('select a mode')
exit()
if name == ‘main’:
main()
import os
import logging
import torch
import shutil
def allocate_tensors():
“”"
init data tensors
:return: data tensors
“”"
tensors = dict()
tensors[‘support_data’] = torch.FloatTensor()
tensors[‘support_label’] = torch.LongTensor()
tensors[‘query_data’] = torch.FloatTensor()
tensors[‘query_label’] = torch.LongTensor()
return tensors
def set_tensors(tensors, batch):
“”"
set data to initialized tensors
:param tensors: initialized data tensors
:param batch: current batch of data
:return: None
“”"
support_data, support_label, query_data, query_label = batch
tensors[‘support_data’].resize_(support_data.size()).copy_(support_data)
tensors[‘support_label’].resize_(support_label.size()).copy_(support_label)
tensors[‘query_data’].resize_(query_data.size()).copy_(query_data)
tensors[‘query_label’].resize_(query_label.size()).copy_(query_label)
def set_logging_config(logdir):
“”"
set logging configuration
:param logdir: directory put logs
:return: None
“”"
if not os.path.exists(logdir):
os.makedirs(logdir)
logging.basicConfig(format=“[%(asctime)s] [%(name)s] %(message)s”,
level=logging.INFO,
handlers=[logging.FileHandler(os.path.join(logdir, ‘log.txt’)),
logging.StreamHandler(os.sys.stdout)])
def save_checkpoint(state, is_best, exp_name):
“”"
save the checkpoint during training stage
:param state: content to be saved
:param is_best: if DPGN model’s performance is the best at current step
:param exp_name: experiment name
:return: None
“”"
torch.save(state, os.path.join(‘{}’.format(exp_name), ‘checkpoint.pth.tar’))
if is_best:
shutil.copyfile(os.path.join(‘{}’.format(exp_name), ‘checkpoint.pth.tar’),
os.path.join(‘{}’.format(exp_name), ‘model_best.pth.tar’))
def adjust_learning_rate(optimizers, lr, iteration, dec_lr_step):
“”"
adjust learning rate after some iterations
:param optimizers: the optimizers
:param lr: learning rate
:param iteration: current iteration
:param dec_lr_step: decrease learning rate in how many step
:return: None
“”"
new_lr = lr * (0.1 ** (int(iteration / dec_lr_step)))
for optimizer in optimizers:
for param_group in optimizer.param_groups:
param_group[‘lr’] = new_lr
def label2edge(label, device):
“”"
convert ground truth labels into ground truth edges
:param label: ground truth labels
:param device: the gpu device that holds the ground truth edges
:return: ground truth edges
“”"
# get size
num_samples = label.size(1) # 当lael的维度是[25,5]时,5
# reshape
label_i = label.unsqueeze(-1).repeat(1, 1, num_samples) #[25,5,5]
label_j = label_i.transpose(1, 2) #[25,5,5]
# compute edge
edge = torch.eq(label_i, label_j).float().to(device)
return edge
def one_hot_encode(num_classes, class_idx, device):
“”"
one-hot encode the ground truth
:param num_classes: number of total class
:param class_idx: belonging class’s index
:param device: the gpu device that holds the one-hot encoded ground truth label
:return: one-hot encoded ground truth label
“”"
return torch.eye(num_classes)[class_idx].to(device)
def preprocessing(num_ways, num_shots, num_queries, batch_size, device):
“”"
prepare for train and evaluation
:param num_ways: number of classes for each few-shot task
:param num_shots: number of samples for each class in few-shot task
:param num_queries: number of queries for each class in few-shot task
:param batch_size: how many tasks per batch
:param device: the gpu device that holds all data
:return: number of samples in support set
number of total samples (support and query set)
mask for edges connect query nodes 边连接查询节点的掩码
mask for unlabeled data (for semi-supervised setting)
“”"
# set size of support set, query set and total number of data in single task
num_supports = num_ways * num_shots
num_samples = num_supports + num_queries * num_ways
#掩码的设置和EGNN相同
# set edge mask (to distinguish support and query edges) 设置边掩码(用于区分支持边和查询边)
support_edge_mask = torch.zeros(batch_size, num_samples, num_samples).to(device) #[25,10,10]
support_edge_mask[:, :num_supports, :num_supports] = 1
query_edge_mask = 1 - support_edge_mask
evaluation_mask = torch.ones(batch_size, num_samples, num_samples).to(device)
return num_supports, num_samples, query_edge_mask, evaluation_mask
# 5 25 1 5
def initialize_nodes_edges(batch, num_supports, tensors, batch_size, num_queries, num_ways, device):
“”"
:param batch: data batch
:param num_supports: number of samples in support set
:param tensors: initialized tensors for holding data
:param batch_size: how many tasks per batch
:param num_queries: number of samples in query set
:param num_ways: number of classes for each few-shot task
:param device: the gpu device that holds all data
:return: data of support set,
label of support set,
data of query set,
label of query set,
data of support and query set,
label of support and query set,
initialized node features of distribution graph (Vd_(0)),
initialized edge features of point graph (Ep_(0)),
initialized edge_features_of distribution graph (Ed_(0))
"""
# allocate data in this batch to specific variables
set_tensors(tensors, batch)
support_data = tensors['support_data'].squeeze(0) #将第一维度为1的维度去掉
support_label = tensors['support_label'].squeeze(0)
query_data = tensors['query_data'].squeeze(0)
query_label = tensors['query_label'].squeeze(0)
# initialize nodes of distribution graph
node_gd_init_support = label2edge(support_label, device) #维度为[25,5,5],对角线为1 main.py 582~589
node_gd_init_query = (torch.ones([batch_size, num_queries * num_ways, num_supports])
* torch.tensor(1. / num_supports)).to(device)
node_feature_gd = torch.cat([node_gd_init_support, node_gd_init_query], dim=1) #[25,10,5],前5行对角线为1,后5行元素是0.2
# initialize edges of point graph
all_data = torch.cat([support_data, query_data], 1) #[25,10,3,84,84]
all_label = torch.cat([support_label, query_label], 1) #[25,10]
all_label_in_edge = label2edge(all_label, device) #[25,10,10],main.py 586
edge_feature_gp = all_label_in_edge.clone() #y=x.clyon(),y与x一样,但是改变y不会改变x
# uniform initialization for point graph's edges
edge_feature_gp[:, num_supports:, :num_supports] = 1. / num_supports #后5行的前5列是0.2
edge_feature_gp[:, :num_supports, num_supports:] = 1. / num_supports #前5行的后5列是0.2
edge_feature_gp[:, num_supports:, num_supports:] = 0 #后5行的后5列是0
for i in range(num_ways * num_queries):
edge_feature_gp[:, num_supports + i, num_supports + i] = 1 #指定元素设为1
# initialize edges of distribution graph (same as point graph)
edge_feature_gd = edge_feature_gp.clone() #25,10,10]
return support_data, support_label, query_data, query_label, all_data, all_label_in_edge, \
node_feature_gd, edge_feature_gp, edge_feature_gd
def backbone_two_stage_initialization(full_data, encoder):
“”"
encode raw data by backbone network
:param full_data: raw data
:param encoder: backbone network
:return: last layer logits from backbone network
second last layer logits from backbone network
“”"
# encode data
last_layer_data_temp = []
second_last_layer_data_temp = []
for data in full_data.chunk(full_data.size(1), dim=1): #在dim=1上的维度数据切成full_data.size(1)块
# the encode step #将[25,10,3,84,84]切成10块[25,1,3,84,84],那么每一块中的25个数据都是同一类图图像数据
encoded_result = encoder(data.squeeze(1)) #返回两个值,且维度都是[25,128]
# prepare for two stage initialization of DPGN
last_layer_data_temp.append(encoded_result[0])
second_last_layer_data_temp.append(encoded_result[1])
# last_layer_data: (batch_size, num_samples, embedding dimension)
last_layer_data = torch.stack(last_layer_data_temp, dim=1)
# second_last_layer_data: (batch_size, num_samples, embedding dimension)
second_last_layer_data = torch.stack(second_last_layer_data_temp, dim=1)
return last_layer_data, second_last_layer_data
# [25,10,128] [25,10,128]