代码
import torch
from matplotlib import pyplot as plt
import numpy as np
import random
from mpl_toolkits.mplot3d import Axes3D
# 数据集
num_inputs = 2
num_examples = 1000
true_w = [2, -3.4]
true_b = 4.2
features = torch.from_numpy(np.random.normal(0, 1, (num_examples, num_inputs))).type(torch.float32) # 1000*2
labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b
labels += torch.from_numpy(np.random.normal(0, 0.01, size=labels.size())) # 噪声
# plt.scatter(features[:,0].numpy(),labels.numpy())
# plt.show()
# 批量数据样本
def data_iter(batch_size, features, labels):
num_examples