本次小记,提供了一份基于pytorch的完整的深层神经网络模型的代码,包含训练模型代码和测试模型代码。除此之外,对代码中不容易理解的代码和函数进行了讲解。
本代码的平台是PyCharm 2024.1.3,python版本3.11 numpy版本是1.26.4,pytorch版本2.0.0+cu118
模型训练的详细代码如下:
# 定义数据处理函数
def data_treating(data):
data = np.array(data, dtype='float32') / 255
data = (data - 0.5) / 0.5 # 标准化
data = data.reshape((-1,)) # 改变数组维度
data = torch.from_numpy(data) # 数据转换成tensor类型
return data
# 载入数据集
train_set = mnist.MNIST('./Mnist_data', train=True, transform=data_treating)
test_set = mnist.MNIST('./Mnist_data', train=False, transform=data_treating)
train_data = DataLoader(train_set, batch_size=64, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)
classify_net = nn.Sequential(
nn.Linear(784, 400),
nn.ReLU(),
nn.Linear(400, 200),
nn.ReLU(),
nn.Linear(200, 100),
nn.ReLU(),
nn.Linear(100, 50),
nn.ReLU(),
nn.