python 复现 Unet 论文中的 Weight Map

 

相信大家对 Unet 论文中的 weight map 并不陌生。但是我翻遍 github上的几个 Unet 经典实现(包括论文开源的 caffe 版本),都没有找到损失加权图的实现代码。我今天尝试着实现了一下,得到了与论文中几乎一致的结果,分享给大家。转载请在文首注明出处,尊重原创,谢谢

 

作 者: 月牙眼的楼下小黑
联 系: zhanglf_tmac (WeChat)
声 明: 欢迎转载本文中的图片或文字,请说明出处


1. 实验图片的获取和加载

我没有论文中的数据集,我只是简单地从 pdf 中截取了示例图,并用 win10 自带的画图软件另存为单通道 bmp 文件。

实验图片

 

我习惯用 skimage 库的 io 函数加载图片:

from skimage import io
import matplotlib.pyplot as plt

gt = io.imread('bin_img.bmp')
gt = 1 * (gt >0)
plt.figure(figsize = (10,10))
plt.imshow(gt)
plt.show()

实验图的加载

2. 获得 class weight map

根据 unet 论文, class_weight_map 的作用是: to balance the class frequencies。 我们分别统计前景(即细胞)和 背景的像素数目 , 取倒数并归一化作为 class map。 这只是一种简单的做法,读者亦可设计更复杂的类别平衡权重计算机制。

import numpy as np

# 【1】计算细胞和背景的像素频率
c_weights = np.zeros(2)
c_weights[0] = 1.0 / ((gt == 0).sum())
c_weights[1] = 1.0 / ((gt == 1).sum())

# 【2】归一化
c_weights /= c_weights.max()

# 【3】得到 class_weight map(cw_map)
cw_map = np.where(gt==0, c_weights[0], c_weights[1])
plt.figure(figsize = (10,10))
plt.imshow(cw_map)
plt.show()

class_weight_map

3. 连通域分析

对每个实例(即每个细胞),我们需要获取只包含该实例的掩码。例如,如果一幅图像中有 2个细胞: 细胞 a, 细胞 b, 我们需要获得只包含细胞 a 的二值掩码图 msk1, 以及只包含细胞 b 的二值掩码图 msk2。所以我们需要对原图像进行连通域分析。

我们调用 skimage.measure下的 label函数实现二值图像的连通区域标记:

skimage.measure.label(image, connectivity= None) : connectivity=1代表4邻接, 2代表 8邻接

该函数会返回跟 image同等大小的标记矩阵,不同值的数字代表不同的连通域。

from skimage import measure,color

#【4】连通域分析,并彩色化
cells = measure.label(gt, connectivity=2)
cells_color = color.label2rgb(cells , bg_label = 0,  bg_color = (0, 0, 0)) 
plt.figure(figsize = (20,20))
plt.imshow(cells_color)
plt.show()

连通域分析并彩色化

4. 得到 distance_weight_map

解释代码是一件非常困难的事,有必要把公式再贴一遍,请读者对照代码和公式理解:

 

对一幅图像,我们可以得到一个二值掩码集合 :

{ msk1, msk2, ...mski,...mskn }

每个掩码只包含一个细胞 。 我们对每个二值掩码进行 距离变化, 使用 opencv 中的 distanceTransform 函数即可轻易实现。我们得到一个距离变换后的浮点图集合:

{dist1, dist2, ..., disti,...distn }, 其中 disti = distanceTransform(mski)

我们要求的 d1_map, 即公式中的 d1(x), 在 (i,j) 处的值为 :

d1_map(i,j) = min (dist1(i,j), dist2(i,j), ...distn(i,j))

笔者不知道求一个集合中 次最小 的数学运算符是什么,暂且用 mins 代替吧, 即对一个集合 amins(a) = min (a - {min(a)})

那么 d2_map, 即公式中的 d2(x), 在 (i,j) 处的值为 :

d2_map(i,j) = mins (dist1(i,j), dist2(i,j), ...distn(i,j))

得到 d1_map 和 d2_map, 我们不难获得对应公式中的第二项的 distance_weight_map

import cv2

#【5】计算得到 distance weight map (dw_map)
w0 = 10
sigma = 5
dw_map = np.zeros_like(gt)
maps = np.zeros((gt.shape[0], gt.shape[1], cells.max()))
if cells.max()>=2:
    for i in range(1, cells.max() + 1):
        maps[:,:,i-1] =  cv2.distanceTransform(1- (cells == i ).astype(np.uint8), cv2.DIST_L2, 3)
    maps = np.sort(maps, axis = 2)
    d1 = maps[:,:,0]
    d2 = maps[:,:,1]
    dis = ((d1 + d2)**2) / (2 * sigma * sigma)
    dw_map = w0*np.exp(-dis) * (cells == 0)
plt.figure(figsize = (10,10))
plt.imshow(dw_map, cmap = 'jet')
plt.show()
<think> 我们正在处理关于“weight reorder”的用户查询。用户希望了解其含义、应用及相关技术。 首先,我们需要进行网络搜索来收集有关“weight reorder”的信息。注意:用户可能指的是“weight reordering”或类似术语,但根据查询,我们将搜索“weight reorder”。 步骤: 1. 搜索“weight reorder meaning”或类似关键词,以了解其基本定义。 2. 搜索“weight reorder applications”来找出其应用场景。 3. 搜索“weight reorder techniques”或“weight reorder methods”来收集相关技术。 由于“weight reorder”可能是一个特定领域的术语,我们需要在深度学习、模型压缩或类似上下文中查找。 根据搜索,我们发现“weight reorder”通常与神经网络模型压缩和优化技术相关,特别是在模型部署和硬件加速方面。它可能指重新排列神经网络中的权重以改善内存访问模式、提高缓存利用率或方便特定硬件指令的处理。 具体信息整合如下: 含义: Weight reorder(权重重排序)是一种模型优化技术,通过改变神经网络权重在内存中的排列顺序,以提升推理效率。这种优化通常针对特定硬件架构,旨在减少内存访问延迟、提高数据局部性,从而加速计算。常见于卷积神经网络(CNN)中,其中权重张量的维度(例如输出通道、输入通道、高度、宽度)可以被重新排列。 应用: - 硬件加速:适配专用硬件(如NPU、GPU)的内存访问模式,减少缓存未命中。 - 模型压缩:与量化、剪枝等技术结合,在压缩模型的同时保持或提升推理速度。 - 推理优化:在移动端和边缘设备上部署模型时,通过权重重排序减少延迟和能耗。 相关技术: 1. **数据布局转换(Data Layout Transformation)**:将权重从NCHW格式转换为NHWC或其他格式,以适应硬件要求。 2. **权重量化(Weight Quantization)**:将浮点权重转换为低精度表示(如INT8),重排序可与量化协同优化。 3. **权重置换(Weight Permutation)**:改变权重矩阵的行列顺序,以改善矩阵乘法的计算效率。 4. **稀疏权重重新排序(Sparse Weight Reordering)**:针对稀疏模型(如剪枝后的模型),重新排列非零权重的位置以提升压缩率和计算效率。 示例:在TensorRT中,权重重排序被用于优化卷积层的执行。具体步骤包括对权重进行预处理,使其在内存中按特定顺序排列,以便在推理时实现更高效的访存。 注意:我们可能还会发现一些具体的实现方式或算法,例如在论文或开源库中。 下面,我们将按照要求生成答案,避免使用第一人称和步骤词汇,将答案组织为多个部分(含义、应用、相关技术),并在最后添加相关问题部分。 回答结构: 含义:... 应用:... 相关技术:... §§相关问题§§: ... </think> ### Weight Reorder的含义 Weight reorder(权重重排序)是一种神经网络模型优化技术,通过调整权重参数的存储顺序或排列方式,提升计算效率和硬件兼容性。其核心目标是: - **改善内存访问模式**:重新组织权重数据布局(如将NCHW格式转为NHWC),匹配硬件(如GPU/NPU)的内存连续访问特性,减少缓存未命中。 - **加速矩阵运算**:优化权重矩阵的稀疏性或分块结构,提高矩阵乘法(GEMM)的计算吞吐量。 - **压缩模型体积**:结合量化(如INT8)或剪枝技术,减少冗余权重存储空间。 --- ### Weight Reorder的应用场景 1. **深度学习推理加速** 在边缘设备(手机、IoT芯片)部署模型时,通过权重重排序适配硬件指令集(如ARM NEON),提升实时性。 **示例**:TensorRT对CNN权重进行通道重排,优化NVIDIA GPU的Tensor Core利用率。 2. **模型压缩与部署** 结合量化感知训练(QAT),将浮点权重重排序为低比特整数排列,降低存储需求。 **工具支持**:TVM、ONNX Runtime支持自动权重重排优化。 3. **稀疏模型优化** 对剪枝后的稀疏网络权重按非零元素密度重排序,提升压缩率(如CSR/CSC格式转换)。 --- ### 关键技术方法 #### 1. 数据布局转换(Data Layout Transformation) - 将权重张量维度从$N\times C\times H\times W$(NCHW)转换为$N\times H\times W\times C$(NHWC),匹配硬件内存对齐。 - **代码示例(PyTorch)**: ```python import torch weight = torch.randn(32, 3, 3, 3) # NCHW格式 reordered_weight = weight.permute(0, 2, 3, 1) # 转为NHWC ``` #### 2. 权重置换(Weight Permutation) - 利用置换矩阵$P$对全连接层权重$W$重排:$$W_{\text{new}} = P \cdot W \cdot P^T$$ - 优化目标:最大化权重矩阵对角线区块的稀疏性,减少计算量。 #### 3. 基于硬件的定制重排 - **GPU优化**:将卷积核权重按输出通道分组存储,适配warp级并行计算。 - **NPU优化**:根据存算一体架构特征,调整权重位宽排列顺序(如华为昇腾的权重交织技术)。 #### 4. 联合量化重排序 - 在量化过程中同步优化权重排列: ```python # 伪代码:量化感知重排序 quantized_weight = quantize(float_weight, scale) reordered_indices = find_optimal_permutation(quantized_weight) output = reorder(quantized_weight, reordered_indices) ``` ---
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值