1. 代码整体结构
这段代码实现了一个图像分类的任务,主要包含以下几个模块:
- 数据处理:加载和预处理图像数据。
- 半监督学习:在训练过程中加入无标签数据,并根据模型预测的置信度来选择哪些无标签样本加入训练。
- 模型定义:使用卷积神经网络(CNN)模型进行图像分类。
- 训练与评估:训练模型并在验证集上评估其性能。
2. 数据预处理与加载
2.1 seed_everything:设置随机种子
def seed_everything(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
设置随机种子是为了确保结果的可重现性。此函数会为torch、numpy和python的随机数生成器设置种子。
2.2 food_Dataset类:自定义数据集类
该类继承自torch.utils.data.Dataset,用于处理和加载图像数据。
2.2.1 __init__方法
def __init__(self, path, mode="train"):
self.mode = mode
if mode == "semi":
self.X = self.read_file(path)
else:
self.X, self.Y = self.read_file(path)
self.Y = torch.LongTensor(self.Y) #标签转为长整形\
path:数据所在路径。mode:数据集模式(训练集train,验证集val,半监督学习semi)。- 如果是半监督模式,加载无标签数据
X;否则,加载有标签数据X和Y(标签)。
2.2.2 read_file方法
def read_file(self, path):
if self.mode == "semi":
file_list = os.listdir(path)
xi = np.zeros((len(file_list), HW, HW, 3), dtype=np.uint8)
# 列出文件夹下所有文件名字
for j, img_name in enumerate(file_list):
img_path = os.path.join(path, img_name)
img = Image.open(img_path)
img = img.resize((HW, HW))
xi[j, ...] = img
print("读到了%d个数据" % len(xi))
return xi
else:
for i in tqdm(range(11)):
file_dir = path + "/%02d" %

最低0.47元/天 解锁文章
1007

被折叠的 条评论
为什么被折叠?



