import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Union, Tuple
class ConvTranspose3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
output_padding: Union[int, Tuple[int, int, int]] = 0,
groups: int = 1,
bias: bool = True,
dilation: Union[int, Tuple[int, int, int]] = 1,
padding_mode: str = 'zeros'
):
super().__init__()
if padding_mode != 'zeros':
raise NotImplementedError("Only 'zeros' padding mode is supported")
# 参数规范化
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = self._triple(kernel_size)
self.stride = self._triple(stride)
self.padding = self._triple(padding)
self.output_padding = self._triple(output_padding)
self.groups = groups
self.dilation = self._triple(dilation)
# 通道数校验
if in_channels % groups != 0:
raise ValueError(f"in_channels must be divisible by groups")
if out_channels % groups != 0:
raise ValueError(f"out_channels must be divisible by groups")
# 初始化参数
self.weight = nn.Parameter(torch.empty(
in_channels,
out_channels // groups,
self.kernel_size[0],
self.kernel_size[1],
self.kernel_size[2]
))
if bias:
self.bias = nn.Parameter(torch.empty(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def _triple(self, x: Union[int, Tuple[int, int, int]]) -> Tuple[int, int, int]:
"""将输入转换为三元组"""
if isinstance(x, int):
return (x, x, x)
elif isinstance(x, tuple) and len(x) == 3:
return x
else:
raise ValueError(f"Expected int or 3-element tuple, got {type(x)}")
def reset_parameters(self):
"""参数初始化"""
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 输入维度解析
N, C_in, D_in, H_in, W_in = x.shape
# Step 1: 上采样(插入零)
D_up = (D_in - 1) * self.stride[0] + 1
H_up = (H_in - 1) * self.stride[1] + 1
W_up = (W_in - 1) * self.stride[2] + 1
x_up = torch.zeros(N, C_in, D_up, H_up, W_up, dtype=x.dtype, device=x.device)
x_up[:, :, ::self.stride[0], ::self.stride[1], ::self.stride[2]] = x
# Step 2: 计算填充量
pad_d = (self.kernel_size[0] - 1) * self.dilation[0] - self.padding[0]
pad_h = (self.kernel_size[1] - 1) * self.dilation[1] - self.padding[1]
pad_w = (self.kernel_size[2] - 1) * self.dilation[2] - self.padding[2]
# 确保填充量非负
pad_d = max(0, pad_d)
pad_h = max(0, pad_h)
pad_w = max(0, pad_w)
# 分解为单侧填充(关键修复:扁平化元组)
pad_config = (
pad_w // 2, pad_w - pad_w // 2, # 宽度维度
pad_h // 2, pad_h - pad_h // 2, # 高度维度
pad_d // 2, pad_d - pad_d // 2 # 深度维度
)
x_pad = F.pad(x_up, pad_config)
# Step 3: 权重变换
G = self.groups
conv_weight = self.weight.view(
G,
self.in_channels // G,
self.out_channels // G,
*self.kernel_size
)
# 维度转置和通道重组
conv_weight = conv_weight.permute(0, 2, 1, 3, 4, 5).contiguous()
conv_weight = conv_weight.view(G * (self.out_channels // G),
self.in_channels // G,
*self.kernel_size)
# 卷积核翻转
conv_weight = conv_weight.flip(dims=[2, 3, 4])
# Step 4: 执行普通卷积
output = F.conv3d(
input=x_pad,
weight=conv_weight,
bias=None,
stride=1,
padding=0,
dilation=self.dilation,
groups=self.groups
)
# Step 5: 应用输出填充
output = F.pad(output, (
0, self.output_padding[2], # 宽度
0, self.output_padding[1], # 高度
0, self.output_padding[0] # 深度
))
# Step 6: 添加偏置
if self.bias is not None:
output += self.bias.view(1, -1, 1, 1, 1)
return output
# ------------------------ 测试代码 ------------------------
def run_test(test_name, **kwargs):
print(f"\n=== 测试案例: {test_name} ===")
# 初始化模型
official = nn.ConvTranspose3d(3, 6, **kwargs, bias=True)
custom = ConvTranspose3d(3, 6, **kwargs, bias=True)
# 参数同步
custom.load_state_dict(official.state_dict())
# 生成输入
x = torch.randn(2, 3, 5, 7, 11) # (batch, channels, depth, height, width)
# 前向计算
out_official = official(x)
out_custom = custom(x)
# 结果验证
print(f"官方形状: {out_official.shape}")
print(f"自定义形状: {out_custom.shape}")
print(f"形状一致: {out_official.shape == out_custom.shape}")
print(f"数值误差: {torch.max(torch.abs(out_official - out_custom)):.2e}")
if __name__ == "__main__":
test_configs = {
"基础测试": {
"kernel_size": 3,
"stride": 2,
"padding": 1,
"output_padding": 1,
"dilation": 1
},
"非对称参数": {
"kernel_size": (3, 5, 2),
"stride": (2, 1, 3),
"padding": (0, 2, 1),
"output_padding": (1, 0, 1),
"dilation": (2, 1, 3)
},
"大膨胀率": {
"kernel_size": 3,
"stride": 2,
"padding": 2,
"output_padding": 1,
"dilation": 3
},
"无输出填充": {
"kernel_size": 4,
"stride": 3,
"padding": 2,
"output_padding": 0,
"dilation": 1
},
"分组卷积": {
"kernel_size": 3,
"stride": 2,
"padding": 1,
"output_padding": 1,
"dilation": 1,
"groups": 3,
"in_channels": 6,
"out_channels": 6
}
}
for name, config in test_configs.items():
run_test(name, **config)
convtranspose3d
最新推荐文章于 2025-05-08 12:11:49 发布