OHEM在线难例挖掘原理及在代码中应用
OHEM原理
OHEM(Online Hard Example Mining)在线难例挖掘是一种用于优化神经网络训练的方法。通过在每个迭代中选择最难的样本进行训练,来提高模型的性能。在代码中可以通过使用损失函数和自定义采样器来实现。在传统的训练过程中,模型会在训练集中遇到大量易于分类的样本,而只有少量的难以分类的样本。这样一来,模型就会倾向于预测易于分类的样本,而忽略难以分类的样本。这样会导致模型无法很好地泛化到测试集上。
OHEM通过挖掘在线难例实现强化模型对难例的学习。具体来说,OHEM在每个batch的训练中选择一定数量(通常为batch size的1/2)的难例样本,这些难例样本的损失函数被优先考虑。因此,模型会更加关注难以分类的样本,在训练过程中逐渐学会处理难例样本的能力,提高模型的泛化性能。
应用
在自己的代码中应用OHEM,可以通过以下步骤:
-
定义一个损失函数,例如交叉熵损失。
-
在每个batch的训练过程中,计算所有样本的损失值,并按照损失值从大到小排序。
-
选择一定数量的样本作为难例样本,例如选择损失值排名前50%的样本。
-
将难例样本的损失函数乘以一个权重(例如2),以增加对难例样本的惩罚。
-
将难例样本和非难例样本的损失函数加权平均,得到本batch的总损失值。
-
根据总损失值更新模型参数。
PyTorch代码示例1
import torch.nn.functional as F
import torch.optim as optim
# 定义损失函数
loss_fn = F.cross_entropy
# 定义优化器
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(num_epochs):
for i, (inputs, labels) in enumerate(train_loader):
# 前向传播
outputs = model(inputs)
# 计算所有样本的损失值
loss = loss_fn(outputs, labels)
# 按照损失值排序
_, indices = torch.sort(loss, descending=True)
# 选择难例样本
num_hard = batch_size // 2
hard_indices = indices[:num_hard]
# 计算难例样本的损失函数,并乘以权重
hard_loss = loss_fn(outputs[hard_indices], labels[hard_indices]) * 2
# 将难例样本和非难例样本的损失函数加权平均
total_loss = (loss.mean() * (batch_size - num_hard) + hard_loss) / batch_size
# 反向传播和更新参数
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
在以上代码中,我们首先定义了一个交叉熵损失函数,然后在每个batch的训练过程中,按照损失值从大到小排序,并选择损失值排名前50%的样本作为难例样本。难例样本的损失函数乘以了一个权重2,以增加对难例样本的惩罚。最终,我们将难例样本和非难例样本的损失函数加权平均得到本batch的总损失值,并根据总损失值更新模型参数。
PyTorch代码示例2
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
# 定义模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.fc = nn.Linear(320, 10)
def forward(self, x):
x = nn.functional.relu(nn.functional.max_pool2d(self.conv1(x), 2))
x = nn.functional.relu(nn.functional.max_pool2d(self.conv2(x), 2))
x = x.view(-1, 320)
x = self.fc(x)
return x
# 定义OHEM损失函数
class OHMELoss(nn.Module):
def __init__(self, ratio=3):
super(OHMELoss, self).__init__()
self.ratio = ratio
def forward(self, input, target):
loss = nn.functional.cross_entropy(input, target, reduction='none')
num_samples = len(loss)
num_hard_samples = int(num_samples / self.ratio)
_, indices = torch.topk(loss, num_hard_samples)
ohem_loss = torch.mean(loss[indices])
return ohem_loss
# 加载数据集
train_dataset = MNIST(root='data', train=True, transform=ToTensor(), download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# 初始化模型和损失函数
model = Net()
criterion = OHMELoss()
# 训练模型
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
epochs = 10
for epoch in range(epochs):
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
# 测试模型
test_dataset = MNIST(root='data', train=False, transform=ToTensor())
test_loader = DataLoader(test_dataset, batch_size=1000)
model.eval()
correct = 0
with torch.no_grad():
for data, target in test_loader:
output = model(data)
_, predicted = torch.max(output.data, 1)
correct += (predicted == target).sum().item()
print('Test Accuracy:', correct / len(test_loader.dataset))
在代码中,我们首先定义了模型,并使用OHMELoss
作为损失函数。OHMELoss
定义中的ratio=3
表示每个迭代中选择三倍于正常的样本数量进行训练。
在训练过程中,我们使用torch.topk
函数选择最难的样本进行训练。在测试过程中,我们使用model.eval()
将模型设为评估模式,并计算模型的准确率。
这个示例展示了如何在PyTorch中使用OHEM进行训练,但具体的实现方式可能因应用场景而异。
PyTorch代码示例3
参考链接:https://blog.youkuaiyun.com/hxxjxw/article/details/119333414
import torch
import torch.nn.functional as F
import torch.nn as nn
smooth_l1_sigma = 1.0
smooth_l1_loss = nn.SmoothL1Loss(reduction='none') # reduce=False
def ohem_loss(batch_size, cls_pred, cls_target, loc_pred, loc_target):
""" Arguments:
batch_size (int): number of sampled rois for bbox head training
loc_pred (FloatTensor): [R, 4], location of positive rois
loc_target (FloatTensor): [R, 4], location of positive rois
pos_mask (FloatTensor): [R], binary mask for sampled positive rois
cls_pred (FloatTensor): [R, C]
cls_target (LongTensor): [R]
Returns:
cls_loss, loc_loss (FloatTensor)
"""
ohem_cls_loss = F.cross_entropy(cls_pred, cls_target, reduction='none', ignore_index=-1)
ohem_loc_loss = smooth_l1_loss(loc_pred, loc_target).sum(dim=1)
# 这里先暂存下正常的分类loss和回归loss
print(ohem_cls_loss.shape, ohem_loc_loss.shape)
loss = ohem_cls_loss + ohem_loc_loss
# 然后对分类和回归loss求和
sorted_ohem_loss, idx = torch.sort(loss, descending=True)
# 再对loss进行降序排列
keep_num = min(sorted_ohem_loss.size()[0], batch_size)
# 得到需要保留的loss数量
if keep_num < sorted_ohem_loss.size()[0]:
# 这句的作用是如果保留数目小于现有loss总数,则进行筛选保留,否则全部保留
keep_idx_cuda = idx[:keep_num] # 保留到需要keep的数目
ohem_cls_loss = ohem_cls_loss[keep_idx_cuda]
ohem_loc_loss = ohem_loc_loss[keep_idx_cuda] # 分类和回归保留相同的数目
cls_loss = ohem_cls_loss.sum() / keep_num
loc_loss = ohem_loc_loss.sum() / keep_num # 然后分别对分类和回归loss求均值
return cls_loss, loc_loss
if __name__ == '__main__':
batch_size = 4
C = 6
loc_pred = torch.randn(8, 4)
loc_target = torch.randn(8, 4)
cls_pred = torch.randn(8, C)
cls_target = torch.Tensor([1, 1, 2, 3, 5, 3, 2, 1]).type(torch.long)
cls_loss, loc_loss = ohem_loss(batch_size, cls_pred, cls_target, loc_pred, loc_target)
print(cls_loss, '--', loc_loss)