基于视频行人重识别的可视化代码(完整实现)
文章目录
前言
目前基于视频行人重识别的可视化代码较少,该代码参考了另一位博主的文章在代码中使用的数据集为Mars数据集。
一、使用步骤
1.引入库
代码如下(示例):
from __future__ import print_function, absolute_import
import scipy.io
import os
from torch.utils.data import DataLoader
import torchvision.transforms as T
from utils import data_manager
from utils.video_loader import VideoDataset, VideoDatasetInfer
import argparse
import numpy as np
import torch
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
2.数据数据读取
代码如下(示例):
# dataleader
def data_loader():
dataset = data_manager.init_dataset(name=args.dataset, root=args.root)
transform_test = T.Compose([
T.Resize((args.height, args.width)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
pin_memory = True if use_gpu else False
queryloader = DataLoader(
VideoDatasetInfer(dataset.query, seq_len=args.seq_len, temporal_sampler='restricted', spatial_transform=transform_test),
batch_size=args.test_batch, shuffle=False, num_workers=args.workers,
pin_memory=pin_memory, drop_last=False,
)
galleryloader = DataLoader(
VideoDatasetInfer(dataset.gallery, seq_len=args.seq_len, temporal_sampler='restricted', spatial_transform=transform_test),
batch_size=args.test_batch, shuffle=False, num_workers=args.workers,
pin_memory=pin_memory, drop_last=False,
)
return dataset, queryloader, galleryloader
3.保存测试结果
将测试的结果保存,以便后续的ranklist的可视化,该段代码插入到测试部分的代码,将query和gallery的特征保存下来,即query_feature.pt、gallery_feature.pt、save_result.mat。
代码示例如下:
def test(model, queryloader, galleryloader, use_gpu, ranks=[1, 5, 10, 20]):
since = time.time()
model.eval()
if args.all_frames:
feat_func = extract_feat_all_frames
else:
feat_func = extract_feat_sampled_frames
qf, q_pids, q_camids = _feats_of_loader(
model,
queryloader,
feat_func,
use_gpu=use_gpu)
print_time("Extracted features for query set, obtained {} matrix".format(qf.shape))
gf, g_pids, g_camids = _feats_of_loader(
model,
galleryloader,
feat_func,
use_gpu=use_gpu)
torch.save(qf, 'query_feature.pt')
torch.save(gf, 'gallery_feature.pt')
io.savemat('save_result.mat', {'query_cam': q_camids,
'query_label': q_pids,
'gallery_cam': g_camids,
'gallery_label': g_pids,
})
print_time("Extracted features for gallery set, obtained {} matrix".format(gf.shape))
if args.dataset == 'mars':
# gallery set must contain query set, otherwise 140 query imgs will not have ground truth.
gf = torch.cat((qf, gf), 0)
g_pids = np.append(q_pids, g_pids)
g_camids = np.append(q_camids, g_camids)
time_elapsed = time.time() - since
print_time('Extracting features complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print_time("Computing distance matrix")
distmat = _cal_dist(qf=qf, gf=gf, distance=args.distance)
print_time("Computing CMC and mAP")
cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids)
_eval_format_logger(cmc, mAP, ranks, '')
return cmc[0]
4.加载数据集和保存的特征文件
代码示例如下:
image_datasets = data_loader()[0]
# print(image_datasets)
result = scipy.io.loadmat('/media/lele/e/zzg/sinet/save_result.mat')
query_feature = torch.load('/media/lele/e/zzg/sinet/query_feature.pt')
gallery_feature = torch.load('/media/lele/e/zzg/sinet/gallery_feature.pt')
query_cam = result['query_cam'][0]
query_label = result['query_label'][0]
gallery_cam = result['gallery_cam'][0]
gallery_label = result['gallery_label'][0]
query_feature = query_feature.cuda()
gallery_feature = gallery_feature.cuda()
5.可视化ranklist
代码示例如下:
# sort the images
def sort_img(qf, ql, qc, gf, gl, gc):
query = qf.view(-1,1)
# print(query.shape)
score = torch.mm(gf,query)
score = score.squeeze(1).cpu()
score = score.numpy()
# predict index
index = np.argsort(score) #from small to large
index = index[::-1]
# index = index[0:2000]
# good index
query_index = np.argwhere(gl==ql)
#same camera
camera_index = np.argwhere(gc==qc)
#good_index = np.setdiff1d(query_index, camera_index, assume_unique=True)
junk_index1 = np.argwhere(gl==-1)
junk_index2 = np.intersect1d(query_index, camera_index)
junk_index = np.append(junk_index2, junk_index1)
mask = np.in1d(index, junk_index, invert=True)
index = index[mask]
return index
i = args.query_index
index = sort_img(query_feature[i],query_label[i],query_cam[i],gallery_feature,gallery_label,gallery_cam)
########################################################################
# Visualize the rank result
img_paths, pid, camid = image_datasets.query[i]
# img_paths, _ = image_datasets['query'].imgs[i]
print(img_paths)
query_label = query_label[i]
fig = plt.figure(figsize=(16, 4))
for i in range(2):
for j in range(3):
ax = plt.subplot(2, 15, (1 if i == 0 else 16) + j)
ax.axis('off')
imshow(img_paths[(0 if i == 0 else 2) + j], 'query')
for i in range(4): # 要做其他Rank排序,将range(4)改成想要的list就行
for j in range(2):
for k in range(3):
ax = plt.subplot(2, 15, (1 if j == 0 else 16) + i * 3 + k + 3)
ax.axis('off')
img_path, pid, camid = image_datasets.gallery[index[i]]
label = gallery_label[index[i]]
imshow(img_path[(0 if j == 0 else 2) + k])
if label == query_label:
ax.set_title('%d' % (i + 1), color='green')
else:
ax.set_title('%d' % (i + 1), color='red')
fig.savefig("/media/lele/e/zzg/sinet/show.png")
二、 基于视频行人重识别结果可视化完整代码:
代码示例如下
from __future__ import print_function, absolute_import
import scipy.io
import os
from torch.utils.data import DataLoader
import torchvision.transforms as T
from utils import data_manager
from utils.video_loader import VideoDataset, VideoDatasetInfer
import argparse
import numpy as np
import torch
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
#######################################################################
# Evaluate
parser = argparse.ArgumentParser(description='Demo')
parser.add_argument('--query_index', default=61, type=int, help='test_image_index')
# parser.add_argument('--test_dir',default='/media/lele/c/zuozhigang/mars/bbox_test',type=str, help='./test_data')
parser.add_argument('--root', type=str, default='/media/lele/c/zuozhigang/mars/')
parser.add_argument('-d', '--dataset', type=str, default='mars',
choices=data_manager.get_names())
parser.add_argument('-j', '--workers', default=0, type=int,
help="number of data loading workers (default: 4)")
parser.add_argument('--height', type=int, default=256,
help="height of an image (default: 256)")
parser.add_argument('--width', type=int, default=128,
help="width of an image (default: 128)")
# Augment
parser.add_argument('--sample_stride', type=int, default=8, help="stride of images to sample in a tracklet")
# Optimization options
parser.add_argument('--max_epoch', default=160, type=int,
help="maximum epochs to run")
parser.add_argument('--start_epoch', default=0, type=int,
help="manual epoch number (useful on restarts)")
parser.add_argument('--train_batch', default=32, type=int,
help="train batch size")
parser.add_argument('--test_batch', default=32, type=int, help="has to be 1")
parser.add_argument('--lr', '--learning-rate', default=0.0003, type=float,
help="initial learning rate, use 0.0001 for rnn, use 0.0003 for pooling and attention")
parser.add_argument('--stepsize', default=[40, 80, 120], nargs='+', type=int,
help="stepsize to decay learning rate")
parser.add_argument('--gamma', default=0.1, type=float,
help="learning rate decay")
parser.add_argument('--weight_decay', default=5e-04, type=float,
help="weight decay (default: 5e-04)")
parser.add_argument('--margin', type=float, default=0.3, help="margin for triplet loss")
parser.add_argument('--distance', type=str, default='cosine', help="euclidean or consine")
parser.add_argument('--num_instances', type=int, default=4, help="number of instances per identity")
parser.add_argument('--losses', default=['xent', 'htri'], nargs='+', type=str, help="losses")
# Architecture
parser.add_argument('-a', '--arch', type=str, default='sinet', help="c2resnet50, nonlocalresnet50")
parser.add_argument('--pretrain', action='store_true', help="load params form pretrain model on kinetics")
parser.add_argument('--pretrain_model_path', type=str, default='', metavar='PATH')
# Miscs
parser.add_argument('--seed', type=int, default=1, help="manual seed")
parser.add_argument('--resume', type=str, default='/media/lele/e/zzg/sinet/logs/best_model.pth.tar', metavar='PATH')
parser.add_argument('--evaluate', action='store_true', help="evaluation only")
parser.add_argument('--eval_step', type=int, default=10,
help="run evaluation for every N epochs (set to -1 to test after training)")
parser.add_argument('--start_eval', type=int, default=0, help="start to evaluate after specific epoch")
parser.add_argument('--save_dir', '--sd', type=str, default='/media/lele/e/zzg/sinet/logs/')
parser.add_argument('--use_cpu', action='store_true', help="use cpu")
parser.add_argument('--gpu_devices', default='0', type=str, help='gpu device ids for CUDA_VISIBLE_DEVICES')
parser.add_argument('--all_frames', action='store_true', help="evaluate with all frames ?")
parser.add_argument('--seq_len', type=int, default=4,
help="number of images to sample in a tracklet")
parser.add_argument('--note', type=str, default='', help='additional description of this command')
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
use_gpu = torch.cuda.is_available()
if args.use_cpu: use_gpu = False
# dataleader
def data_loader():
dataset = data_manager.init_dataset(name=args.dataset, root=args.root)
transform_test = T.Compose([
T.Resize((args.height, args.width)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
pin_memory = True if use_gpu else False
queryloader = DataLoader(
VideoDatasetInfer(dataset.query, seq_len=args.seq_len, temporal_sampler='restricted', spatial_transform=transform_test),
batch_size=args.test_batch, shuffle=False, num_workers=args.workers,
pin_memory=pin_memory, drop_last=False,
)
galleryloader = DataLoader(
VideoDatasetInfer(dataset.gallery, seq_len=args.seq_len, temporal_sampler='restricted', spatial_transform=transform_test),
batch_size=args.test_batch, shuffle=False, num_workers=args.workers,
pin_memory=pin_memory, drop_last=False,
)
return dataset, queryloader, galleryloader
# data_dir = opts.test_dir
# image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ) for x in ['gallery','query']}
#####################################################################
#Show result
def imshow(path, title=None):
"""Imshow for Tensor."""
im = plt.imread(path)
plt.imshow(im)
if title is not None:
plt.title(title)
plt.pause(0.001) # pause a bit so that plots are updated
######################################################################
image_datasets = data_loader()[0]
# print(image_datasets)
result = scipy.io.loadmat('/media/lele/e/zzg/sinet/save_result.mat')
query_feature = torch.load('/media/lele/e/zzg/sinet/query_feature.pt')
gallery_feature = torch.load('/media/lele/e/zzg/sinet/gallery_feature.pt')
query_cam = result['query_cam'][0]
query_label = result['query_label'][0]
gallery_cam = result['gallery_cam'][0]
gallery_label = result['gallery_label'][0]
query_feature = query_feature.cuda()
gallery_feature = gallery_feature.cuda()
#######################################################################
# sort the images
def sort_img(qf, ql, qc, gf, gl, gc):
query = qf.view(-1,1)
# print(query.shape)
score = torch.mm(gf,query)
score = score.squeeze(1).cpu()
score = score.numpy()
# predict index
index = np.argsort(score) #from small to large
index = index[::-1]
# index = index[0:2000]
# good index
query_index = np.argwhere(gl==ql)
#same camera
camera_index = np.argwhere(gc==qc)
#good_index = np.setdiff1d(query_index, camera_index, assume_unique=True)
junk_index1 = np.argwhere(gl==-1)
junk_index2 = np.intersect1d(query_index, camera_index)
junk_index = np.append(junk_index2, junk_index1)
mask = np.in1d(index, junk_index, invert=True)
index = index[mask]
return index
i = args.query_index
index = sort_img(query_feature[i],query_label[i],query_cam[i],gallery_feature,gallery_label,gallery_cam)
########################################################################
# Visualize the rank result
img_paths, pid, camid = image_datasets.query[i]
# img_paths, _ = image_datasets['query'].imgs[i]
print(img_paths)
query_label = query_label[i]
fig = plt.figure(figsize=(16, 4))
for i in range(2):
for j in range(3):
ax = plt.subplot(2, 15, (1 if i == 0 else 16) + j)
ax.axis('off')
imshow(img_paths[(0 if i == 0 else 2) + j], 'query')
for i in range(4): # 要做其他Rank排序,将range(4)改成想要的list就行
for j in range(2):
for k in range(3):
ax = plt.subplot(2, 15, (1 if j == 0 else 16) + i * 3 + k + 3)
ax.axis('off')
img_path, pid, camid = image_datasets.gallery[index[i]]
label = gallery_label[index[i]]
imshow(img_path[(0 if j == 0 else 2) + k])
if label == query_label:
ax.set_title('%d' % (i + 1), color='green')
else:
ax.set_title('%d' % (i + 1), color='red')
fig.savefig("/media/lele/e/zzg/sinet/show.png")