train_curve = list()
def train_net(net, device, data_path, epochs=100, batch_size=4, lr=0.01):
# 加载训练集
isbi_dataset = ISBI_Loader(data_path)
train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset,
batch_size=batch_size,
shuffle=True)
# 定义RMSprop算法
optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
# 定义Loss算法
criterion = nn.BCEWithLogitsLoss()
# best_loss统计,初始化为正无穷
best_loss = float('inf')
# 训练epochs次
step = 0
for epoch in range(epochs):
print('Epoch {}/{}'.format(epoch, epochs))
print('-' * 10)
dt_size = len(train_loader.dataset)
epoch_loss = 0
# 训练模式
net.train()
# 按照batch_size开始训练
for image, label in train_loader:
step += 1
# 将数据拷贝到device中
image = image.to(device=device, dtype=torch.float32)
记录训练过程
最新推荐文章于 2022-08-31 10:31:19 发布