pytorch(一)——用python自動生成train,val文件

本文介绍了一种将大规模图片数据集划分为训练集和验证集的方法,通过Python脚本实现,适用于拥有多个类别且每个类别包含大量图片的数据集。同时,针对图片命名规则复杂的情况,提供了数据加载类的定制化解决方案。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

任務——分類

數據集爲5個不同類別的圖片集,每個圖片集大概有3W張圖片。所以要建立一個train訓練的txt文件和一個val驗證的txt文件,裏面放圖片的路徑,因爲只是練手用,所以不放test驗證。

圖爲5個數據集文件
最終要的結果是從每個文件裏拿出28000個訓練和剩下差不多3000個用來測試。

import os
a=0
while(a<5):

    dir = '/home/zyx/data/pic/'+str(a)+'/'
    label = a

    files = os.listdir(dir)
    files.sort()
    train = open('/home/zyx/data/train.txt','a')
    val = open('/home/zyx/data/val.txt', 'a')
    i = 1
    for file in files:
        if i<29000:
            fileType = os.path.split(file)
            if fileType[1] == '.txt':
                continue
            name =  str(dir) +  file + ' ' + str(int(label)) +'\n'
            train.write(name)
            i = i+1
            print(i)
        else:
            fileType = os.path.split(file)
            if fileType[1] == '.txt':
                continue
            name = str(dir) +file + ' ' + str(int(label)) +'\n'
            val.write(name)
            i = i+1
            print(i)


    val.close()
    train.close()
    print(a)
    a = a + 1

結果在这里插入图片描述

然後就可以開始寫網絡和訓練模型了
因爲我的圖片數據集裏有/home/zyx/data/pic/0/0_original_108475 (2).JPG_6c664301-0796-43f1-ba25-f19aa62537b4.JPG 0比較奇怪的命名,所以要把讀取數據的地方稍微做一些修改

class MyDataset(Dataset):
    def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
        fh = open(txt, 'r')
        imgs = []
        for line in fh:
            line = line.strip('\n')
            line = line.rstrip()
            words = line.split()
            if len(words)>2:
                words[0] = str((words[0]))+' '+str((words[1]))
                words[1] = words[2]
            print(len(words))

            imgs.append((words[0],int(words[1])))
            print((words[0],int(words[1])))
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = self.loader(fn)
        if self.transform is not None:
            img = self.transform(img)
        return img,label

    def __len__(self):
        return len(self.imgs)

基本上再用這個沒啥問題可以直接用了

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值