1、用PyTorch写一个最简的网络模型,两个线性层,够简单吧,让我么看看她能做点什么?
INPUT_SIZE = 784
HIDDEN_SIZE = 500
NUM_CLASSES = 10
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(INPUT_SIZE, HIDDEN_SIZE)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(HIDDEN_SIZE, NUM_CLASSES)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
2、看看网络的结构,确实很简单:
3、用MNIST数据训练一下,这么寒酸的网路,训练10轮就够了:
4、花2.27秒用10000条测试数据跑一下,正确率97.46,这也太假了。
5、用OnnxRuntime推理,正确率是一样的,时间不到1秒了,好像还有点用。
6、用.net随便写个画图程序来测一测,居然识别出来了。
7、结论:AI确实好用,换传统算法实现还挺麻烦;感谢PyTorch,继续学习。