Pytorch|YOWO原理及代码详解(三)
本博客上接,
Pytorch|YOWO原理及代码详解(一),
Pytorch|YOWO原理及代码详解(二),阅前可看。
1. test分析
if opt.evaluate:
logging('evaluating ...')
test(0)
else:
for epoch in range(opt.begin_epoch, opt.end_epoch + 1):
# Train the model for 1 epoch
train(epoch)
# Validate the model
fscore = test(epoch)
is_best = fscore > best_fscore
if is_best:
print("New best fscore is achieved: ", fscore)
print("Previous fscore was: ", best_fscore)
best_fscore = fscore
# Save the model to backup directory
state = {
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'fscore': fscore
}
save_checkpoint(state, is_best, backupdir, opt.dataset, clip_duration)
logging('Weights are saved to backup directory: %s' % (backupdir))
上一节把train的整个流程分析完毕,本节主要分析test流程:fscore = test(epoch)
,进入(step into),查看完整的代码如下:
def test(epoch):
def truths_length(truths):
for i in range(50):
if truths[i][1] == 0:
return i
test_loader = torch.utils.data.DataLoader(
dataset.listDataset(basepath, testlist, dataset_use=dataset_use, shape=(init_width, init_height),
shuffle=False,
transform=transforms.Compose([
transforms.ToTensor()
]), train=False),
batch_size=batch_size, shuffle=False, **kwargs)
num_classes = region_loss.num_classes
anchors = region_loss.anchors
num_anchors = region_loss.num_anchors
conf_thresh_valid = 0.005
total = 0.0
proposals = 0.0
correct = 0.0
fscore = 0.0
correct_classification = 0.0
total_detected = 0.0
nbatch = file_lines(testlist) // batch_size
logging('validation at epoch %d' % (epoch))
model.eval()
for batch_idx, (frame_idx, data, target) in enumerate(test_loader):
if use_cuda:
data = data.cuda()
with torch.no_grad():
output = model(data).data
all_boxes = get_region_boxes(output, conf_thresh_valid, num_classes, anchors, num_anchors, 0, 1)
for i in range(output.size(0)):
boxes = all_boxes[i]
boxes = nms(boxes, nms_thresh)
if dataset_use == 'ucf101-24':
detection_path = os.path.join('ucf_detections', 'detections_' + str(epoch), frame_idx[i])
current_dir = os.path.join('ucf_detections', 'detections_' + str(epoch))
if not os.path.exists('ucf_detections'):
os.makedirs(current_dir)
if not os.path.exists(current_dir):
os.makedirs(current_dir)
else:
detection_path = os.path.join('jhmdb_detections', 'detections_' + str(epoch), frame_idx[i])
current_dir = os.path.join('jhmdb_detections', 'detections_' + str(epoch))
if not os.path.exists('jhmdb_detections'):
os.mkdir(current_dir)
if not os.path.exists(current_dir):
os.mkdir(current_dir)
with open(detection_path, 'w+') as f_detect:
for box in boxes:
x1 = round(float(box[0] - box[2] / 2.0) * 320.0)
y1 = round(float(box[1] - box[3] / 2.0) * 240.0)
x2 = round(float(box[0] + box[2] / 2.0) * 320.0)
y2 = round(float(box[1] + box[3] / 2.0) * 240.0)
det_conf = float(box[4])
for j in range((len(box) - 5) // 2):
cls_conf = float(box[5 + 2 * j].item())
if type(box[6 + 2 * j]) == torch.Tensor:
cls_id = int(box[6 + 2 * j].item())
else:
cls_id = int(box[6 + 2 * j])
prob = det_conf * cls_conf
f_detect.write(
str(int(box[6]) + 1) + ' ' + str(prob) + ' ' + str(x1) + ' ' + str(y1) + ' ' + str(
x2) + ' ' + str(y2) + '\n')
truths = target[i].view(-1, 5)
num_gts = truths_length(truths)
total = total + num_gts
for i in range(len(boxes)):
if boxes[i][4] > 0.25:
proposals = proposals + 1
for i in range(num_gts):
box_gt = [truths[i][1], truths[i][2], truths[i][3], truths[i][4], 1.0, 1.0, truths[i][0]]
best_iou = 0
best_j = -1
for j in range(len(boxes)):
iou = bbox_iou(box_gt, boxes[j], x1y1x2y2=False)
if iou > best_iou:
best_j