上篇做了简单的介绍以及数据集的训练:请点这里
下篇将借助OpenCV试着实际使用一下我们训练的模型
首先:对于上篇描述的训练好的模型model,将其保存下来,使用时再加载出来
#训练模型
for epoch in range(10): #训练10轮
train(epoch) #执行训练
test() #每轮训练完都测试一下正确率
#保存模型
path = "C:/Users/yas/Desktop/pytorch/MNIST/model/model1.pth"
torch.save(model, path)
#加载模型
model = torch.load("C:/Users/yas/Desktop/pytorch/MNIST/model/model1.pth")
我们想要实现的是,通过摄像头捕获图片,实时识别出数字。大致有几个步骤:
1.调用摄像头,拿出每一帧
2.将每一帧输入进模型,得到输出的预测结果
3.将摄像头图像和预测值实时显示出来
这部分代码都包含在一个循环里面:
这其中包括了:获取每一帧图像,转换为灰度图像,反向二值化,改变其大小为28*28,最后还要转换为张量才能送到模型中去。(由于数据集中数字是黑底白字,我们日常手写是白底黑字,所以需要反向二值化操作)
然后是使用模型进行预测。
最后输出预测结果即可。
#1.调用摄像头,拿出每一帧
cap = cv2.VideoCapture(0) #定义视频来源为摄像头
while 1:
ret, frame = cap.read() # 摄像头读取,ret为是否读取成功,frame为视频的每一帧图像
#展示原图像
frame = cv2.resize(frame, (300, 200))
cv2.imshow("source", frame)
#展示灰度,二值化后图像
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) #灰度化
res, frame = cv2.threshold(frame, 90, 255, cv2.THRESH_BINARY_INV) #反向二值化
cv2.imshow("gray", frame)
#展示输入模型的图像(方形)
frame = cv2.resize(frame, (140, 140))
cv2.imshow("28*28", frame)
cv2.waitKey(100) #延时,控制帧率
frame = cv2.resize(frame, (28, 28)) #转换成28*28大小
#两次升维,使其能送入模型
testimg = torch.unsqueeze(testimg, dim=0)
testimg = torch.unsqueeze(testimg, dim=0)
#转换为浮点数
testimg = testimg.to(torch.float32)
#2.将每一帧输入进模型,得到输出的预测结果
#加载模型
model = torch.load("C:/Users/yas/Desktop/pytorch/MNIST/model/model1.pth")
predimg = model(testimg) #进行预测
_, pred = torch.max(predimg.data, dim=1) #获得最大值
print('the predict num is', int(pred.data[0])) #输出结果
最后结果如图:
总的来看,当送进模型的图像(左上角)周围全黑且数字在中央时,准确率较高,一旦周围有杂质或是数字不在中央或是大小不合适,准确率就十分低了。
以上就是我正在入门pytorch的一些学习过程,还请各位指出其中错误指出,谢谢阅读!