Siammask代码阅读笔记(一)

test_mask_refine.sh运行过程详细分析
博客详细分析了指令的运行过程,从程序入口test_mask_refine.sh开始,介绍其运行/test.py的参数,还讲解了Linux shell中‘2>&1’含义和tee命令。接着阐述了$SiamMask/tools/test.py用parser解析参数,遍历数据集视频调用track_vot()函数。

详细分析如下指令的运行过程

# Evaluate performance on VOT
bash test_mask_refine.sh config_vot.json SiamMask_VOT.pth VOT2018 0

1. 程序入口 test_mask_refine.sh

Evaluate performance on VOT

bash test_mask_refine.sh config_vot.json SiamMask_VOT.pth VOT2016 0

查看文件内容 “$SiamMask/experiments/siammask_sharp/test_mask_refine.sh ”

CUDA_VISIBLE_DEVICES=$gpu python -u ../../tools/test.py \
    --config $config \
    --resume $model \
    --mask --refine \
    --dataset $dataset 2>&1 | tee logs/test_$dataset.log

1, 运行 …/…/tools/test.py
2, 输入到 /test.py 的参数包括:
(1) config file: config_vot.json; (2)mode: SiamMask_VOT.pth; (3)dataset:VOT2016; (4)gpu:0
3,关于本语句的语法。
(a)linux shell中"2>&1"含义
https://www.cnblogs.com/zhenghongxin/p/7029173.html
对于& 1 更准确的说应该是文件描述符 1,而1标识标准输出,stdout。
对于2 ,表示标准错误,stderr。
2>&1 的意思就是将标准错误重定向到标准输出。
(b)为初学者介绍的 Linux tee 命令(6 个例子)
https://linux.cn/article-9435-1.html
有时候,你会想手动跟踪命令的输出内容,同时又想将输出的内容写入文件,确保之后可以用来参考。如果你想寻找这相关的工具,那么恭喜你,Linux 已经有了一个叫做 tee 的命令可以帮助你。
tee 命令:从标准输入中复制到每一个文件,并输出到标准输出。

2. $SiamMask/tools/test.py

def main():
    global args, logger, v_id
    args = parser.parse_args()
    cfg = load_config(args)

使用 parser 解析参数。

参数的默认值如下:

parser = argparse.ArgumentParser(description='Test SiamMask')
parser.add_argument('--arch', dest='arch', default='', choices=['Custom',],
                    help='architecture of pretrained model')
parser.add_argument('--config', dest='config', required=True, help='hyper-parameter for SiamMask')
parser.add_argument('--resume', default='', type=str, required=True,
                    metavar='PATH', help='path to latest checkpoint (default: none)')
parser.add_argument('--mask', action='store_true', help='whether use mask output')
parser.add_argument('--refine', action='store_true', help='whether use mask refine output')
parser.add_argument('--dataset', dest='dataset', default='VOT2018', choices=dataset_zoo,
                    help='datasets')
parser.add_argument('-l', '--log', default="log_test.txt", type=str, help='log file')
parser.add_argument('-v', '--visualization', dest='visualization', action='store_true',
                    help='whether visualize result')
parser.add_argument('--save_mask', action='store_true', help='whether use save mask for davis')
parser.add_argument('--gt', action='store_true', help='whether use gt rect for davis (Oracle)')
parser.add_argument('--video', default='', type=str, help='test special video')
parser.add_argument('--cpu', action='store_true', help='cpu mode')
parser.add_argument('--debug', action='store_true', help='debug mode')

遍历数据集中的所有视频,对于每个视频调用一次 track_vot() 函数。

    for v_id, video in enumerate(dataset.keys(), start=1):
        if args.video != '' and video != args.video:
            continue

        if vos_enable:
            iou_list, speed = track_vos(model, dataset[video], cfg['hp'] if 'hp' in cfg.keys() else None,
                                 args.mask, args.refine, args.dataset in ['DAVIS2017', 'ytb_vos'], device=device)
            iou_lists.append(iou_list)
        else:
            lost, speed = track_vot(model, dataset[video], cfg['hp'] if 'hp' in cfg.keys() else None,
                             args.mask, args.refine, device=device)
            total_lost += lost
        speed_list.append(speed)

3. track_vot() 函数


def track_vot(model, video, hp=None, mask_enable=False, refine_enable=False, device='cpu'):
    regions = []  # result and states[1 init / 2 lost / 0 skip]
    image_files, gt = video['image_files'], video['gt']

    start_frame, end_frame, lost_times, toc = 0, len(image_files), 0, 0

    # 遍历当前视频中的所有图像
    for f, image_file in enumerate(image_files):
        im = cv2.imread(image_file)
        tic = cv2.getTickCount()
        
        if f == start_frame:  # init
            # 初始化跟踪器
            cx, cy, w, h = get_axis_aligned_bbox(gt[f])
            target_pos = np.array([cx, cy])
            target_sz = np.array([w, h])
            state = siamese_init(im, target_pos, target_sz, model, hp, device)  # init tracker
            location = cxy_wh_2_rect(state['target_pos'], state['target_sz'])
            regions.append(1 if 'VOT' in args.dataset else gt[f])
        
        elif f > start_frame:  # tracking
            # 调用函数siamese_track(),对当前图像运行目标跟踪算法
            state = siamese_track(state, im, mask_enable, refine_enable, device, args.debug)  # track
            
            if mask_enable:
                # 启用mask选项的时候,保存当前跟踪结果:location 和 mask
                location = state['ploygon'].flatten()
                mask = state['mask']
            else:
                location = cxy_wh_2_rect(state['target_pos'], state['target_sz'])
                mask = []

            if 'VOT' in args.dataset:
                gt_polygon = ((gt[f][0], gt[f][1]), (gt[f][2], gt[f][3]),
                              (gt[f][4], gt[f][5]), (gt[f][6], gt[f][7]))
                if mask_enable:
                    pred_polygon = ((location[0], location[1]), (location[2], location[3]),
                                    (location[4], location[5]), (location[6], location[7]))
                else:
                    pred_polygon = ((location[0], location[1]),
                                    (location[0] + location[2], location[1]),
                                    (location[0] + location[2], location[1] + location[3]),
                                    (location[0], location[1] + location[3]))
                # 判断当前跟踪算法的输出框 和 目标真实位置框 是否有重叠区域。
                b_overlap = vot_overlap(gt_polygon, pred_polygon, (im.shape[1], im.shape[0]))
            else:
                b_overlap = 1

            if b_overlap:
                regions.append(location)
            else:  # lost
                # 目标丢失,需要间隔5帧之后重启跟踪器
                regions.append(2)
                lost_times += 1
                start_frame = f + 5  # skip 5 frames
        else:  # skip
            regions.append(0)
        toc += cv2.getTickCount() - tic

    toc /= cv2.getTickFrequency()

    # save result
    name = args.arch.split('.')[0] + '_' + ('mask_' if mask_enable else '') + ('refine_' if refine_enable else '') +\
           args.resume.split('/')[-1].split('.')[0]

    if 'VOT' in args.dataset:
        video_path = join('test', args.dataset, name,
                          'baseline', video['name'])
        if not isdir(video_path): makedirs(video_path)
        result_path = join(video_path, '{:s}_001.txt'.format(video['name']))
        with open(result_path, "w") as fin:
            for x in regions:
                fin.write("{:d}\n".format(x)) if isinstance(x, int) else \
                        fin.write(','.join([vot_float2str("%.4f", i) for i in x]) + '\n')
    else:  # OTB
        video_path = join('test', args.dataset, name)
        if not isdir(video_path): makedirs(video_path)
        result_path = join(video_path, '{:s}.txt'.format(video['name']))
        with open(result_path, "w") as fin:
            for x in regions:
                fin.write(','.join([str(i) for i in x])+'\n')

    logger.info('({:d}) Video: {:12s} Time: {:02.1f}s Speed: {:3.1f}fps Lost: {:d}'.format(
        v_id, video['name'], toc, f / toc, lost_times))

    return lost_times, f / toc

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值