代码实现
工具类代码
import torch
from matplotlib import pyplot as plt
def plot_curve(ls):
fig = plt.figure()
palette = plt.get_cmap('Set1')
for idx, data in enumerate(ls):
for key, value in data.items():
plt.plot(range(len(value)), value, color=palette(idx), label=key)
plt.legend()
plt.xlabel('step')
plt.ylabel('value')
plt.show()
def plot_image(img, label, name):
fig = plt.figure()
for i in range(6):
plt.subplot(2, 3, i+1)
plt.tight_layout()
plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
plt.title("{}: {}".format(name, label[i].item()))
plt.xticks([])
plt.yticks([])
plt.show()
def one_hot(label, depth=10):
out = torch.zeros