问题:
如何在显存不足的情况下,增加batch-size?
换言之,如何增加batch-size而无需扩大显存?
思路:
将batch数据,分为多个mini-batch,对mini-batch计算loss,再求和,进行反向传播。
这样内存只占用mini-batch大小的数据,用时间换空间。
pytorch实现:
import torch
from sklearn import metrics
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
# 简单的TextRNN模型
class TextRNN(nn.Module):
def __init__(self, num_words, num_classes, embedding_dim, hidden_dim, dropout):
super(TextRNN, self).__init__()
self.embed = nn.Embedding(num_embeddings=num_words + 1, embedding_dim=embedding_dim, padding_idx=num_words)
self.encode = nn.GRU(embedding_dim, 200, batch_first=True, bidirectional=True)
self.mlp =