PyTorch 分类任务训练模板

简介

想用PyTorch 做分类任务的模型训练,找到一个使用模板,稍加调整并附上我的理解。

1. 数据准备

在这个阶段,传入batch_size, 传入训练样本的存储路径(image_path),数据储存格式如下:

Data
   ----class1
        -----image01.png
        -----image02.png
        ……
    ----class2
        -----image11.png
        -----image12.png
        ……     
      ----class3
        -----image21.png
        -----image22.png
        ……               

接下来就采用torch.utils.data.DataLoader将数据按照train 和 val 打包(这个函数的用法放在最后), 同时也使用了数据增强。

# 传入 batch_size
def train_val_data_process(batch_size:Int,image_path:str):      
    data_transform = {
   
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}    
                                   
    # check the image_path exist or not
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)    
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),                                         transform=data_transform["train"])
    train_num = len(train_dataset)
    
    cl_list = train_dataset.class_to_idx
    num_classes = len(cl_list)
    print("Number of classes:", num_classes) # 
    
    cla_dict = dict((val, key) for key, val in cl_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值