环境 :
conda install pytorch==1.7.0 torchvision==0.8.0 torchaudio==0.7.0 cudatoolkit=10.1 -c pytorch
每次看起来都很简单,自己写起来才知道哪里不会。
import os
import argparse
import sys
# import logging
# import nni
import torch
import time
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# from nni.utils import merge_parameter
from torchvision import datasets, transforms
class LSTMs(nn.Module):
def __init__(self, hidden_dim, n_layer):
super(LSTMs, self).__init__()
self.n_layer = n_layer
self.hidden_dim = hidden_dim
self.lstm = nn.LSTM(28, hidden_dim, n_layer, batch_first=True)
self.linear = nn.Linear(hidden_dim, 10)
def forward(self, x):
# x = x.squeeze()
out, _ = self.lstm(x)
out = out[:, -1, :]
out = self.linear(out)
return out
def train(args, net, device, train_loader, optimizer, epoch):
net.train()
# train_loss = 0
correct = 0
total = 0
for batch_idx,(data,label) in enumerate(train_loader):
data, label = data.to(device),label.to(device)
optimizer.zero_grad()
data = data.squeeze()
out = net(data)
# 一定要这种写法
loss = torch.nn.functional.cross_entropy(out, label)
loss.backward()
optimizer.step()
# train_loss += loss.item()
predicted = torch.max(out, 1)[1].cpu().numpy()
total += label.size(0)
correct += float((predicted == label.cpu().numpy()).astype(int).sum()) #/ float(label.size(0))
# acc = 100.*correct/total
progress_bar(batch_idx, len(train_loader), 'Loss: %.4f | Acc: %.4f%% (%d/%d)'
% (loss/(batch_idx+1), 100.*correct/total, correct, total))
def test(args, net, device, test_loader):
net.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for index,(data,label) in enumerate(test_loader):
data, label = data.to(device),label.to(device)
data = data.squeeze()
out = net(data)
# test_loss += F.nll_loss(out, label, reduction='sum').item()
test_loss += torch.nn.functional.cross_entropy(out, label, reduction='sum').item()
pred = out.argmax(dim=1,keepdim=True)
correct += pred.eq(label.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100.*correct/len(test_loader.dataset)
print("Test: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
test_loss, correct, len(test_loader.dataset), accuracy))
return accuracy
term_width = 0
try:
term_width = os.get_terminal_size().columns
except Exception as exception:
term_width = 200
term_width = int(term_width)
def format_time(seconds):
days = int(seconds / 3600/24)
seconds = seconds - days*3600*24
hours = int(seconds / 3600)
seconds = seconds - hours*3600
minutes = int(seconds / 60)
seconds = seconds - minutes*60
secondsf = int(seconds)
seconds = seconds - secondsf
millis = int(seconds*1000)
f = ''
i = 1
if days > 0:
f += str(days) + 'D'
i += 1
if hours > 0 and i <= 2:
f += str(hours) + 'h'
i += 1
if minutes > 0 and i <= 2:
f += str(minutes) + 'm'
i += 1
if secondsf > 0 and i <= 2:
f += str(secondsf) + 's'
i += 1
if millis > 0 and i <= 2:
f += str(millis) + 'ms'
i += 1
if f == '':
f = '0ms'
return f
TOTAL_BAR_LENGTH = 95.
last_time = time.time()
begin_time = last_time
def progress_bar(current, total, msg=None):
global last_time, begin_time
if current == 0:
begin_time = time.time() # Reset for new bar.
cur_len = int(TOTAL_BAR_LENGTH*current/total)
rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1