Pytorch|YOWO原理及代码详解(三)

Pytorch|YOWO原理及代码详解(三)

本博客上接,
Pytorch|YOWO原理及代码详解(一)
Pytorch|YOWO原理及代码详解(二),阅前可看。

1. test分析

    if opt.evaluate:
        logging('evaluating ...')
        test(0)
    else:
        for epoch in range(opt.begin_epoch, opt.end_epoch + 1):
            # Train the model for 1 epoch
            train(epoch)

            # Validate the model
            fscore = test(epoch)

            is_best = fscore > best_fscore
            if is_best:
                print("New best fscore is achieved: ", fscore)
                print("Previous fscore was: ", best_fscore)
                best_fscore = fscore

            # Save the model to backup directory
            state = {
   
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'fscore': fscore
            }
            save_checkpoint(state, is_best, backupdir, opt.dataset, clip_duration)
            logging('Weights are saved to backup directory: %s' % (backupdir))

上一节把train的整个流程分析完毕,本节主要分析test流程:fscore = test(epoch),进入(step into),查看完整的代码如下:

    def test(epoch):
        def truths_length(truths):
            for i in range(50):
                if truths[i][1] == 0:
                    return i

        test_loader = torch.utils.data.DataLoader(
            dataset.listDataset(basepath, testlist, dataset_use=dataset_use, shape=(init_width, init_height),
                                shuffle=False,
                                transform=transforms.Compose([
                                    transforms.ToTensor()
                                ]), train=False),
            batch_size=batch_size, shuffle=False, **kwargs)

        num_classes = region_loss.num_classes
        anchors = region_loss.anchors
        num_anchors = region_loss.num_anchors
        conf_thresh_valid = 0.005
        total = 0.0
        proposals = 0.0
        correct = 0.0
        fscore = 0.0

        correct_classification = 0.0
        total_detected = 0.0

        nbatch = file_lines(testlist) // batch_size

        logging('validation at epoch %d' % (epoch))
        model.eval()

        for batch_idx, (frame_idx, data, target) in enumerate(test_loader):
            if use_cuda:
                data = data.cuda()
            with torch.no_grad():
                output = model(data).data
                all_boxes = get_region_boxes(output, conf_thresh_valid, num_classes, anchors, num_anchors, 0, 1)
                for i in range(output.size(0)):
                    boxes = all_boxes[i]
                    boxes = nms(boxes, nms_thresh)
                    if dataset_use == 'ucf101-24':
                        detection_path = os.path.join('ucf_detections', 'detections_' + str(epoch), frame_idx[i])
                        current_dir = os.path.join('ucf_detections', 'detections_' + str(epoch))
                        if not os.path.exists('ucf_detections'):
                            os.makedirs(current_dir)
                        if not os.path.exists(current_dir):
                            os.makedirs(current_dir)
                    else:
                        detection_path = os.path.join('jhmdb_detections', 'detections_' + str(epoch), frame_idx[i])
                        current_dir = os.path.join('jhmdb_detections', 'detections_' + str(epoch))
                        if not os.path.exists('jhmdb_detections'):
                            os.mkdir(current_dir)
                        if not os.path.exists(current_dir):
                            os.mkdir(current_dir)

                    with open(detection_path, 'w+') as f_detect:
                        for box in boxes:
                            x1 = round(float(box[0] - box[2] / 2.0) * 320.0)
                            y1 = round(float(box[1] - box[3] / 2.0) * 240.0)
                            x2 = round(float(box[0] + box[2] / 2.0) * 320.0)
                            y2 = round(float(box[1] + box[3] / 2.0) * 240.0)

                            det_conf = float(box[4])
                            for j in range((len(box) - 5) // 2):
                                cls_conf = float(box[5 + 2 * j].item())

                                if type(box[6 + 2 * j]) == torch.Tensor:
                                    cls_id = int(box[6 + 2 * j].item())
                                else:
                                    cls_id = int(box[6 + 2 * j])
                                prob = det_conf * cls_conf

                                f_detect.write(
                                    str(int(box[6]) + 1) + ' ' + str(prob) + ' ' + str(x1) + ' ' + str(y1) + ' ' + str(
                                        x2) + ' ' + str(y2) + '\n')
                    truths = target[i].view(-1, 5)
                    num_gts = truths_length(truths)

                    total = total + num_gts

                    for i in range(len(boxes)):
                        if boxes[i][4] > 0.25:
                            proposals = proposals + 1

                    for i in range(num_gts):
                        box_gt = [truths[i][1], truths[i][2], truths[i][3], truths[i][4], 1.0, 1.0, truths[i][0]]
                        best_iou = 0
                        best_j = -1
                        for j in range(len(boxes)):
                            iou = bbox_iou(box_gt, boxes[j], x1y1x2y2=False)
                            if iou > best_iou:
                                best_j 
评论 30
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值