channel_shuffle代码实现

本文详细介绍了如何在Pytorch中实现ShuffleNetv2模型中的关键层ChannelShuffle,包括图像通道拆分、GConv1处理和Feature合并,以及最终的reshape和transpose操作。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

 结构图,先将输入的图像进行通道拆分为组GConv1,每个GConv1再拆分Feature,每个GConv1的Feature进行合并GConv2,输出Output

输入图像x,拆分为groups个组,每隔组的通道数为channels_per_group 

batch_size, num_channels, height, width = x.size()
channels_per_group = num_channels // groups

 进行变换

# reshape
# [batch_size, num_channels, height, width] -> [batch_size, groups, channels_per_group, height, width]
x = x.view(batch_size, groups, channels_per_group, height, width)

 

再将1和2的维度进行调换 ,就实现了Feature到GConv2

x = torch.transpose(x, 1, 2).contiguous()

 

全部代码

def channel_shuffle(x: Tensor, groups: int) -> Tensor:

    batch_size, num_channels, height, width = x.size()
    channels_per_group = num_channels // groups

    # reshape
    # [batch_size, num_channels, height, width] -> [batch_size, groups, channels_per_group, height, width]
    x = x.view(batch_size, groups, channels_per_group, height, width)

    x = torch.transpose(x, 1, 2).contiguous()

    # flatten
    x = x.view(batch_size, -1, height, width)

    return x

参考视频:

8.2 使用Pytorch搭建ShuffleNetv2_哔哩哔哩_bilibili

### Channel Split 和 Channel Shuffle 的概念 Channel Split 是指将输入张量按照通道维度分成两个或多个子集的操作。这种操作通常用于设计轻量化神经网络架构,通过减少计算复杂度来提高模型效率。 Channel Shuffle 则是在执行分组卷积之后重新排列通道顺序的过程[^1]。具体来说,当完成一组独立的分组卷积运算后,为了使后续层能充分利用之前各组的信息,需要对这些经过不同路径处理过的特征图进行混合。这一步骤有助于打破因过度分离而导致的信息孤岛现象,从而增强跨组间信息交流的有效性。 ### 应用场景与实现方式 #### Channel Split 实现 在 TensorFlow 中可以通过 `tf.split` 函数轻松实现这一功能: ```python import tensorflow as tf def channel_split(x, num_splits=2): channels = x.shape[-1] split_channels = [channels // num_splits] * (num_splits - 1) + [channels - (channels // num_splits) * (num_splits - 1)] return tf.split(x, split_channels, axis=-1) input_tensor = tf.random.normal((1, 56, 56, 64)) split_tensors = channel_split(input_tensor) print([t.shape for t in split_tensors]) ``` 这段代码展示了如何把一个形状为 `(batch_size, height, width, depth)` 的四维张量按最后一个轴(即深度方向)均匀分割成两部分。 #### Channel Shuffle 实现 对于 Channel Shuffle,则可通过先 reshape 后 transpose 来达到目的: ```python def channel_shuffle(x, groups): batch_size, height, width, num_channels = x.shape channels_per_group = num_channels // groups # Reshape to group the channels together. x_reshaped = tf.reshape(x, [-1, height, width, groups, channels_per_group]) # Transpose and then flatten back out again. x_transposed = tf.transpose(x_reshaped, perm=[0, 1, 2, 4, 3]) output_shape = (-1, height, width, num_channels) result = tf.reshape(x_transposed, shape=output_shape) return result shuffled_tensor = channel_shuffle(split_tensors[0], groups=2) print(shuffled_tensor.shape) ``` 此函数接收一个张量以及指定要划分多少个小组作为参数,并返回一个新的具有相同大小但内部数据已被打乱分布的新张量。 ### 主要区别 - **作用对象**: Channel Split 关注的是如何合理分配资源给不同的分支去并行工作;而 Channel Shuffle 更侧重于解决由于上述过程可能引发的信息隔离问题。 - **应用场景**: 前者多见于构建高效的瓶颈结构中,后者则是为了让各个支路之间更好地共享学到的知识,在某些特定类型的网络比如 ShuffleNet V1/V2 中被广泛应用。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值