import models
import torch
import cv2
import os
import time
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# TODO 分类训练的标签类型,顺序很重要,一定和模型训练时保持一致!!!
CLASSNAMES = ['1','2','3','4']
def inference_one(model, img_path, out_path):
'''pytorch多标签分类模型的单张图像推理'''
_, imgname = os.path.split(img_path)
print(imgname)
# ==========(1)使用PIL进行测试的代码=====================================
transform_valid = transforms.Compose([transforms.Resize((256, 256), interpolation=2), transforms.ToTensor()])
img = Image.open(img_path)
img_ = transform_valid(img).unsqueeze(0) # 拓展维度
# # ==========(2)使用opencv读取图像的测试代码,若使用opencv进行读取,将上面(1)注释掉==========
# img = cv2.imread(img_path)
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# img = cv2.resize(img, (256, 256))
# img_ = torch.from_numpy(img).float().permute(2, 0, 1).unsqueeze(0)/255
img_ = img_.to(device)
outputs = model(img_)
outputs = torch.sigmoid(outputs) # 在训练时也这样设置:多标签用 sigmod 单标签用softmax,
print(outputs)
y_pred = []
for i in range(len(outputs[0])):
if float(outputs[0][i]) > 0.5:
y_pred.append(1)
else:
y_pred.append(0)
# print(imgname, y_pred)
# 输出概率最大的类别
_, indices = torch.max(outputs, 1)
percentage = outputs[0] * 100
perc = percentage[int(indices)].item()
result = CLASSNAMES[indices]
print('max predicted:', result, round(perc,2))
# 输出从大到小排序的预测结果
_, indices = torch.sort(outputs, descending=True)
percentage = outputs[0] * 100
predict = [(CLASSNAMES[idx], round(percentage[idx].item(),2)) for idx in indices[0]]
print('all predicted:', predict)
# 可视化 top3
image = img_.squeeze(0)
image = image.detach().cpu().numpy()
image = np.transpose(image, (1, 2, 0))
plt.imshow(image)
plt.axis('off')
plt.title(f"{predict[:3]}")
plt.savefig(os.path.join(out_path,imgname.replace(".jpg","_res.jpg")))
plt.show()
return y_pred
if __name__ == "__main__":
start_time = time.time()
model = models.model(pretrained=False, requires_grad=False).to(device)
checkpoint = torch.load('/workspace/code/multi-label_image_classification/outputs/model_1117.pth')
model.load_state_dict(checkpoint['model_state_dict']) # load model weights state_dict
model.eval()
root_path = "/workspace/code/multi-label_image_classification/input/biaoji-classifier/Multi_Label_dataset/Images"
out_path = "/workspace/code/multi-label_image_classification/resimg"
os.makedirs(out_path, exist_ok=True)
txt = open("/workspace/code/multi-label_image_classification/resimg.txt", "w")
for filename in os.listdir(root_path):
if ".jpg" in filename:
img_path = os.path.join(root_path, filename)
y_pred = inference_one(model, img_path, out_path)
txt.writelines(f"{filename}\t{y_pred}\n")
txt.flush()
txt.close()
end_time = time.time()
print("cost times:", end_time - start_time)
pytorch多标签分类模型的单张图像推理
最新推荐文章于 2024-12-17 19:26:54 发布