【OWOD】EADA代码解读 - 1

本文详细描述了一个名为EADA的项目,涉及使用VisDA2017数据集进行跨域图像识别,通过PyTorch实现ResNet模型的训练,包括参数加载、数据处理、模型训练过程(包括数据预处理、模型初始化、训练步骤、主动选择等)、CUDA支持以及结果保存。

1. 项目构成

在这里插入图片描述

2. 数据格式

以visda2017为例,其txt文件格式为:path_img [class_number]

data/visda2017/validation/aeroplane/aeroplane_1363127.jpg 0
data/visda2017/validation/aeroplane/aeroplane_1363134.jpg 0

3. main.py

3.1 main()

i. 加载yaml文件中的参数
	1) 使用到argparse,cfg模块;
	2) 使用merge_from_file读取默认yaml文件中的参数
	3) 使用merge_from_list从命令行获取参数
ii. 新建output文件夹
	1) 使用utils.py中的mkdir函数
iii. 建立logger和seed
	1) 使用logger记录日志
	2) 从yaml文件读取使用的seed号码,并固定种子
iv. 训练
	1) 从yaml文件获取数据路径
	2) 冻结参数
	3) 训练 (train())
v. 记录结果生成csv文件

3.2 train()

i. 检查cuda是否可用
ii. 数据预处理
	1) 使用transforms.py中的build_transform函数(torchvision.transforms方法)
	2) 装载源数据和目标数据
	3) 初始化selected数据(此时为空)
iii. 实例化模型、优化器和损失函数
	1) 模型使用network.py中的ResNet
iv. 开始训练:记录信息和时间
	1) 使用metric_logger.py中的metriclogger储存训练信息
	2) 迭代epochs
		a) 使用iter迭代访问数据
		b) 梯度清零:optimizer.zero_grad()
		c) 先叠加源数据的free energy作为损失函数
		d) 再叠加目标数据的free energy作为最终损失函数
		e) 反向传播:total_loss.backward()
		f) 梯度下降:optimizer.step()
	3) 测试:每五个epoch测试一次(test())
	4) 主动选择数据
		a) 根据默认yaml文件在个别epoch选择数据
			i) 可以选择随机或者EADA
	5) 保存模型和主动选择的数据集

3.3 test()

i. evaluation
ii. 计算准确率

3.4 完整代码

from __future__ import print_function
import argparse
import os.path
import os
import logging
import time
import datetime

import torch
import torch.optim as optim
import numpy as np
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader

from core.datasets.image_list import ImageList
from core.models.network import ResNetFc
from core.active.active import EADA_active, RAND_active
from core.utils.utils import set_random_seed, mkdir, momentum_update
from core.datasets.transforms import build_transform
from core.active.loss import NLLLoss, FreeEnergyAlignmentLoss
from core.utils.metric_logger import MetricLogger
from core.utils.logger import setup_logger
from core.config import cfg


def test(model, test_loader):
    start_test = True
    model.eval()  # evaluation
    with torch.no_grad():
        for batch_idx, test_data in enumerate(test_loader):
            img, labels = test_data['img0'], test_data['label']
            img = img.cuda()
            outputs = model(img, return_feat=False)
            if start_test:
                all_output = outputs.float().cpu()
                all_label = labels.float()
                start_test = False
            else:
                all_output = torch.cat((all_output, outputs.float().cpu()), 0)
                all_label = torch.cat((all_label, labels.float()), 0)

    _, predict = torch.min(all_output, 1)
    acc = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0]) * 100  # 计算acc

    return acc


def train(cfg, task):
    logger = logging.getLogger("EADA.trainer")

    # 1. check cuda
    use_cuda = True if torch.cuda.is_available() else False

    kwargs = {
   
   'num_workers': 2, 'pin_memory': True} if use_cuda else {
   
   }

    # 2. transform and prepare data
    source_transform = build_transform(cfg, is_train=True, choices=cfg.INPUT.SOURCE_TRANSFORMS)
    target_transform = build_transform(cfg, is_train=True, choices=cfg.INPUT.TARGET_TRANSFORMS)
    test_transform = build_transform(cfg, is_train=False, choices=cfg.INPUT.TEST_TRANSFORMS)

    src_train_ds = ImageList(os.path.join(cfg.DATASET.R
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值