Lite-HRNet: A Lightweight High-Resolution Network原理解析及代码分享

论文《Lite-HRNet: A Lightweight High-Resolution Network》提出者用于人体姿态估计,本文将其用于语义分割训练,损失无法下降,基本全无法提取,本文仅作为结构学习记录,若有读者找出原因请不吝指教。

实时语义分割之BiSeNetv2(2020)结构原理解析及建筑物提取实践

概要

Lite-HRNet是针对实时语义分割任务优化的轻量级网络,其核心思想是在高分辨率特征保持与计算效率之间取得平衡。相比原版HRNet,通过引入多分辨率特征交互机制与通道重参数化技术。
提示: (1)论文地址:https://arxiv.org/abs/2104.06403 (2)Github网址:https://github.com/HRNet/Lite-HRNet (3)本人打包好数据可运行代码:Lite-HRNet.zip 链接: https://pan.baidu.com/s/1ZJpGchNDfLMosCOBNKt65A?pwd=v7yf 提取码: v7yf

理论知识

整体架构流程

大概流程
LiteHRNet模型结构示意如下图2。灰色块状部分为stage2-4,它由一个高分辨率的主干作为第一阶段,逐渐添加高到低分辨率流作为主体。主体包含一系列阶段,每个阶段包含并行的多分辨率流和重复的多分辨率融合。
在这里插入图片描述

class LiteHRNet(nn.Module):
    def __init__(self, num_class=1, n_channel=3, base_ch=40, arch_type='litehrnet18',  repeat=2, act_type='relu'):

        ......
        self.stem = nn.Sequential(
                        ConvBNAct(n_channel, 32, 3, 2, act_type=act_type),
                        ShuffleBlock(32, base_ch, 2, act_type)
                    )
        self.stage1_down = DSConvBNAct(base_ch, base_ch*2, 3, 2, act_type=act_type)
        self.stage2 = StageBlock(base_ch, 2, repeat, num_modules[0], act_type)
        self.stage3 = StageBlock(base_ch, 3, repeat, num_modules[1], act_type)
        self.stage4 = StageBlock(base_ch, 4, repeat, num_modules[2], act_type)
        self.rep_head = RepresentationHead(base_ch, num_class, 4, act_type)

    def forward(self, x):
        x = self.stem(x)
        x2 = self.stage1_down(x)
        feats = [x, x2]
        feats = self.stage2(feats)
        feats = self.stage3(feats)
        feats = self.stage4(feats)
        x = self.rep_head(feats)
        ......
        return x


关键模块解析

模块名称功能描述结构在代码定义参数量占比
Stem模块初始特征提取与通道扩展class ShuffleBlock8%
StageBlock多分辨率特征交互核心单元class StageBlock65%
CrossResolutionWeight跨分辨率特征权重生成class CrossResolutionWeightModule12%
CCWBlock通道-空间联合注意力计算class CCWBlock10%
FusionBlock多分支特征融合class FusionBlock5%

以下关键模块解析时只取形,舍去具体操作。

Stem模块:轻量级数据输入设计(ShuffleBlock)

class ShuffleBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride, act_type):
        # 通道分割:将输入分为左右两部分
        in_ch_l = in_ch//2
        self.left_branch = ConvBNAct(in_ch_l, out_ch//2, 1, stride) 
        self.right_branch = nn.Sequential(
            ConvBNAct(in_ch-in_ch_l, out_ch//2, 1),
            DWConvBNAct(..., stride),  # 深度可分离卷积
            ConvBNAct(...)
        )

设计:

  • 通道分割策略:将输入通道对半拆分,左分支保持原分辨率,右分支进行深度卷积下采样,减少50%计算量。

  • 通道混洗(Channel Shuffle):促进左右分支间的信息交互,避免特征僵化。

  • 深度可分离卷积:相比标准卷积减少8倍参数量。

StageBlock:多分辨率特征交互

class StageBlock(nn.Module):
    def __init__(self, base_ch, stage, repeat, num_modules, act_type):
        super().__init__()
            crw_module = CrossResolutionWeightModule(crw_ch, act_type)
            ccw_block = nn.ModuleList([CCWBlock(chs[j], chs[j], 1, act_type) for _ in range(repeat)])
            # Fusion Block
            fusion_block = FusionBlock(base_ch, stage, extra_output, act_type)
            self.stage_blocks.append(crw_module)
            self.stage_blocks.append(ccw_blocks)
            self.stage_blocks.append(fusion_block)
            ....


    def forward(self, feats):
        for i in range(len(self.stage_blocks) // 3):
            crw_module = self.stage_blocks[i*3]
            ccw_blocks = self.stage_blocks[i*3+1]
            fusion_block = self.stage_blocks[i*3+2]
            cr_weight = crw_module(feats)
            for j, ccw_block in enumerate(ccw_blocks):
                for m in ccw_block:
                    feats[j] = m(feats[j], cr_weight[j])
            feats = fusion_block(feats)
        return feats   

特征提取时,StageBlock 类依次通过3个模块进行特征处理。

  • 跨分辨率权重生成: 通过 CrossResolutionWeightModule 生成跨分辨率权重。
  • 特征处理: 对于每个阶段的 CCWBlock,使用生成的权重处理输入特征。
  • 特征融合: 最后,通过融合块fusion_block将处理后的特征进行合并。

CCWBlock:双路特征增强单元

class CCWBlock(nn.Module):
    def forward(self, feats, cr_weight):
        # 特征分割
        feats_l, feats_r = split(feats)  
        
        # 左分支:恒等映射或1x1卷积
        feats_l = self.left_branch(feats_l)  
        
        # 右分支:空间权重调制
        feats_r = feats_r * cr_weight  # 跨分辨率权重
        feats_r = self.right_branch(feats_r)  
        spatial_weight = self.sw(feats_r)  # 空间注意力
        feats_r = feats_r * spatial_weight
  
        return shuffle(concat(feats_l, feats_r))

设计:

  • 跨分辨率权重融合:接收来自CrossResolutionWeight模块的权重图,实现多尺度特征校准。

  • 空间注意力机制:通过SpatialWeightModule动态调整特征图各位置的重要性。

  • 双路异构设计:左路保持特征完整性,右路增强局部细节,兼顾效率与精度。

CrossResolutionWeight:多尺度关联模块

class CrossResolutionWeightModule(nn.Module):
    def forward(self, feats):
        # 多分辨率特征池化
        pooled_feats = [adaptive_pool(feat) for feat in feats]  
        concat_feat = torch.cat(pooled_feats, dim=1)
        
        # 权重生成
        weight = self.conv(concat_feat)  # 1x1卷积+sigmoid
        return split(weight)

设计:

  • 自适应池化层:将不同分辨率的特征图统一到相同尺度。

  • 通道压缩:采用8倍通道缩减比,生成紧凑的权重表示。

  • 动态权重分配:为每个分辨率分支生成独立的注意力图。

FusionBlock:多流特征融合


class FusionBlock(nn.Module):
    def __init__(self, base_ch, stage, extra_output, act_type):
        # 构建上采样和下采样流
        self.stream1 = [DownsampleBlock(...) for _ in ...]  
        self.stream2 = [UpsampleBlock(...) for _ in ...]
        
    def forward(self, feats):
        # 多分辨率特征融合
        x1 = stream1[0](feats[0]) + stream2[0](feats[1])
        x2 = stream1[1](feats[0]) + stream2[1](feats[1]) 
        ...

设计:

  • 双向特征流:包含上采样(UpsampleBlock)和下采样(DownsampleBlock)两条通路。

  • 残差连接结构:通过Element-wise Add实现多尺度特征融合。

  • 可扩展架构:支持2-4个输入分支的动态融合。

模型实践

训练数据准备

提示:云盘代码已内置少量高分二号卫星影像建筑物数据集

训练数据分为原始影像和标签(二值化,0和255),均位于Sample文件夹内,本示例数据为尺度不一致,将在训练数据导入时批量规范为256*256,数据相对路径为:

Sample\build\train\ IMG_T1
------------------\ IMG_LABEL
-------------\val \ IMG_T1
------------------\IMG_LABEL

模型训练

运行dp0_train.py,模型开始训练,核心参数包括:

parser参数说明
num_epochs训练批次
learning_rate初始学习率
batch_size单次样本数量
dataset数据集名字
crop_height训练时影像重采样尺度

数据结构CDDataset_Seg定义在utils文件夹dataset.py中,注意读取后进行了数据增强(随机翻转),灰度化,尺寸调整,标签归一化、 one-hot 编码,以及维度和数据类型的转换,最终得到适用于 PyTorch 模型训练的张量。

        # 读取训练图片和标签图片
        image_t1 = cv2.imread(image_t1_path,-1)
        #image_t2 = cv2.imread(image_t2_path)
        label = cv2.imread(label_path)
       
        # 随机进行数据增强,为2时不做处理
        if self.data_augment:
            flipCode = random.choice([-1, 0, 1, 2])
            if flipCode != 2:
#                image_t1 = normalized(image_t1, 'tif')
                image_t1 = self.augment(image_t1, flipCode)
                #image_t2 = self.augment(image_t2, flipCode)
                label = self.augment(label, flipCode)        
            
        label = cv2.cvtColor(label, cv2.COLOR_BGR2GRAY)
        
        image_t1 = cv2.resize(image_t1, (self.img_h, self.img_w))
        #image_t2 = cv2.resize(image_t2, (Config.img_h, Config.img_w))

        label = cv2.resize(label, (self.img_h, self.img_w))
        label = label/255
        label = label.astype('uint8')
        label = onehot(label, 2)
        label = label.transpose(2, 0, 1)
        label = torch.FloatTensor(label)

训练过程如下图所示,模型保存至checkpoints数据集同名文件夹内。

影像测试

运行dp0_AllPre.py,核心参数包括:

parser参数说明
Checkpointspath预训练模型位置名称
Dataset批量化预测数据文件夹
Outputpath输出数据文件夹

数据加载方式:

    pre_dataset = CDDataset_Pre(data_path=pre_imgpath1,
                                img_h=Config.img_h, img_w=Config.img_w,
                                transform=transforms.Compose([
                                transforms.ToTensor()]))

    pre_dataloader = torch.utils.data.DataLoader(dataset=pre_dataset,
                                             batch_size=1,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=0)

需要注意,预测定义的数据结构CDDataset_Pre与训练时CDDataset_Seg略有区别,主要是不进行label处理。

结果示例

效果太差,理论上不应该,LiteHRNet结构中的分割头

        x = self.rep_head(feats)
        x = F.interpolate(x, size, mode='bilinear', align_corners=True)

第一个x从rep_head输出, 尺度为128*128,后续有空考虑换分割头测试。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值