MuSc代码阅读记录

MUSC: ZERO-SHOT INDUSTRIAL ANOMALY CLASSIFICATION AND SEGMENTATION WITH MUTUAL SCORING OF THE UNLABELED IMAGES

代码

GitHub - xrli-U/MuSc: This is an official PyTorch implementation for "MuSc : Zero-Shot Industrial Anomaly Classification and Segmentation with Mutual Scoring of the Unlabeled Images" (MuSc ICLR2024).

1. 特征提取

使用clip特征提取器ViT-L-14-336,提取全局特征和局部特征。

if isinstance(image_info, dict):
    image = image_info["image"]
    image_path_list.extend(image_info["image_path"])
    img_masks.append(image_info["mask"])
    gt_list.extend(list(image_info["is_anomaly"].numpy()))
with torch.no_grad(), torch.cuda.amp.autocast():
    input_image = image.to(torch.float).to(self.device)
    if 'dinov2' in self.model_name:      
        patch_tokens = self.dino_model.get_intermediate_layers(x=input_image, n=[l-1 for l in self.features_list], return_class_token=False)
        image_features = self.dino_model(input_image)
        patch_tokens = [patch_tokens[l].cpu() for l in range(len(self.features_list))]
        fake_cls = [torch.zeros_like(p)[:, 0:1, :] for p in patch_tokens]
        patch_tokens = [torch.cat([fake_cls[i], patch_tokens[i]], dim=1) for i in range(len(patch_tokens))]
    elif 'dino' in self.model_name:
        patch_tokens_all = self.dino_model.get_intermediate_layers(x=input_image, n=max(self.features_list))
        image_features = self.dino_model(input_image)
        patch_tokens = [patch_tokens_all[l-1].cpu() for l in self.features_list]
    else: # 本文使用的是clip特征提取器
        # 使用CLIP特征提取器,返回第一项为图像特征,第二项为局部特征
        image_features, patch_tokens = self.clip_model.encode_image(input_image, self.features_list)                        
        # 对全局特征最后一个维度计算L2范数,并进行归一化
        image_features /= image_features.norm(dim=-1, keepdim=True)
        # 局部特征,形状为batch_size x 1370 x 1024
        patch_tokens = [patch_tokens[l].cpu() for l in range(len(self.features_list))]
# 将特征转换为列表tensor([4, 769]) -> [numpy(768), numpy(768), numpy(768), numpy(768)]
image_features = [image_features[bi].squeeze().cpu().numpy() for bi in range(image_features.shape[0])]
# 图片的全局特征,长度为N图片数,元素为numpy张量(768)
class_tokens.extend(image_features)
# 图片的局部特征,长度为N/batch_size,每个元素为列表(长度为batch_size,每个元素为tensor张量[4, 1370, 1024],代表四个不同尺度的特征)
patch_tokens_list.append(patch_tokens)

2.LNAMD

分三个聚合度{1,3,5}进行特征提取。

# LNAMD
# 特征维度,1024
feature_dim = patch_tokens_list[0][0].shape[-1]
anomaly_maps_r = torch.tensor([]).double()
# r_list聚合度列表[1, 3, 5]
for r in self.r_list:
    start_time = time.time()
    print('aggregation degree: {}'.format(r))
    LNAMD_r = LNAMD(device=self.device, r=r, feature_dim=feature_dim, feature_layer=self.features_list)
    # 创建空字典,键是层级索引,值是该层的所有特征列表
    Z_layers = {}
    for im in range(len(patch_tokens_list)):
        # 获取特征块,大小为3 x 1370 x 1024
        patch_tokens = [p.to(self.device) for p in patch_tokens_list[im]]
        with torch.no_grad(), torch.cuda.amp.autocast():
            # 聚合模块,聚合后特征为 [3, 1369, 4, 1024] 去掉了类别维度
            features = LNAMD_r._embed(patch_tokens)
            features /= features.norm(dim=-1, keepdim=True)
            for l in range(len(self.features_list)):
                # 存储聚合后的特征
                if str(l) not in Z_layers.keys():
                    Z_layers[str(l)] = []
                Z_layers[str(l)].append(features[:, :, l, :])
                # 最终输出的Z有0123四个键,值为长度为21的特征列表,特征形状为4 x 1369 x 1024
    end_time = time.time()
    print('LNAMD-{}: {}ms per image'.format(r, (end_time-start_time)*1000/subset_num))

LNAMD实现,对聚合度不为1的特征分块处理。 

class LNAMD(torch.nn.Module):
    def __init__(self, device, feature_dim=1024, feature_layer=[1,2,3,4], r=3, patchstride=1):
        super(LNAMD, self).__init__()
        self.device = device
        self.r = r # 聚合度
        self.patch_maker = PatchMaker(r, stride=patchstride) # 切块函数
        self.LNA = Preprocessing(feature_layer, feature_dim)

    def _embed(self, features):
        B = features[0].shape[0]

        features_layers = []
        # 对特征进行形状变换
        for feature in features:
            # reshape and layer normalization
            feature = feature[:, 1:, :] # remove the cls token
            feature = feature.reshape(feature.shape[0],
                                      int(math.sqrt(feature.shape[1])),
                                      int(math.sqrt(feature.shape[1])),
                                      feature.shape[2])
            feature = feature.permute(0, 3, 1, 2)
            feature = torch.nn.LayerNorm([feature.shape[1], feature.shape[2],
                                          feature.shape[3]]).to(self.device)(feature)
            features_layers.append(feature)

        if self.r != 1: # 聚合度为1则不需要分块
            # divide into patches
            features_layers = [self.patch_maker.patchify(x, return_spatial_info=True) for x in features_layers] # 切块
            patch_shapes = [x[1] for x in features_layers] # 存储特征块的高宽信息
            features_layers = [x[0] for x in features_layers] # 存储分块后的特征
        else:
            patch_shapes = [f.shape[-2:] for f in features_layers] # 长度为batch,每个元素为块的wxh,存储每个块的高宽信息
            # 存储分块后的特征
            features_layers = [f.reshape(f.shape[0], f.shape[1], -1, 1, 1).permute(0, 2, 1, 3, 4) for f in features_layers]

        ref_num_patches = patch_shapes[0] # 将第一层特征分辨率设置为参考分辨率
        for i in range(1, len(features_layers)):
            patch_dims = patch_shapes[i] # 获取当前层特征分辨率
            if patch_dims[0] == ref_num_patches[0] and patch_dims[1] == ref_num_patches[1]:
                continue # 如果和参考分辨率一致则不需要处理
            # 插值对齐分辨率
            _features = features_layers[i]
            _features = _features.reshape(
                _features.shape[0], patch_dims[0], patch_dims[1], *_features.shape[2:]
            )
            _features = _features.permute(0, -3, -2, -1, 1, 2)
            perm_base_shape = _features.shape
            _features = _features.reshape(-1, *_features.shape[-2:])
            _features = F.interpolate(
                _features.unsqueeze(1),
                size=(ref_num_patches[0], ref_num_patches[1]),
                mode="bilinear",
                align_corners=False,
            )
            _features = _features.squeeze(1)
            _features = _features.reshape(
                *perm_base_shape[:-2], ref_num_patches[0], ref_num_patches[1]
            )
            _features = _features.permute(0, -2, -1, 1, 2, 3)
            _features = _features.reshape(len(_features), -1, *_features.shape[-3:])
            features_layers[i] = _features
        features_layers = [x.reshape(-1, *x.shape[-3:]) for x in features_layers]
        
        # aggregation
        features_layers = self.LNA(features_layers) # 将特征采样到同一分辨率
        features_layers = features_layers.reshape(B, -1, *features_layers.shape[-2:])   # (B, L, layer, C)

        return features_layers.detach().cpu()

LNA

用于统一特征分辨率

class Preprocessing(torch.nn.Module):
    def __init__(self, input_layers, output_dim):
        super(Preprocessing, self).__init__()
        self.output_dim = output_dim
        self.preprocessing_modules = torch.nn.ModuleList()
        # 根据特征层数添加平均层
        for input_layer in input_layers:
            module = MeanMapper(output_dim)
            self.preprocessing_modules.append(module)

    def forward(self, features):
        # 对特征层进行下采样 [batchx1369, 1024, 5, 5] -> [batchx1369, 1024]
        _features = []
        for module, feature in zip(self.preprocessing_modules, features):
            _features.append(module(feature))
        return torch.stack(_features, dim=1)


class MeanMapper(torch.nn.Module):
    def __init__(self, preprocessing_dim):
        super(MeanMapper, self).__init__()
        self.preprocessing_dim = preprocessing_dim

    def forward(self, features):
        features = features.reshape(len(features), 1, -1)
        return F.adaptive_avg_pool1d(features, self.preprocessing_dim).squeeze(1) # 平均池化

3. MSM

通过计算图像块对其他图像的所有图像块的距离获得异常得分。

    anomaly_maps_l = torch.tensor([]).double()
    start_time = time.time()
    # 遍历所有层的特征
    for l in Z_layers.keys():
        # 将第l层上所有特征拼接,形状为[83, 1369, 1024]
        Z = torch.cat(Z_layers[l], dim=0).to(self.device)
        print('layer-{} mutual scoring...'.format(l))
        anomaly_maps_msm = MSM(Z=Z, device=self.device, topmin_min=0, topmin_max=0.3)
        # 将第l层计算的异常图加入最终异常图中
        anomaly_maps_l = torch.cat((anomaly_maps_l, anomaly_maps_msm.unsqueeze(0).cpu()), dim=0)
        torch.cuda.empty_cache()
    # 计算所有层异常得分图的平均 [N, patch_num]
    anomaly_maps_l = torch.mean(anomaly_maps_l, 0)
    # 将所有聚合度计算得到的异常图拼接 [r, N, patch_num]
    anomaly_maps_r = torch.cat((anomaly_maps_r, anomaly_maps_l.unsqueeze(0)), dim=0)
    end_time = time.time()
    print('MSM: {}ms per image'.format((end_time-start_time)*1000/subset_num))
# 拼接所有聚合度的异常图
anomaly_maps_iter = torch.mean(anomaly_maps_r, 0).to(self.device)
del anomaly_maps_r
torch.cuda.empty_cache()

B, L = anomaly_maps_iter.shape
H = int(np.sqrt(L))
# 对特征计算得到的异常图上采样到输入图像的分辨率
anomaly_maps_iter = F.interpolate(anomaly_maps_iter.view(B, 1, H, H),
                            size=self.image_size, mode='bilinear', align_corners=True)
anomaly_maps = torch.cat((anomaly_maps, anomaly_maps_iter.cpu()), dim=0)

源码中给出了两种实现方式。

首先计算图像块到所有图像(除本身)的所有图像块的距离,输出num_patch x (N-1) x num_patch。

接着取图像块到所有图像的最相似图像块的距离,输出num_patch x (N-1)。

只取距离在前30%的图像,输出num_patch x 0.3(N-1)。

取到30%图像距离的均值,作为每个图像块的异常得分,输出num_patch。

def compute_scores_fast(Z, i, device, topmin_min=0, topmin_max=0.3):
    image_num, patch_num, c = Z.shape
    patch2image = torch.tensor([]).to(device)
    Z_ref = torch.cat((Z[:i], Z[i+1:]), dim=0) # 将要计算的第i张图特征排除
    # 计算图像块到所有其他图像的图像块的距离,输出[patch_num, N-1, patch_num]
    patch2image = torch.cdist(Z[i:i+1], Z_ref.reshape(-1, c)).reshape(patch_num, image_num-1, patch_num)
    # 取图像块到一幅图像中所有图像块距离的最小值,输出[patch_num, N-1]
    patch2image = torch.min(patch2image, -1)[0]
    # 设置取值范围
    k_max = topmin_max
    k_min = topmin_min
    if k_max < 1:
        k_max = int(patch2image.shape[1]*k_max)
    if k_min < 1:
        k_min = int(patch2image.shape[1]*k_min)
    if k_max < k_min:
        k_max, k_min = k_min, k_max
    # 取异常得分升序排列前30%的块,输出[patch_num, (N-1) x 0.3]
    vals, _ = torch.topk(patch2image.float(), k_max, largest=False, sorted=True)
    vals, _ = torch.topk(vals.float(), k_max-k_min, largest=True, sorted=True)
    patch2image = vals.clone()
    # 取平均,结果为每个块的异常得分,输出[patch_num]
    return torch.mean(patch2image, dim=1)

def MSM(Z, device, topmin_min=0, topmin_max=0.3):
    anomaly_scores_matrix = torch.tensor([]).double().to(device)
    for i in tqdm(range(Z.shape[0])): # Z[N x feature_dim x wh]
        # 对于一幅图像,计算其所有块的异常得分,输出[patch_num]
        anomaly_scores_i = compute_scores_fast(Z, i, device, topmin_min, topmin_max).unsqueeze(0)
        # 将所有图像计算的异常得分拼接,输出[N, patch_num]
        anomaly_scores_matrix = torch.cat((anomaly_scores_matrix, anomaly_scores_i.double()), dim=0)    # (N, B)
    return anomaly_scores_matrix

4.RsCIN

计算异常得分,优化异常得分。

代码中的class_tokens取自特征提取器中所提取的特征。

B = anomaly_maps.shape[0]
# 对异常图取最大值计算异常得分 [N]
ac_score = np.array(anomaly_maps).reshape(B, -1).max(-1)
# RsCIN根据数据集设置超参数
if self.dataset == 'visa':
    k_score = [1, 8, 9]
elif self.dataset == 'mvtec_ad':
    k_score = [1, 2, 3]
else:
    k_score = [1, 2, 3]
# 优化异常得分,class_tokens为特征提取过程中的图像特征(长度为N的列表,特征维度为768), 输出 [N]
scores_cls = RsCIN(ac_score, class_tokens, k_list=k_score)

降低异常样本的得分,从而区分正常和异常样本。 

def MMO(W, score, k_list=[1, 2, 3]):
    S_list = []
    for k in k_list:
        # 将相似矩阵W中最小的W.shape[0]-k个值置0
        _, topk_matrix = torch.topk(W.float(), W.shape[0]-k, largest=False, sorted=True)
        W_mask = W.clone()
        for i in range(W.shape[0]):
            W_mask[i, topk_matrix[i]] = 0
        n = W.shape[-1]
        # 计算对角矩阵和概率转移矩阵
        D_ = torch.zeros_like(W).float()
        for i in range(n):
            # 计算每一行的权重和,将其倒数存储在对角线上,用于对样本权重进行缩放
            D_[i, i] = 1 / (W_mask[i,:].sum())
        P = D_ @ W_mask # 将每一行归一化为概率分布
        S = score.clone().unsqueeze(-1)
        # 通过矩阵乘法更新异常得分
        S = P @ S
        S_list.append(S)
    S = torch.concat(S_list, -1).mean(-1)
    return S

def RsCIN(scores_old, cls_tokens=None, k_list=[0]):
    if cls_tokens is None or 0 in k_list:
        return scores_old
    cls_tokens = np.array(cls_tokens) # 类型转换
    # 对异常得分归一化,使其范围在[0, 1]之间
    scores = (scores_old - scores_old.min()) / (scores_old.max() - scores_old.min())
    # 与转置相乘,计算相似矩阵,输出[N, N]
    similarity_matrix = cls_tokens @ cls_tokens.T
    similarity_matrix = torch.tensor(similarity_matrix)
    # 输入归一化得分和相似度矩阵
    scores_new = MMO(similarity_matrix.clone().float(), score=torch.tensor(scores).clone().float(), k_list=k_list)
    scores_new = scores_new.numpy()
    return scores_new

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值