conv3d算子多核

conv3d算子相当复杂,开发多核需要提前测试,这是用来测试的pytorch脚本

测试脚本

import torch
import torch.nn as nn
import numpy as np

# test module
class Conv2dFrom3D(nn.Module):
    # transform 3d to 2d
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1):
        super().__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        
        # 计算输出深度
        self.kd = kernel_size[0]
        self.output_depth = None
        
        # 创建多个无偏置的Conv2d层
        self.conv2d_layers = nn.ModuleList()
        for i in range(self.kd):
            conv = nn.Conv2d(
                in_channels, 
                out_channels, 
                kernel_size[1:],
                stride=stride[1:] if isinstance(stride, tuple) else stride,
                padding=padding[1:] if isinstance(padding, tuple) else padding,
                dilation=dilation[1:] if isinstance(dilation, tuple) else dilation,
                groups=groups,
                bias=True
            )
            self.conv2d_layers.append(conv)
        
        # 添加全局偏置参数
        self.bias = nn.Parameter(torch.zeros(out_channels))
    
    def forward(self, x):
        N, C, D, H, W = x.shape
        
        # 计算输出深度
        self.output_depth = D - self.kd + 1
        
        # 存储每个深度位置的输出
        outputs = []
        
        # 在深度方向滑动窗口
        for d_start in range(self.output_depth):
            # 提取当前深度窗口 [N, C, kd, H, W]
            window = x[:, :, d_start:d_start+self.kd, :, :]
            
            # 对窗口中的每个深度位置应用对应的Conv2d
            window_outputs = []
            for i in range(self.kd):
                # 获取当前深度切片 [N, C, H, W]
                depth_slice = window[:, :, i, :, :]
                
                # 应用对应的Conv2d层
                conv_output = self.conv2d_layers[i](depth_slice)
                window_outputs.append(conv_output)
            
            # 叠加所有深度位置的输出 [N, out_channels, kd, H_out, W_out]
            window_output = torch.stack(window_outputs, dim=2)
            
            # 在深度方向求和,模拟Conv3d的深度卷积 [N, out_channels, H_out, W_out]
            depth_sum = torch.sum(window_output, dim=2)
            outputs.append(depth_sum)
        
        # 将所有窗口输出在深度方向堆叠 [N, out_channels, output_depth, H_out, W_out]
        output = torch.stack(outputs, dim=2)
        
        return output
    
    def copy_weights_from_conv3d(self, conv3d):
        """从Conv3d复制权重和偏置"""
        with torch.no_grad():
            # Conv3d权重形状: [out_channels, in_channels, kd, kh, kw]
            weights_3d = conv3d.weight.data
            bias_3d = conv3d.bias.data
            
            # 将权重分配到各个Conv2d层
            for i in range(self.kd):
                # 提取当前深度位置的权重 [out_channels, in_channels, kh, kw]
                weights_2d = weights_3d[:, :, i, :, :]
                self.conv2d_layers[i].weight.data.copy_(weights_2d)
            # 复制偏置
            if conv3d.bias is not None:
                for conv2d_layer in self.conv2d_layers:
                    conv2d_layer.bias.data.copy_(conv3d.bias.data)

            

class Conv2dFrom3D_SpatialSplit(nn.Module):
    # split H & w
    # 理论上切分 dhw 三个维度是等价的
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1):
        super().__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        
        self.kd = kernel_size[0]
        self.output_depth = None
        
        # 使用单个Conv2d处理所有深度位置
        self.conv2d = nn.Conv2d(
            in_channels * self.kd,  # 输入通道扩展
            out_channels, 
            kernel_size[1:],
            stride=stride[1:] if isinstance(stride, tuple) else stride,
            padding=padding[1:] if isinstance(padding, tuple) else padding,
            dilation=dilation[1:] if isinstance(dilation, tuple) else dilation,
            groups=groups,
            bias=True
        )
    
    def forward(self, x):
        N, C, D, H, W = x.shape
        self.output_depth = D - self.kd + 1
        
        outputs = []
        for d_start in range(self.output_depth):
            # 提取当前深度窗口
            window = x[:, :, d_start:d_start+self.kd, :, :]
            
            # 将深度维度展平到通道维度
            window_flat = window.permute(0, 2, 1, 3, 4).reshape(N, C * self.kd, H, W)
            
            # 应用Conv2d
            conv_output = self.conv2d(window_flat)
            outputs.append(conv_output)
        
        return torch.stack(outputs, dim=2)


# test function
def verify_conv3d_oc_split_equivalence():
    # config seed
    torch.manual_seed(42)
    
    # config parameter
    batch_size = 1
    in_channels = 6
    out_channels = 64  # 原始输出通道数
    depth = 28
    height = 16
    width = 16
    kernel_size = [2,4,4]
    stride = 1
    padding = 0
    groups = 1
    dilation = [1,1,1]
    
    # generate input
    input_data = torch.randn(batch_size, in_channels, depth, height, width)
    
    # create original conv3d and culculate original_output
    original_conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride,
                              padding=padding, groups=groups, dilation=dilation)
    original_output = original_conv(input_data)
    
    # slice
    num_splits = 8
    assert out_channels % num_splits == 0, "out_channels must be divisible by num_splits"
    split_size = out_channels // num_splits
    
    # create slice conv3d
    split_convs = nn.ModuleList([
        nn.Conv3d(in_channels, split_size, kernel_size, 
                 stride=stride, padding=padding, groups=groups, dilation=dilation)
        for _ in range(num_splits)
    ])
    # 使用conv2d去搭conv3d multi_core算子()
    # 可以切其他维度,找到最优切法
    
    # initial weight for slice
    for i in range(num_splits):
        start_idx = i * split_size
        end_idx = (i + 1) * split_size
        # copy weight
        split_convs[i].weight.data = original_conv.weight.data[start_idx:end_idx, :, :, :, :].clone()
        # copy bias
        if original_conv.bias is not None:
            split_convs[i].bias.data = original_conv.bias.data[start_idx:end_idx].clone()
    
    # culculate concatenated_output
    split_outputs = []
    for conv in split_convs:
        split_outputs.append(conv(input_data))
    concatenated_output = torch.cat(split_outputs, dim=1)
    
    # culculate diff
    diff = torch.abs(original_output - concatenated_output)
    max_diff = torch.max(diff)
    mean_diff = torch.mean(diff)
    
    print(f"max_diff: {max_diff.item()}")
    print(f"mean_diff: {mean_diff.item()}")
    
    # Check if it is equivalent within the allowable margin of error
    tolerance = 1e-15
    if torch.allclose(original_output, concatenated_output, atol=tolerance):
        print("verify conv3d oc split test pass! \n")
    else:
        print("verify conv3d oc split test fail! \n")

def verify_conv2d_from_3d_split_equivalence():
    # config seed
    torch.manual_seed(42)
    
    # config parameter
    batch_size = 1
    in_channels = 4
    out_channels = 16
    depth = 16
    height = 32
    width = 32
    kernel_size = [1, 4, 4]
    stride = 1
    padding = 0
    
    # generate input
    input_data = torch.randn(batch_size, in_channels, depth, height, width)
    
    # create original conv3d and calculate original_output
    original_conv = nn.Conv3d(
        in_channels,
        out_channels,
        kernel_size,
        stride=stride, 
        padding=padding
        )
    original_output = original_conv(input_data)

    # create replacement and copy weights
    conv2d_replacement = Conv2dFrom3D(
        in_channels,
        out_channels,
        kernel_size,
        stride=stride, 
        padding=padding
        )
    conv2d_replacement.copy_weights_from_conv3d(original_conv)
    # calculate outputs
    conv2d_replacement_output = conv2d_replacement(input_data)
    
    # calculate diff
    diff = torch.abs(original_output - conv2d_replacement_output)
    max_diff = torch.max(diff)
    mean_diff = torch.mean(diff)
    
    print(f"max_diff: {max_diff.item()}")
    print(f"mean_diff: {mean_diff.item()}")
    
    # Check equivalence
    tolerance = 1e-5
    if torch.allclose(original_output, conv2d_replacement_output, atol=tolerance):
        print("verify conv2d from 3d test pass! \n")
    else:
        print("verify conv2d from 3d test fail! \n")

def verify_conv2d_spatial_split_equivalence():
    # config seed
    torch.manual_seed(42)

    # config parameter
    batch_size = 1
    in_channels = 4
    out_channels = 16
    depth = 16
    height = 32
    width = 32
    kernel_size = [1, 4, 4]
    stride = 1
    padding = 0
    
    # generate input
    input_3d = torch.randn(batch_size, in_channels, depth, height, width)
    
    # create original conv3d and calculate original_output
    conv3d = nn.Conv3d(
        in_channels, 
        out_channels, 
        kernel_size, 
        stride=stride, 
        padding=padding
    )
    original_output = conv3d(input_3d)
    
    # create replacement and copy weights
    conv2d_replacement = Conv2dFrom3D_SpatialSplit(
        in_channels, 
        out_channels, 
        kernel_size, 
        stride=stride, 
        padding=padding
    )
    with torch.no_grad():
        weights_3d = conv3d.weight.data
        kd = kernel_size[0]
        weights_reshaped = weights_3d.permute(0, 2, 1, 3, 4).reshape(
            out_channels, in_channels * kd, kernel_size[1], kernel_size[2]
        )
        # copy weight
        conv2d_replacement.conv2d.weight.data.copy_(weights_reshaped)
        # copy bias
        conv2d_replacement.conv2d.bias.data.copy_(conv3d.bias.data)
    # calculate outputs
    conv2d_output = conv2d_replacement(input_3d)
    
    # calculate diff
    diff = torch.abs(original_output - conv2d_output)
    max_diff = torch.max(diff).item()
    mean_diff = torch.mean(diff).item()
    
    # print(f"Original output shape: {original_output.shape}")
    # print(f"Conv2d replacement output shape: {conv2d_output.shape}")
    print(f"max_diff: {max_diff}")
    print(f"mean_diff: {mean_diff}")
    
    # Check equivalence
    tolerance = 1e-5
    if torch.allclose(original_output, conv2d_output, atol=tolerance):
        print("verify conv2d spatial split test pass! \n")
        return True
    else:
        print("verify conv2d spatial split test fail! \n")
        return False



if __name__ == "__main__":
    # verify_conv3d_oc_split_equivalence()
    verify_conv2d_from_3d_split_equivalence()
    verify_conv2d_spatial_split_equivalence()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值