- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
目标
具体实现
(一)环境
语言环境:Python 3.10
编 译 器: PyCharm
框 架: Pytorch
(二)具体步骤
1. 文件结构
2. config.py
import argparse
def get_options(parser=argparse.ArgumentParser()):
parser.add_argument('--workers', type=int, default=0, help='Number of parallel workers')
parser.add_argument('--batch-size', type=int, default=4, help='input batch size, default=32')
parser.add_argument('--size', type=tuple, default=(224, 224), help='input image size')
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate, default=0.0001')
parser.add_argument('--epochs', type=int, default=20, help='number of epochs')
parser.add_argument('--seed', type=int, default=112, help='random seed')
parser.add_argument('--save-path', type=str, default='./models/', help='path to save checkpoints')
opt = parser.parse_args()
if opt:
print(f'num_workers:{opt.workers}')
print(f'batch_size:{opt.batch_size}')
print(f'learn rate:{opt.lr}')
print(f'epochs:{opt.epochs}')
print(f'random seed:{opt.seed}')
print(f'save_path:{opt.save_path}')
return opt
if __name__ == '__main__':
opt = get_options()
3. Utils.py
import torch
import pathlib
import matplotlib.pyplot as plt
from torchvision.transforms import transforms
# 第一步:设置GPU
def USE_GPU():
if torch.cuda.is_available():
print('CUDA is available, will use GPU')
device = torch.device("cuda")
else:
print('CUDA is not available. Will use CPU')
device = torch.device("cpu")
return device
temp_dict = dict()
def recursive_iterate(path):
"""
根据所提供的路径遍历该路径下的所有子目录,列出所有子目录下的文件
:param path: 路径
:return: 返回最后一级目录的数据
""" path = pathlib.Path(path)
for file in path.iterdir():
if file.is_file():
temp_key = str(file).split('\\')[-2]
if temp_key in temp_dict:
temp_dict.update({temp_key: temp_dict[temp_key] + 1})
else:
temp_dict.update({temp_key: 1})
# print(file)
elif file.is_dir():
recursive_iterate(file)
return temp_dict
def data_from_directory(directory, train_dir=None, test_dir=None, show=False):
"""
提供是的数据集是文件形式的,提供目录方式导入数据,简单分析数据并返回数据分类
:param test_dir: 是否设置了测试集目录
:param train_dir: 是否设置了训练集目录
:param directory: 数据集所在目录
:param show: 是否需要以柱状图形式显示数据分类情况,默认显示
:return: 数据分类列表,类型: list
""" global total_image
print("数据目录:{}".format(directory))
data_dir = pathlib.Path(directo