convtranspose3d

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Union, Tuple

class ConvTranspose3d(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, int, int]],
        stride: Union[int, Tuple[int, int, int]] = 1,
        padding: Union[int, Tuple[int, int, int]] = 0,
        output_padding: Union[int, Tuple[int, int, int]] = 0,
        groups: int = 1,
        bias: bool = True,
        dilation: Union[int, Tuple[int, int, int]] = 1,
        padding_mode: str = 'zeros'
    ):
        super().__init__()
        if padding_mode != 'zeros':
            raise NotImplementedError("Only 'zeros' padding mode is supported")
        
        # 参数规范化
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = self._triple(kernel_size)
        self.stride = self._triple(stride)
        self.padding = self._triple(padding)
        self.output_padding = self._triple(output_padding)
        self.groups = groups
        self.dilation = self._triple(dilation)

        # 通道数校验
        if in_channels % groups != 0:
            raise ValueError(f"in_channels must be divisible by groups")
        if out_channels % groups != 0:
            raise ValueError(f"out_channels must be divisible by groups")

        # 初始化参数
        self.weight = nn.Parameter(torch.empty(
            in_channels,
            out_channels // groups,
            self.kernel_size[0],
            self.kernel_size[1],
            self.kernel_size[2]
        ))

        if bias:
            self.bias = nn.Parameter(torch.empty(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def _triple(self, x: Union[int, Tuple[int, int, int]]) -> Tuple[int, int, int]:
        """将输入转换为三元组"""
        if isinstance(x, int):
            return (x, x, x)
        elif isinstance(x, tuple) and len(x) == 3:
            return x
        else:
            raise ValueError(f"Expected int or 3-element tuple, got {type(x)}")

    def reset_parameters(self):
        """参数初始化"""
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 输入维度解析
        N, C_in, D_in, H_in, W_in = x.shape

        # Step 1: 上采样(插入零)
        D_up = (D_in - 1) * self.stride[0] + 1
        H_up = (H_in - 1) * self.stride[1] + 1
        W_up = (W_in - 1) * self.stride[2] + 1
        
        x_up = torch.zeros(N, C_in, D_up, H_up, W_up, dtype=x.dtype, device=x.device)
        x_up[:, :, ::self.stride[0], ::self.stride[1], ::self.stride[2]] = x

        # Step 2: 计算填充量
        pad_d = (self.kernel_size[0] - 1) * self.dilation[0] - self.padding[0]
        pad_h = (self.kernel_size[1] - 1) * self.dilation[1] - self.padding[1]
        pad_w = (self.kernel_size[2] - 1) * self.dilation[2] - self.padding[2]

        # 确保填充量非负
        pad_d = max(0, pad_d)
        pad_h = max(0, pad_h)
        pad_w = max(0, pad_w)

        # 分解为单侧填充(关键修复:扁平化元组)
        pad_config = (
            pad_w // 2, pad_w - pad_w // 2,  # 宽度维度
            pad_h // 2, pad_h - pad_h // 2,  # 高度维度
            pad_d // 2, pad_d - pad_d // 2   # 深度维度
        )
        x_pad = F.pad(x_up, pad_config)

        # Step 3: 权重变换
        G = self.groups
        conv_weight = self.weight.view(
            G,
            self.in_channels // G,
            self.out_channels // G,
            *self.kernel_size
        )
        # 维度转置和通道重组
        conv_weight = conv_weight.permute(0, 2, 1, 3, 4, 5).contiguous()
        conv_weight = conv_weight.view(G * (self.out_channels // G), 
                                      self.in_channels // G,
                                      *self.kernel_size)
        # 卷积核翻转
        conv_weight = conv_weight.flip(dims=[2, 3, 4])

        # Step 4: 执行普通卷积
        output = F.conv3d(
            input=x_pad,
            weight=conv_weight,
            bias=None,
            stride=1,
            padding=0,
            dilation=self.dilation,
            groups=self.groups
        )

        # Step 5: 应用输出填充
        output = F.pad(output, (
            0, self.output_padding[2],  # 宽度
            0, self.output_padding[1],  # 高度
            0, self.output_padding[0]   # 深度
        ))

        # Step 6: 添加偏置
        if self.bias is not None:
            output += self.bias.view(1, -1, 1, 1, 1)

        return output

# ------------------------ 测试代码 ------------------------
def run_test(test_name, **kwargs):
    print(f"\n=== 测试案例: {test_name} ===")
    
    # 初始化模型
    official = nn.ConvTranspose3d(3, 6, **kwargs, bias=True)
    custom = ConvTranspose3d(3, 6, **kwargs, bias=True)
    
    # 参数同步
    custom.load_state_dict(official.state_dict())
    
    # 生成输入
    x = torch.randn(2, 3, 5, 7, 11)  # (batch, channels, depth, height, width)
    
    # 前向计算
    out_official = official(x)
    out_custom = custom(x)
    
    # 结果验证
    print(f"官方形状: {out_official.shape}")
    print(f"自定义形状: {out_custom.shape}")
    print(f"形状一致: {out_official.shape == out_custom.shape}")
    print(f"数值误差: {torch.max(torch.abs(out_official - out_custom)):.2e}")

if __name__ == "__main__":
    test_configs = {
        "基础测试": {
            "kernel_size": 3,
            "stride": 2,
            "padding": 1,
            "output_padding": 1,
            "dilation": 1
        },
        "非对称参数": {
            "kernel_size": (3, 5, 2),
            "stride": (2, 1, 3),
            "padding": (0, 2, 1),
            "output_padding": (1, 0, 1),
            "dilation": (2, 1, 3)
        },
        "大膨胀率": {
            "kernel_size": 3,
            "stride": 2,
            "padding": 2,
            "output_padding": 1,
            "dilation": 3
        },
        "无输出填充": {
            "kernel_size": 4,
            "stride": 3,
            "padding": 2,
            "output_padding": 0,
            "dilation": 1
        },
        "分组卷积": {
            "kernel_size": 3,
            "stride": 2,
            "padding": 1,
            "output_padding": 1,
            "dilation": 1,
            "groups": 3,
            "in_channels": 6,
            "out_channels": 6
        }
    }

    for name, config in test_configs.items():
        run_test(name, **config)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值