### UNet 模型在图像分类中的应用
UNet 是一种经典的卷积神经网络架构,最初设计用于生物医学图像分割任务。尽管其主要用途是图像分割,但通过适当的调整,也可以将其应用于图像分类任务。UNet 的核心思想在于利用编码器-解码器结构,结合跳跃连接(skip connections),保留空间信息并实现精确的像素级预测。对于图像分类任务,可以通过调整输出层以及损失函数来适配分类需求。
#### 构建 UNet 网络模型
首先,定义 UNet 中的各个组件,包括卷积层、下采样和上采样模块。以下是一个简化版的 PyTorch 实现:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(DoubleConv, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.dropout1 = nn.Dropout2d(p=0.5)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
self.dropout2 = nn.Dropout2d(p=0.5)
self.leaky_relu = nn.LeakyReLU(negative_slope=0.01)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.dropout1(x)
x = self.leaky_relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.dropout2(x)
x = self.leaky_relu(x)
return x
class DownSample(nn.Module):
def __init__(self, in_channels, out_channels):
super(DownSample, self).__init__()
self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.double_conv = DoubleConv(in_channels, out_channels)
def forward(self, x):
x = self.max_pool(x)
x = self.double_conv(x)
return x
class UpSample(nn.Module):
def __init__(self, in_channels, out_channels):
super(UpSample, self).__init__()
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.double_conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim=1)
return self.double_conv(x)
class UNet(nn.Module):
def __init__(self, in_channels=3, num_classes=2):
super(UNet, self).__init__()
self.down_conv1 = DoubleConv(in_channels, 64)
self.down_conv2 = DownSample(64, 128)
self.down_conv3 = DownSample(128, 256)
self.down_conv4 = DownSample(256, 512)
self.down_conv5 = DownSample(512, 1024)
self.up_conv4 = UpSample(1024, 512)
self.up_conv3 = UpSample(512, 256)
self.up_conv2 = UpSample(256, 128)
self.up_conv1 = UpSample(128, 64)
self.final_conv = nn.Conv2d(64, num_classes, kernel_size=1)
def forward(self, x):
conv1 = self.down_conv1(x)
conv2 = self.down_conv2(conv1)
conv3 = self.down_conv3(conv2)
conv4 = self.down_conv4(conv3)
x = self.down_conv5(conv4)
x = self.up_conv4(x, conv4)
x = self.up_conv3(x, conv3)
x = self.up_conv2(x, conv2)
x = self.up_conv1(x, conv1)
return self.final_conv(x)
```
#### 数据预处理与加载
使用 PyTorch 的 `torchvision.transforms` 对数据进行增强,并通过 `DataLoader` 加载数据集:
```python
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import os
from PIL import Image
class CustomDataset(Dataset):
def __init__(self, image_paths, transform=None):
self.image_paths = image_paths
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image_path = self.image_paths[idx]
image = Image.open(image_path).convert('RGB')
label = int(os.path.basename(os.path.dirname(image_path))) # 假设文件夹名即为标签
if self.transform:
image = self.transform(image)
return image, label
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
dataset = CustomDataset(image_paths, transform=transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
```
#### 训练模型
接下来,定义损失函数和优化器,并编写训练循环:
```python
import torch.optim as optim
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(num_classes=2).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs = 10
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for images, labels in dataloader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
epoch_loss = running_loss / len(dataset)
print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')
```
#### 测试模型
最后,在测试集上评估模型性能:
```python
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_dataloader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Test Accuracy: {100 * correct / total:.2f}%')
```