windows单GPU测试pyskl

这篇文章介绍了如何修改Pyskl的测试脚本,以支持单GPU测试和处理分布式环境。主要内容包括配置文件解析、模型加载、数据加载器构建以及在不同设置下进行预测和评估。
部署运行你感兴趣的模型镜像

test.py修改

pyskl的测试文件和训练文件一样,不支持单GPU测试,因此需要修改代码,屏蔽分布式测试代码

修改后代码如下:

# Copyright (c) OpenMMLab. All rights reserved.
# flake8: noqa: E722
import argparse
import os
import os.path as osp
import time

import mmcv
import torch
import torch.distributed as dist
from mmcv import Config, load
from mmcv.cnn import fuse_conv_bn
from mmcv.engine import multi_gpu_test, single_gpu_test
from mmcv.fileio.io import file_handlers
from mmcv.parallel import MMDistributedDataParallel, MMDataParallel
from mmcv.runner import get_dist_info, init_dist, load_checkpoint

from pyskl.datasets import build_dataloader, build_dataset
from pyskl.models import build_model
from pyskl.utils import mc_off, mc_on, test_port


def parse_args():
    parser = argparse.ArgumentParser(
        description='pyskl test (and eval) a model')
    parser.add_argument('--config',default='../configs/posec3d/c3d_light_gym/joint.py', help='test config file path')
    parser.add_argument('-C', '--checkpoint', help='checkpoint file', default='latest.pth')
    parser.add_argument(
        '--out',
        default=None,
        help='output result file in pkl/yaml/json format')
    parser.add_argument(
        '--fuse-conv-bn',
        action='store_true',
        help='Whether to fuse conv and bn, this will slightly increase'
        'the inference speed')
    parser.add_argument(
        '--eval',
        type=str,
        nargs='+',
        default=['top_k_accuracy', 'mean_class_accuracy'],
        help='evaluation metrics, which depends on the dataset, e.g.,'
        ' "top_k_accuracy", "mean_class_accuracy" for video dataset')
    parser.add_argument(
        '--tmpdir',
        help='tmp directory used for collecting results from multiple workers')
    parser.add_argument(
        '--average-clips',
        choices=['score', 'prob', None],
        default=None,
        help='average type when averaging test clips')
    parser.add_argument(
        '--launcher',
        choices=['pytorch', 'slurm'],
        default='pytorch',
        help='job launcher')
    # parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    # if 'LOCAL_RANK' not in os.environ:
    #     os.environ['LOCAL_RANK'] = str(args.local_rank)

    return args


def inference_pytorch(args, cfg, data_loader):
    """Get predictions by pytorch models."""
    if args.average_clips is not None:
        # You can set average_clips during testing, it will override the
        # original setting
        if cfg.model.get('test_cfg') is None and cfg.get('test_cfg') is None:
            cfg.model.setdefault('test_cfg',
                                 dict(average_clips=args.average_clips))
        else:
            if cfg.model.get('test_cfg') is not None:
                cfg.model.test_cfg.average_clips = args.average_clips
            else:
                cfg.test_cfg.average_clips = args.average_clips

    # build the model and load checkpoint
    model = build_model(cfg.model)

    if args.checkpoint is None:
        work_dir = cfg.work_dir
        args.checkpoint = osp.join(work_dir, 'latest.pth')

    load_checkpoint(model, args.checkpoint, map_location='cpu')

    if args.fuse_conv_bn:
        model = fuse_conv_bn(model)

    # model = MMDistributedDataParallel(
    #     model.cuda(),
    #     device_ids=[torch.cuda.current_device()],
    #     broadcast_buffers=False)
    model = MMDataParallel(model.cuda())
    # outputs = multi_gpu_test(model, data_loader, args.tmpdir)
    outputs = single_gpu_test(model, data_loader)
    return outputs


def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)

    out = osp.join(cfg.work_dir, 'result.pkl') if args.out is None else args.out

    # Load eval_config from cfg
    eval_cfg = cfg.get('evaluation', {})
    keys = ['interval', 'tmpdir', 'start', 'save_best', 'rule', 'by_epoch', 'broadcast_bn_buffers']
    for key in keys:
        eval_cfg.pop(key, None)
    if args.eval:
        eval_cfg['metrics'] = args.eval

    mmcv.mkdir_or_exist(osp.dirname(out))
    _, suffix = osp.splitext(out)
    assert suffix[1:] in file_handlers, ('The format of the output file should be json, pickle or yaml')

    # set cudnn benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
    cfg.data.test.test_mode = True

    # if not hasattr(cfg, 'dist_params'):
    #     cfg.dist_params = dict(backend='nccl')
    #
    # init_dist(args.launcher, **cfg.dist_params)
    # rank, world_size = get_dist_info()
    cfg.gpu_ids = [0]

    # build the dataloader
    dataset = build_dataset(cfg.data.test, dict(test_mode=True))
    dataloader_setting = dict(
        videos_per_gpu=cfg.data.get('videos_per_gpu', 1),
        workers_per_gpu=cfg.data.get('workers_per_gpu', 1),
        shuffle=False)
    dataloader_setting = dict(dataloader_setting, **cfg.data.get('test_dataloader', {}))
    data_loader = build_dataloader(dataset, **dataloader_setting)

    default_mc_cfg = ('localhost', 22077)
    memcached = cfg.get('memcached', False)

    # if rank == 0 and memcached:
    if memcached:
        # mc_list is a list of pickle files you want to cache in memory.
        # Basically, each pickle file is a dictionary.
        mc_cfg = cfg.get('mc_cfg', default_mc_cfg)
        assert isinstance(mc_cfg, tuple) and mc_cfg[0] == 'localhost'
        if not test_port(mc_cfg[0], mc_cfg[1]):
            mc_on(port=mc_cfg[1], launcher=args.launcher)
        retry = 3
        while not test_port(mc_cfg[0], mc_cfg[1]) and retry > 0:
            time.sleep(5)
            retry -= 1
        assert retry >= 0, 'Failed to launch memcached. '

    # dist.barrier()
    outputs = inference_pytorch(args, cfg, data_loader)

    # rank, _ = get_dist_info()
    # if rank == 0:
    #     print(f'\nwriting results to {out}')
    #     dataset.dump_results(outputs, out=out)
    #     if eval_cfg:
    #         eval_res = dataset.evaluate(outputs, **eval_cfg)
    #         for name, val in eval_res.items():
    #             print(f'{name}: {val:.04f}')
    print(f'\nwriting results to {out}')
    dataset.dump_results(outputs, out=out)
    if eval_cfg:
        eval_res = dataset.evaluate(outputs, **eval_cfg)
        for name, val in eval_res.items():
            print(f'{name}: {val:.04f}')

    # dist.barrier()
    # if rank == 0 and memcached:
    if memcached:
        mc_off()


if __name__ == '__main__':
    main()

您可能感兴趣的与本文相关的镜像

Wan2.2-I2V-A14B

Wan2.2-I2V-A14B

图生视频
Wan2.2

Wan2.2是由通义万相开源高效文本到视频生成模型,是有​50亿参数的轻量级视频生成模型,专为快速内容创作优化。支持480P视频生成,具备优秀的时序连贯性和运动推理能力

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值