使用预训练的ImageNet1K模型在ImageNet100上测试
保持1000类分类头,进行类别映射
此时模型的输出维度是1000,然后必须要将label进行映射,才能得到正确的结果
from torchvision.models import vit_b_16
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import loralib as lora
import os
# 加载预训练模型
model = vit_b_16(weights="IMAGENET1K_V1") # 使用 ImageNet 预训练权重
model.eval() # 设置为评估模式
# 检查设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# # 查看模型结构
# print(model)
# 自定义数据集,重新映射标签
class CustomImageFolder(ImageFolder):
def __init__(self, root, transform=None, target_transform=None, class_to_imagenet_idx=None):
super().__init__(root, transform=transform, target_transform=target_transform)
self.class_to_imagenet_idx = class_to_imagenet_idx
def __getitem__(self, index):
# 获取图像和原始标签
path, original_label = self.samples[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
# 将原始标签映射为 ImageNet 标签
imagenet_label = self.class_to_imagenet_idx[self.classes[original_label]]
return img, imagenet_label
# 数据变换(与训练一致)
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
test_transforms = transforms.Compose([
transforms.Resize(256), # 调整短边为 256
transforms.CenterCrop(224), # 中心裁剪为 224x224
transforms.ToTensor(), # 转换为 Tensor
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) # 归一化
])
# 加载 ImageNet 官方类别标签(顺序固定)
IMAGENET_CLASSES_PATH = "data/imagenet100/imagenet_classes.txt"
assert os.path.exists(IMAGENET_CLASSES_PATH), "请下载 ImageNet 类别标签文件 imagenet_classes.txt!"
with open(IMAGENET_CLASSES_PATH) as f:
imagenet_classes = [line.strip() for line in f.readlines()]
# 加载测试集
test_dataset_path = "data/imagenet100/test"
assert os.path.exists(test_dataset_path), "测试集路径不存在!请检查路径。"
test_dataset = ImageFolder(root=test_dataset_path, transform=test_transforms)
# 检查 ImageFolder 加载的类别顺序
test_classes = list(test_dataset.class_to_idx.keys()) # 按字典序排序的类别名称
assert set(test_classes).issubset(set(imagenet_classes)), "测试集类别不在 ImageNet 类别中!"
# 创建测试集类别到官方类别索引的映射
imagenet_class_to_idx = {cls_name: idx for idx, cls_name in enumerate(imagenet_classes)}
test_class_to_imagenet_idx = {cls: imagenet_class_to_idx[cls] for cls in test_classes}
# import pdb; pdb.set_trace()
# 更新测试集标签索引
# test_dataset.targets = [test_class_to_imagenet_idx[test_dataset.classes[label]] for label in test_dataset.targets]
test_dataset = CustomImageFolder(root=test_dataset_path, transform=test_transforms, class_to_imagenet_idx=test_class_to_imagenet_idx)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)
# 测试准确率
def test_accuracy(model, test_loader, device):
correct = 0
total = 0
with torch.no_grad(): # 禁用梯度计算以加速
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs) # 推理
_, predicted = torch.max(outputs, 1) # 获取预测结果
correct += (predicted == labels).sum().item()
total += labels.size(0)
print(correct, total)
accuracy = correct / total * 100
return accuracy
# 计算并输出测试集准确率
accuracy = test_accuracy(model, test_loader, device)
print(f"ImageNet100 测试集分类准确率: {accuracy:.2f}%")
修改分类头,只留下测试集存在类别
from torchvision.models import vit_b_16
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import loralib as lora
import os
# 加载预训练模型
model = vit_b_16(weights="IMAGENET1K_V1") # 使用 ImageNet 预训练权重
model.eval() # 设置为评估模式
# 检查设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# # 查看模型结构
# print(model)
# 自定义数据集,重新映射标签
class CustomImageFolder(ImageFolder):
def __init__(self, root, transform=None, class_to_imagenet_idx=None):
super().__init__(root, transform)
self.class_to_imagenet_idx = class_to_imagenet_idx
if self.class_to_imagenet_idx is not None:
# 更新 targets 以匹配新的类别映射
self.targets = [self.class_to_imagenet_idx[self.classes[label]] for _, label in self.samples]
else:
# 保持默认行为,使用 ImageFolder 的 targets
self.targets = [label for _, label in self.samples]
def __getitem__(self, index):
img, original_label = super().__getitem__(index)
if self.class_to_imagenet_idx is not None:
mapped_label = self.targets[index]
else:
mapped_label = original_label
return img, mapped_label
# 数据变换(与训练一致)
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
test_transforms = transforms.Compose([
transforms.Resize(256), # 调整短边为 256
transforms.CenterCrop(224), # 中心裁剪为 224x224
transforms.ToTensor(), # 转换为 Tensor
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) # 归一化
])
# 加载 ImageNet 官方类别标签(顺序固定)
IMAGENET_CLASSES_PATH = "data/imagenet100/imagenet_classes.txt"
assert os.path.exists(IMAGENET_CLASSES_PATH), "请下载 ImageNet 类别标签文件 imagenet_classes.txt!"
with open(IMAGENET_CLASSES_PATH) as f:
imagenet_classes = [line.strip() for line in f.readlines()]
# 加载测试集
test_dataset_path = "data/imagenet100/test"
assert os.path.exists(test_dataset_path), "测试集路径不存在!请检查路径。"
test_dataset = ImageFolder(root=test_dataset_path, transform=test_transforms)
# 检查 ImageFolder 加载的类别顺序
test_classes = list(test_dataset.class_to_idx.keys()) # 按字典序排序的类别名称
assert set(test_classes).issubset(set(imagenet_classes)), "测试集类别不在 ImageNet 类别中!"
# 获取当前类别 ID 到原始 ImageNet 类别 ID 的映射
imagenet_class_to_idx = {cls_name: idx for idx, cls_name in enumerate(imagenet_classes)}
current_id_to_original_id = {
test_dataset.class_to_idx[cls]: imagenet_class_to_idx[cls]
for cls in test_dataset.classes
}
print("当前 ID 到原始 ID 的映射:", current_id_to_original_id)
# 获取原始分类头和权重
old_classifier = model.heads.head
old_weight = old_classifier.weight.data # 原始分类头的权重
old_bias = old_classifier.bias.data # 原始分类头的偏置
# 新分类头的类别数
new_num_classes = len(current_id_to_original_id)
# 初始化新的分类头
new_classifier = nn.Linear(old_classifier.in_features, new_num_classes)
# 从旧权重中提取对应类别的权重和偏置
new_weight = torch.stack([old_weight[imagenet_id] for imagenet_id in current_id_to_original_id.values()])
new_bias = torch.tensor([old_bias[imagenet_id] for imagenet_id in current_id_to_original_id.values()])
# 将提取的权重和偏置赋值到新的分类头
new_classifier.weight.data = new_weight
new_classifier.bias.data = new_bias
# 替换模型的分类头
model.heads.head = new_classifier
model = model.to(device)
# import pdb; pdb.set_trace()
# # 更新测试集标签索引
# test_dataset.targets = [test_class_to_imagenet_idx[test_dataset.classes[label]] for label in test_dataset.targets]
# test_dataset = CustomImageFolder(root=test_dataset_path, transform=test_transforms, class_to_imagenet_idx=test_class_to_imagenet_idx)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)
# print(test_dataset.targets)
# print(test_dataset.classes)
# 测试准确率
def test_accuracy(model, test_loader, device):
correct = 0
total = 0
with torch.no_grad(): # 禁用梯度计算以加速
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs) # 推理
_, predicted = torch.max(outputs, 1) # 获取预测结果
correct += (predicted == labels).sum().item()
total += labels.size(0)
print(correct, total)
accuracy = correct / total * 100
return accuracy
# 计算并输出测试集准确率
accuracy = test_accuracy(model, test_loader, device)