智能数字图像处理:图卷积SGN代码(pytorch)之main.py解读

本文详细解读了基于PyTorch的图卷积模型SGN的main.py代码,包括设置CUDA环境、参数解析、模型构建、训练过程、验证与测试。内容涵盖数据加载、损失函数、优化器、学习率调整、模型保存与评估等关键步骤。

1.os.environ["CUDA_VISIBLE_DEVICES"] = '0'-》目的:使用cuda的环境变量cuda_visible_Devices莱限定CUDA程序所能使用的GPU设备

2.parser = argparse.ArgumentParser(description='Skeleton-Based Action Recgnition')-》创建解析器,使用 argparse 的第一步是创建一个 ArgumentParser 对象。
fit.add_fit_args(parser)-》读入命令行参数,该调用有多个参数

3.parser.set_defaults(
    network='SGN',-》网络名称
    dataset = 'NTU',-》数据集名称
    case = 0,-》默认实例为0
    batch_size=16,-》批处理大小
    max_epochs=40,-》最大训练轮数
    monitor='val_acc',-》监视器为验证集的准确率
    lr=0.001,-》初始学习率
    weight_decay=0.0001,-》衰减权重
    lr_factor=0.1,
    workers=0,-》线程数
    print_freq = 20,
    train = 0,
    seg = 20,
    )、

4. args.num_classes = get_num_classes(args.dataset)-》调用get_num_classes得到分类数
    model = SGN(args.num_classes, args.dataset, args.seg, args)-》调用model的SGN方法得到模型

5.total = get_n_params(model)-》调用get_n_params得到model的累乘。
    print(model)-》打印模型结构哦
    print('The number of parameters: ', total)-》打印参数个数
    print('The modes is:', args.network)-》打印模式

6.if torch.cuda.is_available():-》核实显卡驱动是不是可用
        print('It is using GPU!')
        model = model.cuda()-》模型加载GPU驱动

7. criterion = LabelSmoothingLoss(args.num_classes, smoothing=0.1).cuda()-》调用LabelSmoothingLoss方法进行loss的传播。
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)-》构建一个优化器optimizer,你必须给它一个可进行迭代优化的包含了所有参数

8. if args.monitor == 'val_acc':-》若监听到验证集准确率
        mode = 'max'
        monitor_op = np.greater
        best = -np.Inf
        str_op = 'improve'-》则增加
    elif args.monitor == 'val_loss':-》若监听到验证集损失值
        mode = 'min'
        monitor_op = np.less
        best = np.Inf
        str_op = 'reduce'-》则减少

9.scheduler = MultiStepLR(optimizer, milestones=[60, 90, 110], gamma=0.1)-》根据优化器进行学习率调整
    ntu_loaders = NTUDataLoaders(args.dataset, args.case, seg=args.seg)-》加载数据
    train_loader = ntu_loaders.get_train_loader(args.batch_size, args.workers)-》加载训练数据
    val_loader = ntu_loaders.get_val_loader(args.batch_size, args.workers)-》加载验证数据
    train_size = ntu_loaders.get_train_size()-》获取训练数据大小
    val_size = ntu_loaders.get_val_size()-》获取验证数据大小

10.test_loader = ntu_loaders.get_test_loader(32, args.workers)-》加载测试数据集

    print('Train on %d samples, validate on %d samples' % (train_size, val_size))-》打印训练集大小验证集大小

    best_epoch = 0-》初始化epoch
    output_dir = make_dir(args.dataset)-》创建目录

11. save_path = os.path.join(output_dir, args.network)-》设置保存路径
    if not os.path.exists(save_path):-》如果路径不存在
        os.makedirs(save_path)-》创建路径

12.checkpoint = osp.join(save_path, '%s_best.pth' % args.case)-》设置训练完权重文件保存路径和文件名格式
    earlystop_cnt = 0
    csv_file = osp.join(save_path, '%s_log.csv' % args.case)-》设置日志文件保存路径和文件名格式
    log_res = list()

13. lable_path = osp.join(save_path, '%s_lable.txt'% args.case)-》设置标签文件保存路径和文件名格式
    pred_path = osp.join(save_path, '%s_pred.txt' % args.case)-》设置预测文件保存路径和文件名格式

14.开始训练
    if args.train ==1:-》第一轮
        for epoch in range(args.start_epoch, args.max_epochs):-》循环,范围在第一轮到最后一轮(40)之间

            print(epoch, optimizer.param_groups[0]['lr'])-》打印训练轮数,优化器的参数

15. t_start = time.time()-》记录开始时间
     train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch)-》计算训练损失和准确率
     val_loss, val_acc = validate(val_loader, model, criterion)-》计算验证集损失和准确率
   

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值