conv3d算子相当复杂,开发多核需要提前测试,这是用来测试的pytorch脚本
测试脚本
import torch
import torch.nn as nn
import numpy as np
# test module
class Conv2dFrom3D(nn.Module):
# transform 3d to 2d
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1):
super().__init__()
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
# 计算输出深度
self.kd = kernel_size[0]
self.output_depth = None
# 创建多个无偏置的Conv2d层
self.conv2d_layers = nn.ModuleList()
for i in range(self.kd):
conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size[1:],
stride=stride[1:] if isinstance(stride, tuple) else stride,
padding=padding[1:] if isinstance(padding, tuple) else padding,
dilation=dilation[1:] if isinstance(dilation, tuple) else dilation,
groups=groups,
bias=True
)
self.conv2d_layers.append(conv)
# 添加全局偏置参数
self.bias = nn.Parameter(torch.zeros(out_channels))
def forward(self, x):
N, C, D, H, W = x.shape
# 计算输出深度
self.output_depth = D - self.kd + 1
# 存储每个深度位置的输出
outputs = []
# 在深度方向滑动窗口
for d_start in range(self.output_depth):
# 提取当前深度窗口 [N, C, kd, H, W]
window = x[:, :, d_start:d_start+self.kd, :, :]
# 对窗口中的每个深度位置应用对应的Conv2d
window_outputs = []
for i in range(self.kd):
# 获取当前深度切片 [N, C, H, W]
depth_slice = window[:, :, i, :, :]
# 应用对应的Conv2d层
conv_output = self.conv2d_layers[i](depth_slice)
window_outputs.append(conv_output)
# 叠加所有深度位置的输出 [N, out_channels, kd, H_out, W_out]
window_output = torch.stack(window_outputs, dim=2)
# 在深度方向求和,模拟Conv3d的深度卷积 [N, out_channels, H_out, W_out]
depth_sum = torch.sum(window_output, dim=2)
outputs.append(depth_sum)
# 将所有窗口输出在深度方向堆叠 [N, out_channels, output_depth, H_out, W_out]
output = torch.stack(outputs, dim=2)
return output
def copy_weights_from_conv3d(self, conv3d):
"""从Conv3d复制权重和偏置"""
with torch.no_grad():
# Conv3d权重形状: [out_channels, in_channels, kd, kh, kw]
weights_3d = conv3d.weight.data
bias_3d = conv3d.bias.data
# 将权重分配到各个Conv2d层
for i in range(self.kd):
# 提取当前深度位置的权重 [out_channels, in_channels, kh, kw]
weights_2d = weights_3d[:, :, i, :, :]
self.conv2d_layers[i].weight.data.copy_(weights_2d)
# 复制偏置
if conv3d.bias is not None:
for conv2d_layer in self.conv2d_layers:
conv2d_layer.bias.data.copy_(conv3d.bias.data)
class Conv2dFrom3D_SpatialSplit(nn.Module):
# split H & w
# 理论上切分 dhw 三个维度是等价的
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1):
super().__init__()
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.kd = kernel_size[0]
self.output_depth = None
# 使用单个Conv2d处理所有深度位置
self.conv2d = nn.Conv2d(
in_channels * self.kd, # 输入通道扩展
out_channels,
kernel_size[1:],
stride=stride[1:] if isinstance(stride, tuple) else stride,
padding=padding[1:] if isinstance(padding, tuple) else padding,
dilation=dilation[1:] if isinstance(dilation, tuple) else dilation,
groups=groups,
bias=True
)
def forward(self, x):
N, C, D, H, W = x.shape
self.output_depth = D - self.kd + 1
outputs = []
for d_start in range(self.output_depth):
# 提取当前深度窗口
window = x[:, :, d_start:d_start+self.kd, :, :]
# 将深度维度展平到通道维度
window_flat = window.permute(0, 2, 1, 3, 4).reshape(N, C * self.kd, H, W)
# 应用Conv2d
conv_output = self.conv2d(window_flat)
outputs.append(conv_output)
return torch.stack(outputs, dim=2)
# test function
def verify_conv3d_oc_split_equivalence():
# config seed
torch.manual_seed(42)
# config parameter
batch_size = 1
in_channels = 6
out_channels = 64 # 原始输出通道数
depth = 28
height = 16
width = 16
kernel_size = [2,4,4]
stride = 1
padding = 0
groups = 1
dilation = [1,1,1]
# generate input
input_data = torch.randn(batch_size, in_channels, depth, height, width)
# create original conv3d and culculate original_output
original_conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride,
padding=padding, groups=groups, dilation=dilation)
original_output = original_conv(input_data)
# slice
num_splits = 8
assert out_channels % num_splits == 0, "out_channels must be divisible by num_splits"
split_size = out_channels // num_splits
# create slice conv3d
split_convs = nn.ModuleList([
nn.Conv3d(in_channels, split_size, kernel_size,
stride=stride, padding=padding, groups=groups, dilation=dilation)
for _ in range(num_splits)
])
# 使用conv2d去搭conv3d multi_core算子()
# 可以切其他维度,找到最优切法
# initial weight for slice
for i in range(num_splits):
start_idx = i * split_size
end_idx = (i + 1) * split_size
# copy weight
split_convs[i].weight.data = original_conv.weight.data[start_idx:end_idx, :, :, :, :].clone()
# copy bias
if original_conv.bias is not None:
split_convs[i].bias.data = original_conv.bias.data[start_idx:end_idx].clone()
# culculate concatenated_output
split_outputs = []
for conv in split_convs:
split_outputs.append(conv(input_data))
concatenated_output = torch.cat(split_outputs, dim=1)
# culculate diff
diff = torch.abs(original_output - concatenated_output)
max_diff = torch.max(diff)
mean_diff = torch.mean(diff)
print(f"max_diff: {max_diff.item()}")
print(f"mean_diff: {mean_diff.item()}")
# Check if it is equivalent within the allowable margin of error
tolerance = 1e-15
if torch.allclose(original_output, concatenated_output, atol=tolerance):
print("verify conv3d oc split test pass! \n")
else:
print("verify conv3d oc split test fail! \n")
def verify_conv2d_from_3d_split_equivalence():
# config seed
torch.manual_seed(42)
# config parameter
batch_size = 1
in_channels = 4
out_channels = 16
depth = 16
height = 32
width = 32
kernel_size = [1, 4, 4]
stride = 1
padding = 0
# generate input
input_data = torch.randn(batch_size, in_channels, depth, height, width)
# create original conv3d and calculate original_output
original_conv = nn.Conv3d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding
)
original_output = original_conv(input_data)
# create replacement and copy weights
conv2d_replacement = Conv2dFrom3D(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding
)
conv2d_replacement.copy_weights_from_conv3d(original_conv)
# calculate outputs
conv2d_replacement_output = conv2d_replacement(input_data)
# calculate diff
diff = torch.abs(original_output - conv2d_replacement_output)
max_diff = torch.max(diff)
mean_diff = torch.mean(diff)
print(f"max_diff: {max_diff.item()}")
print(f"mean_diff: {mean_diff.item()}")
# Check equivalence
tolerance = 1e-5
if torch.allclose(original_output, conv2d_replacement_output, atol=tolerance):
print("verify conv2d from 3d test pass! \n")
else:
print("verify conv2d from 3d test fail! \n")
def verify_conv2d_spatial_split_equivalence():
# config seed
torch.manual_seed(42)
# config parameter
batch_size = 1
in_channels = 4
out_channels = 16
depth = 16
height = 32
width = 32
kernel_size = [1, 4, 4]
stride = 1
padding = 0
# generate input
input_3d = torch.randn(batch_size, in_channels, depth, height, width)
# create original conv3d and calculate original_output
conv3d = nn.Conv3d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding
)
original_output = conv3d(input_3d)
# create replacement and copy weights
conv2d_replacement = Conv2dFrom3D_SpatialSplit(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding
)
with torch.no_grad():
weights_3d = conv3d.weight.data
kd = kernel_size[0]
weights_reshaped = weights_3d.permute(0, 2, 1, 3, 4).reshape(
out_channels, in_channels * kd, kernel_size[1], kernel_size[2]
)
# copy weight
conv2d_replacement.conv2d.weight.data.copy_(weights_reshaped)
# copy bias
conv2d_replacement.conv2d.bias.data.copy_(conv3d.bias.data)
# calculate outputs
conv2d_output = conv2d_replacement(input_3d)
# calculate diff
diff = torch.abs(original_output - conv2d_output)
max_diff = torch.max(diff).item()
mean_diff = torch.mean(diff).item()
# print(f"Original output shape: {original_output.shape}")
# print(f"Conv2d replacement output shape: {conv2d_output.shape}")
print(f"max_diff: {max_diff}")
print(f"mean_diff: {mean_diff}")
# Check equivalence
tolerance = 1e-5
if torch.allclose(original_output, conv2d_output, atol=tolerance):
print("verify conv2d spatial split test pass! \n")
return True
else:
print("verify conv2d spatial split test fail! \n")
return False
if __name__ == "__main__":
# verify_conv3d_oc_split_equivalence()
verify_conv2d_from_3d_split_equivalence()
verify_conv2d_spatial_split_equivalence()