pytorch 实现UNet
将其分为双卷积、下采样、上采样和输出四个部分。
pytorch实现代码
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
"""conv->BN->relu * 2"""
def __init__(self, in_channels, out_channels, mid_channels = None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace = True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(sel