import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Union, Tuple
class ConvTranspose3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
output_padding: Union[int, Tuple[int, int, int]] = 0,
groups: int = 1,
bias: bool = True,
dilation: Union[int, Tuple[int, int, int]] = 1,
padding_mode: str = 'zeros'
):
super().__init__()
if padding_mode != 'zeros':
raise NotImplementedError("Only 'zeros' padding mode is supported")
# 参数规范化
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = self._triple(kernel_size)
self.stride = self._triple(stride)
self.padding = self._triple(padding)
self.output_padding = self._triple(output_padding)
self.groups = groups
self.dilation = self._triple(dilation)
# 通道数校验
if in_channels % groups != 0:
raise ValueError(f"in_channels must be divisible by groups")
if out_channels % groups != 0:
raise ValueError(f"out_channels must be divisible by groups")
# 初始化参数
self.weight = nn.Parameter(torch.empty(
in_channels,
out_channels // groups
convtranspose3d
于 2025-04-07 02:54:07 首次发布

最低0.47元/天 解锁文章
527

被折叠的 条评论
为什么被折叠?



