MUSC: ZERO-SHOT INDUSTRIAL ANOMALY CLASSIFICATION AND SEGMENTATION WITH MUTUAL SCORING OF THE UNLABELED IMAGES
代码
1. 特征提取
使用clip特征提取器ViT-L-14-336,提取全局特征和局部特征。
if isinstance(image_info, dict):
image = image_info["image"]
image_path_list.extend(image_info["image_path"])
img_masks.append(image_info["mask"])
gt_list.extend(list(image_info["is_anomaly"].numpy()))
with torch.no_grad(), torch.cuda.amp.autocast():
input_image = image.to(torch.float).to(self.device)
if 'dinov2' in self.model_name:
patch_tokens = self.dino_model.get_intermediate_layers(x=input_image, n=[l-1 for l in self.features_list], return_class_token=False)
image_features = self.dino_model(input_image)
patch_tokens = [patch_tokens[l].cpu() for l in range(len(self.features_list))]
fake_cls = [torch.zeros_like(p)[:, 0:1, :] for p in patch_tokens]
patch_tokens = [torch.cat([fake_cls[i], patch_tokens[i]], dim=1) for i in range(len(patch_tokens))]
elif 'dino' in self.model_name:
patch_tokens_all = self.dino_model.get_intermediate_layers(x=input_image, n=max(self.features_list))
image_features = self.dino_model(input_image)
patch_tokens = [patch_tokens_all[l-1].cpu() for l in self.features_list]
else: # 本文使用的是clip特征提取器
# 使用CLIP特征提取器,返回第一项为图像特征,第二项为局部特征
image_features, patch_tokens = self.clip_model.encode_image(input_image, self.features_list)
# 对全局特征最后一个维度计算L2范数,并进行归一化
image_features /= image_features.norm(dim=-1, keepdim=True)
# 局部特征,形状为batch_size x 1370 x 1024
patch_tokens = [patch_tokens[l].cpu() for l in range(len(self.features_list))]
# 将特征转换为列表tensor([4, 769]) -> [numpy(768), numpy(768), numpy(768), numpy(768)]
image_features = [image_features[bi].squeeze().cpu().numpy() for bi in range(image_features.shape[0])]
# 图片的全局特征,长度为N图片数,元素为numpy张量(768)
class_tokens.extend(image_features)
# 图片的局部特征,长度为N/batch_size,每个元素为列表(长度为batch_size,每个元素为tensor张量[4, 1370, 1024],代表四个不同尺度的特征)
patch_tokens_list.append(patch_tokens)
2.LNAMD
分三个聚合度{1,3,5}进行特征提取。
# LNAMD
# 特征维度,1024
feature_dim = patch_tokens_list[0][0].shape[-1]
anomaly_maps_r = torch.tensor([]).double()
# r_list聚合度列表[1, 3, 5]
for r in self.r_list:
start_time = time.time()
print('aggregation degree: {}'.format(r))
LNAMD_r = LNAMD(device=self.device, r=r, feature_dim=feature_dim, feature_layer=self.features_list)
# 创建空字典,键是层级索引,值是该层的所有特征列表
Z_layers = {}
for im in range(len(patch_tokens_list)):
# 获取特征块,大小为3 x 1370 x 1024
patch_tokens = [p.to(self.device) for p in patch_tokens_list[im]]
with torch.no_grad(), torch.cuda.amp.autocast():
# 聚合模块,聚合后特征为 [3, 1369, 4, 1024] 去掉了类别维度
features = LNAMD_r._embed(patch_tokens)
features /= features.norm(dim=-1, keepdim=True)
for l in range(len(self.features_list)):
# 存储聚合后的特征
if str(l) not in Z_layers.keys():
Z_layers[str(l)] = []
Z_layers[str(l)].append(features[:, :, l, :])
# 最终输出的Z有0123四个键,值为长度为21的特征列表,特征形状为4 x 1369 x 1024
end_time = time.time()
print('LNAMD-{}: {}ms per image'.format(r, (end_time-start_time)*1000/subset_num))
LNAMD实现,对聚合度不为1的特征分块处理。
class LNAMD(torch.nn.Module):
def __init__(self, device, feature_dim=1024, feature_layer=[1,2,3,4], r=3, patchstride=1):
super(LNAMD, self).__init__()
self.device = device
self.r = r # 聚合度
self.patch_maker = PatchMaker(r, stride=patchstride) # 切块函数
self.LNA = Preprocessing(feature_layer, feature_dim)
def _embed(self, features):
B = features[0].shape[0]
features_layers = []
# 对特征进行形状变换
for feature in features:
# reshape and layer normalization
feature = feature[:, 1:, :] # remove the cls token
feature = feature.reshape(feature.shape[0],
int(math.sqrt(feature.shape[1])),
int(math.sqrt(feature.shape[1])),
feature.shape[2])
feature = feature.permute(0, 3, 1, 2)
feature = torch.nn.LayerNorm([feature.shape[1], feature.shape[2],
feature.shape[3]]).to(self.device)(feature)
features_layers.append(feature)
if self.r != 1: # 聚合度为1则不需要分块
# divide into patches
features_layers = [self.patch_maker.patchify(x, return_spatial_info=True) for x in features_layers] # 切块
patch_shapes = [x[1] for x in features_layers] # 存储特征块的高宽信息
features_layers = [x[0] for x in features_layers] # 存储分块后的特征
else:
patch_shapes = [f.shape[-2:] for f in features_layers] # 长度为batch,每个元素为块的wxh,存储每个块的高宽信息
# 存储分块后的特征
features_layers = [f.reshape(f.shape[0], f.shape[1], -1, 1, 1).permute(0, 2, 1, 3, 4) for f in features_layers]
ref_num_patches = patch_shapes[0] # 将第一层特征分辨率设置为参考分辨率
for i in range(1, len(features_layers)):
patch_dims = patch_shapes[i] # 获取当前层特征分辨率
if patch_dims[0] == ref_num_patches[0] and patch_dims[1] == ref_num_patches[1]:
continue # 如果和参考分辨率一致则不需要处理
# 插值对齐分辨率
_features = features_layers[i]
_features = _features.reshape(
_features.shape[0], patch_dims[0], patch_dims[1], *_features.shape[2:]
)
_features = _features.permute(0, -3, -2, -1, 1, 2)
perm_base_shape = _features.shape
_features = _features.reshape(-1, *_features.shape[-2:])
_features = F.interpolate(
_features.unsqueeze(1),
size=(ref_num_patches[0], ref_num_patches[1]),
mode="bilinear",
align_corners=False,
)
_features = _features.squeeze(1)
_features = _features.reshape(
*perm_base_shape[:-2], ref_num_patches[0], ref_num_patches[1]
)
_features = _features.permute(0, -2, -1, 1, 2, 3)
_features = _features.reshape(len(_features), -1, *_features.shape[-3:])
features_layers[i] = _features
features_layers = [x.reshape(-1, *x.shape[-3:]) for x in features_layers]
# aggregation
features_layers = self.LNA(features_layers) # 将特征采样到同一分辨率
features_layers = features_layers.reshape(B, -1, *features_layers.shape[-2:]) # (B, L, layer, C)
return features_layers.detach().cpu()
LNA
用于统一特征分辨率
class Preprocessing(torch.nn.Module):
def __init__(self, input_layers, output_dim):
super(Preprocessing, self).__init__()
self.output_dim = output_dim
self.preprocessing_modules = torch.nn.ModuleList()
# 根据特征层数添加平均层
for input_layer in input_layers:
module = MeanMapper(output_dim)
self.preprocessing_modules.append(module)
def forward(self, features):
# 对特征层进行下采样 [batchx1369, 1024, 5, 5] -> [batchx1369, 1024]
_features = []
for module, feature in zip(self.preprocessing_modules, features):
_features.append(module(feature))
return torch.stack(_features, dim=1)
class MeanMapper(torch.nn.Module):
def __init__(self, preprocessing_dim):
super(MeanMapper, self).__init__()
self.preprocessing_dim = preprocessing_dim
def forward(self, features):
features = features.reshape(len(features), 1, -1)
return F.adaptive_avg_pool1d(features, self.preprocessing_dim).squeeze(1) # 平均池化
3. MSM
通过计算图像块对其他图像的所有图像块的距离获得异常得分。
anomaly_maps_l = torch.tensor([]).double()
start_time = time.time()
# 遍历所有层的特征
for l in Z_layers.keys():
# 将第l层上所有特征拼接,形状为[83, 1369, 1024]
Z = torch.cat(Z_layers[l], dim=0).to(self.device)
print('layer-{} mutual scoring...'.format(l))
anomaly_maps_msm = MSM(Z=Z, device=self.device, topmin_min=0, topmin_max=0.3)
# 将第l层计算的异常图加入最终异常图中
anomaly_maps_l = torch.cat((anomaly_maps_l, anomaly_maps_msm.unsqueeze(0).cpu()), dim=0)
torch.cuda.empty_cache()
# 计算所有层异常得分图的平均 [N, patch_num]
anomaly_maps_l = torch.mean(anomaly_maps_l, 0)
# 将所有聚合度计算得到的异常图拼接 [r, N, patch_num]
anomaly_maps_r = torch.cat((anomaly_maps_r, anomaly_maps_l.unsqueeze(0)), dim=0)
end_time = time.time()
print('MSM: {}ms per image'.format((end_time-start_time)*1000/subset_num))
# 拼接所有聚合度的异常图
anomaly_maps_iter = torch.mean(anomaly_maps_r, 0).to(self.device)
del anomaly_maps_r
torch.cuda.empty_cache()
B, L = anomaly_maps_iter.shape
H = int(np.sqrt(L))
# 对特征计算得到的异常图上采样到输入图像的分辨率
anomaly_maps_iter = F.interpolate(anomaly_maps_iter.view(B, 1, H, H),
size=self.image_size, mode='bilinear', align_corners=True)
anomaly_maps = torch.cat((anomaly_maps, anomaly_maps_iter.cpu()), dim=0)
源码中给出了两种实现方式。
首先计算图像块到所有图像(除本身)的所有图像块的距离,输出num_patch x (N-1) x num_patch。
接着取图像块到所有图像的最相似图像块的距离,输出num_patch x (N-1)。
只取距离在前30%的图像,输出num_patch x 0.3(N-1)。
取到30%图像距离的均值,作为每个图像块的异常得分,输出num_patch。
def compute_scores_fast(Z, i, device, topmin_min=0, topmin_max=0.3):
image_num, patch_num, c = Z.shape
patch2image = torch.tensor([]).to(device)
Z_ref = torch.cat((Z[:i], Z[i+1:]), dim=0) # 将要计算的第i张图特征排除
# 计算图像块到所有其他图像的图像块的距离,输出[patch_num, N-1, patch_num]
patch2image = torch.cdist(Z[i:i+1], Z_ref.reshape(-1, c)).reshape(patch_num, image_num-1, patch_num)
# 取图像块到一幅图像中所有图像块距离的最小值,输出[patch_num, N-1]
patch2image = torch.min(patch2image, -1)[0]
# 设置取值范围
k_max = topmin_max
k_min = topmin_min
if k_max < 1:
k_max = int(patch2image.shape[1]*k_max)
if k_min < 1:
k_min = int(patch2image.shape[1]*k_min)
if k_max < k_min:
k_max, k_min = k_min, k_max
# 取异常得分升序排列前30%的块,输出[patch_num, (N-1) x 0.3]
vals, _ = torch.topk(patch2image.float(), k_max, largest=False, sorted=True)
vals, _ = torch.topk(vals.float(), k_max-k_min, largest=True, sorted=True)
patch2image = vals.clone()
# 取平均,结果为每个块的异常得分,输出[patch_num]
return torch.mean(patch2image, dim=1)
def MSM(Z, device, topmin_min=0, topmin_max=0.3):
anomaly_scores_matrix = torch.tensor([]).double().to(device)
for i in tqdm(range(Z.shape[0])): # Z[N x feature_dim x wh]
# 对于一幅图像,计算其所有块的异常得分,输出[patch_num]
anomaly_scores_i = compute_scores_fast(Z, i, device, topmin_min, topmin_max).unsqueeze(0)
# 将所有图像计算的异常得分拼接,输出[N, patch_num]
anomaly_scores_matrix = torch.cat((anomaly_scores_matrix, anomaly_scores_i.double()), dim=0) # (N, B)
return anomaly_scores_matrix
4.RsCIN
计算异常得分,优化异常得分。
代码中的class_tokens取自特征提取器中所提取的特征。
B = anomaly_maps.shape[0]
# 对异常图取最大值计算异常得分 [N]
ac_score = np.array(anomaly_maps).reshape(B, -1).max(-1)
# RsCIN根据数据集设置超参数
if self.dataset == 'visa':
k_score = [1, 8, 9]
elif self.dataset == 'mvtec_ad':
k_score = [1, 2, 3]
else:
k_score = [1, 2, 3]
# 优化异常得分,class_tokens为特征提取过程中的图像特征(长度为N的列表,特征维度为768), 输出 [N]
scores_cls = RsCIN(ac_score, class_tokens, k_list=k_score)
降低异常样本的得分,从而区分正常和异常样本。
def MMO(W, score, k_list=[1, 2, 3]):
S_list = []
for k in k_list:
# 将相似矩阵W中最小的W.shape[0]-k个值置0
_, topk_matrix = torch.topk(W.float(), W.shape[0]-k, largest=False, sorted=True)
W_mask = W.clone()
for i in range(W.shape[0]):
W_mask[i, topk_matrix[i]] = 0
n = W.shape[-1]
# 计算对角矩阵和概率转移矩阵
D_ = torch.zeros_like(W).float()
for i in range(n):
# 计算每一行的权重和,将其倒数存储在对角线上,用于对样本权重进行缩放
D_[i, i] = 1 / (W_mask[i,:].sum())
P = D_ @ W_mask # 将每一行归一化为概率分布
S = score.clone().unsqueeze(-1)
# 通过矩阵乘法更新异常得分
S = P @ S
S_list.append(S)
S = torch.concat(S_list, -1).mean(-1)
return S
def RsCIN(scores_old, cls_tokens=None, k_list=[0]):
if cls_tokens is None or 0 in k_list:
return scores_old
cls_tokens = np.array(cls_tokens) # 类型转换
# 对异常得分归一化,使其范围在[0, 1]之间
scores = (scores_old - scores_old.min()) / (scores_old.max() - scores_old.min())
# 与转置相乘,计算相似矩阵,输出[N, N]
similarity_matrix = cls_tokens @ cls_tokens.T
similarity_matrix = torch.tensor(similarity_matrix)
# 输入归一化得分和相似度矩阵
scores_new = MMO(similarity_matrix.clone().float(), score=torch.tensor(scores).clone().float(), k_list=k_list)
scores_new = scores_new.numpy()
return scores_new