test.py修改
pyskl的测试文件和训练文件一样,不支持单GPU测试,因此需要修改代码,屏蔽分布式测试代码
修改后代码如下:
# Copyright (c) OpenMMLab. All rights reserved.
# flake8: noqa: E722
import argparse
import os
import os.path as osp
import time
import mmcv
import torch
import torch.distributed as dist
from mmcv import Config, load
from mmcv.cnn import fuse_conv_bn
from mmcv.engine import multi_gpu_test, single_gpu_test
from mmcv.fileio.io import file_handlers
from mmcv.parallel import MMDistributedDataParallel, MMDataParallel
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
from pyskl.datasets import build_dataloader, build_dataset
from pyskl.models import build_model
from pyskl.utils import mc_off, mc_on, test_port
def parse_args():
parser = argparse.ArgumentParser(
description='pyskl test (and eval) a model')
parser.add_argument('--config',default='../configs/posec3d/c3d_light_gym/joint.py', help='test config file path')
parser.add_argument('-C', '--checkpoint', help='checkpoint file', default='latest.pth')
parser.add_argument(
'--out',
default=None,
help='output result file in pkl/yaml/json format')
parser.add_argument(
'--fuse-conv-bn',
action='store_true',
help='Whether to fuse conv and bn, this will slightly increase'
'the inference speed')
parser.add_argument(
'--eval',
type=str,
nargs='+',
default=['top_k_accuracy', 'mean_class_accuracy'],
help='evaluation metrics, which depends on the dataset, e.g.,'
' "top_k_accuracy", "mean_class_accuracy" for video dataset')
parser.add_argument(
'--tmpdir',
help='tmp directory used for collecting results from multiple workers')
parser.add_argument(
'--average-clips',
choices=['score', 'prob', None],
default=None,
help='average type when averaging test clips')
parser.add_argument(
'--launcher',
choices=['pytorch', 'slurm'],
default='pytorch',
help='job launcher')
# parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
# if 'LOCAL_RANK' not in os.environ:
# os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def inference_pytorch(args, cfg, data_loader):
"""Get predictions by pytorch models."""
if args.average_clips is not None:
# You can set average_clips during testing, it will override the
# original setting
if cfg.model.get('test_cfg') is None and cfg.get('test_cfg') is None:
cfg.model.setdefault('test_cfg',
dict(average_clips=args.average_clips))
else:
if cfg.model.get('test_cfg') is not None:
cfg.model.test_cfg.average_clips = args.average_clips
else:
cfg.test_cfg.average_clips = args.average_clips
# build the model and load checkpoint
model = build_model(cfg.model)
if args.checkpoint is None:
work_dir = cfg.work_dir
args.checkpoint = osp.join(work_dir, 'latest.pth')
load_checkpoint(model, args.checkpoint, map_location='cpu')
if args.fuse_conv_bn:
model = fuse_conv_bn(model)
# model = MMDistributedDataParallel(
# model.cuda(),
# device_ids=[torch.cuda.current_device()],
# broadcast_buffers=False)
model = MMDataParallel(model.cuda())
# outputs = multi_gpu_test(model, data_loader, args.tmpdir)
outputs = single_gpu_test(model, data_loader)
return outputs
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
out = osp.join(cfg.work_dir, 'result.pkl') if args.out is None else args.out
# Load eval_config from cfg
eval_cfg = cfg.get('evaluation', {})
keys = ['interval', 'tmpdir', 'start', 'save_best', 'rule', 'by_epoch', 'broadcast_bn_buffers']
for key in keys:
eval_cfg.pop(key, None)
if args.eval:
eval_cfg['metrics'] = args.eval
mmcv.mkdir_or_exist(osp.dirname(out))
_, suffix = osp.splitext(out)
assert suffix[1:] in file_handlers, ('The format of the output file should be json, pickle or yaml')
# set cudnn benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
cfg.data.test.test_mode = True
# if not hasattr(cfg, 'dist_params'):
# cfg.dist_params = dict(backend='nccl')
#
# init_dist(args.launcher, **cfg.dist_params)
# rank, world_size = get_dist_info()
cfg.gpu_ids = [0]
# build the dataloader
dataset = build_dataset(cfg.data.test, dict(test_mode=True))
dataloader_setting = dict(
videos_per_gpu=cfg.data.get('videos_per_gpu', 1),
workers_per_gpu=cfg.data.get('workers_per_gpu', 1),
shuffle=False)
dataloader_setting = dict(dataloader_setting, **cfg.data.get('test_dataloader', {}))
data_loader = build_dataloader(dataset, **dataloader_setting)
default_mc_cfg = ('localhost', 22077)
memcached = cfg.get('memcached', False)
# if rank == 0 and memcached:
if memcached:
# mc_list is a list of pickle files you want to cache in memory.
# Basically, each pickle file is a dictionary.
mc_cfg = cfg.get('mc_cfg', default_mc_cfg)
assert isinstance(mc_cfg, tuple) and mc_cfg[0] == 'localhost'
if not test_port(mc_cfg[0], mc_cfg[1]):
mc_on(port=mc_cfg[1], launcher=args.launcher)
retry = 3
while not test_port(mc_cfg[0], mc_cfg[1]) and retry > 0:
time.sleep(5)
retry -= 1
assert retry >= 0, 'Failed to launch memcached. '
# dist.barrier()
outputs = inference_pytorch(args, cfg, data_loader)
# rank, _ = get_dist_info()
# if rank == 0:
# print(f'\nwriting results to {out}')
# dataset.dump_results(outputs, out=out)
# if eval_cfg:
# eval_res = dataset.evaluate(outputs, **eval_cfg)
# for name, val in eval_res.items():
# print(f'{name}: {val:.04f}')
print(f'\nwriting results to {out}')
dataset.dump_results(outputs, out=out)
if eval_cfg:
eval_res = dataset.evaluate(outputs, **eval_cfg)
for name, val in eval_res.items():
print(f'{name}: {val:.04f}')
# dist.barrier()
# if rank == 0 and memcached:
if memcached:
mc_off()
if __name__ == '__main__':
main()
这篇文章介绍了如何修改Pyskl的测试脚本,以支持单GPU测试和处理分布式环境。主要内容包括配置文件解析、模型加载、数据加载器构建以及在不同设置下进行预测和评估。
1366

被折叠的 条评论
为什么被折叠?



