%matplotlib inline
import random
import torch
from d2l import torch as d2l
1.生成数据集
def synthetic_data(w,b,num_examples):
x = torch.normal(0,1,(num_examples,len(w)))
y = torch.matmul(x, w) + b
y += torch.normal(0,0.01,y.shape)
return x,y.reshape((-1,1))
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)
features.shape, labels.shape
(torch.Size([1000, 2]), torch.Size([1000, 1]))
d2l.set_figsize()
d2l.plt.scatter(features[:,(1)].detach().numpy(),labels.detach().numpy(),1)
<matplotlib.collections.PathCollection at 0x1d9864c36a0>

2.读取数据集
def data_iter(batch_size,features,labels):
num_examples = len(features)
indices = list(range(num_examples))
random.shuffle(indices)
for i in range