第一步:准备数据
4种肺部数据,新冠肺炎阳性病例,正常、肺部不透明(非COVID肺部感染)和病毒性肺炎图像:
self.class_indict = ["COVID", "Lung_Opacity", "Normal", "Viral_Pneumonia"]
,总共有21161张图片,每个文件夹单独放一种
第二步:搭建模型
本文选择MobileNetV2,其网络结构如下:
第三步:训练代码
1)损失函数为:交叉熵损失函数
2)MobileNetV2代码:
from torch import nn
import torch
def _make_divisible(ch, divisor=8, min_ch=None):
"""
This function is taken from the original tf repo.
It ensures th