sandglass block代码

import torch
import torch.nn.functional as F
from typing import Callable


class DepthwiseSeperableConvolutionalBlock(torch.nn.Module):
    def __init__(self, in_channels, kernel_size, out_channels, padding, stride, activation: Callable | None = None, bias: bool =True) -> None:
        super().__init__()
        self.activation = activation
        self.depthwise_layer = torch.nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, groups=in_channels, padding=padding, stride=stride, bias=bias)
        self.pointwise_layer = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias)

        self.batch_normalization_a = torch.nn.BatchNorm2d(num_features=in_channels)
        self.batch_normalization_b = torch.nn.BatchNorm2d(num_features=out_channels)

    def forward(self, x):
        depthwise_output = self.batch_normalization_a(self.depthwise_layer(x))
        if self.activation:
            pointwise_output = self.activation(self.batch_normalization_b(self.pointwise_layer(depthwise_output)))
        else:
            pointwise_output = self.batch_normalization_b(self.pointwise_layer(depthwise_output))
        return pointwise_output

class SandglassBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, stride, kernel_size, reduction_factor, bias=True) -> None:
        super().__init__()
        self.depthwise_block_a = DepthwiseSeperableConvolutionalBlock(in_channels=in_channels, kernel_size=kernel_size, out_channels=in_channels, padding=1, stride=1, bias=bias, activation=F.relu6)
        self.reduction_pointwise_conv_layer = torch.nn.Conv2d(in_channels=in_channels, out_channels=in_channels//reduction_factor, kernel_size=1, stride=1, padding=0, bias=bias)
        self.expansion_pointwise_conv_layer = torch.nn.Conv2d(in_channels=in_channels//reduction_factor, out_channels=out_channels, kernel_size=1, stride=1, padding=0, bias=bias)
        self.depthwise_block_b = DepthwiseSeperableConvolutionalBlock(in_channels=out_channels, kernel_size=kernel_size, out_channels=out_channels, padding=1, stride=stride, bias=bias)

        self.shortcut: bool = in_channels == out_channels
        self.batch_normalization_reduction = torch.nn.BatchNorm2d(num_features=in_channels//reduction_factor)
        self.batch_normalization_expansion = torch.nn.BatchNorm2d(num_features=out_channels)

    def forward(self, x):
        residual = x
        depthwise_a = self.depthwise_block_a(x)
        reduction = self.batch_normalization_reduction(self.reduction_pointwise_conv_layer(depthwise_a))
        expansion = F.relu6(self.batch_normalization_expansion(self.expansion_pointwise_conv_layer(reduction)))
        depthwise_b = self.depthwise_block_b(expansion)
        if self.shortcut:
            depthwise_b = depthwise_b + residual
        return depthwise_b

sandglass_block = SandglassBlock(32, 96, 2, 3, 2)

input_tensor = torch.rand(32, 32, 112, 112)

sandglass_output = sandglass_block(input_tensor)
print(sandglass_output.size())

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值