DL学习1--从零开始线性回归
使用mnist数据集表写了一个线性回归分类器。
from torch.utils import data
import torchvision
from torchvision import transforms
import torch
import torch.nn as nn
from d2l import torch as d2l
import tqdm
def load_data_fashion_mnist(batch_size, resize=None):
"""下载Fashion-MNIST数据集,然后将其加载到内存中"""
trans = [transforms.ToTensor()]
if resize:
trans.insert(0, transforms.Resize(resize))
trans = transforms.Compose(trans)
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
root="../data", train=False, transform=trans, download=True)
return (data.DataLoader(mnist_train, batch_size, shuffle=True),
data.DataLoader(mnist_test, batch_size, shuffle=False))
# batch size
batch_size = 64
train_iter,test_iter = load_data_fashion_mnist(batch_size=batch_size)
for X,y in train_iter:
print(X.shape,y.shape)
break
input_shape = 784
output_shape = 10
w = torch.normal(0,0.1,size=(input_shape,output_shape),requires_grad=True)
b = torch.zeros(output_shape,requires_grad=True)
def net(X):
return softmax( torch.matmul(X.reshape(-1,w.shape[0]) , w ) + b )
def softmax(y_ht):
y_hat = torch.exp(y_ht)
tmp = y_hat.sum(1,keepdim=True)
return y_hat / tmp
def cross_entropy(y_hat,y):
return - torch.log( y_hat[range(len(y_hat)) , y] )
def accuracy(y_hat,y):
if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
y_hat = y_hat.argmax(axis=1)
cmp = y_hat.type(y.dtype) == y
return float(cmp.type(y.dtype).sum())
def sgd(pamars,lr,batch_size):
with torch.no_grad():
for pama in pamars:
pama -= lr* pama.grad / batch_size
pama.grad.zero_()
def updater(batch_size):
return sgd([w,b],lr,batch_size)
# train
lr = 0.01
epochs = 20
for epoch in tqdm.tqdm(range(epochs)):
if isinstance(net,torch.nn.Module):
net.train()
true_data,data_len = 0,0
true_data_t,data_len_t = 0,0
for X,y in train_iter:
preb = net(X)
l = cross_entropy(preb,y)
if isinstance(updater, torch.optim.Optimizer):
updater.zero_grad()
l.sum().backward()
updater.step()
else:
l.sum().backward()
updater(X.shape[0])
true_data+= accuracy(preb,y)
data_len+=len(X)
# 测试数据
if isinstance(net, torch.nn.Module):
net.eval()
with torch.no_grad():
for X,y in test_iter:
pre = net(X)
l = cross_entropy(pre,y)
true_data_t += accuracy(pre,y)
data_len_t += len(X)
print("epoch =",epoch+1,' train acc=',true_data/data_len,' test acc=',true_data_t/data_len_t)
true_data, data_len = 0, 0
true_data_t, data_len_t = 0, 0
运行结果
E:\anaconda3\envs\torch1.13\python.exe C:/Users/Jie/Desktop/renwu/LiMu-AI/py/1.softmax回归实现.py
torch.Size([64, 1, 28, 28]) torch.Size([64])
5%|▌ | 1/20 [00:11<03:38, 11.49s/it]epoch = 1 train acc= 0.6595166666666666 test acc= 0.7441
10%|█ | 2/20 [00:22<03:23, 11.30s/it]epoch = 2 train acc= 0.7742666666666667 test acc= 0.7789
15%|█▌ | 3/20 [00:33<03:11, 11.26s/it]epoch = 3 train acc= 0.7963333333333333 test acc= 0.7943
20%|██ | 4/20 [00:45<03:00, 11.25s/it]epoch = 4 train acc= 0.8083833333333333 test acc= 0.8007
epoch = 5 train acc= 0.8153833333333333 test acc= 0.8067
30%|███ | 6/20 [01:07<02:37, 11.21s/it]epoch = 6 train acc= 0.8207666666666666 test acc= 0.8139
35%|███▌ | 7/20 [01:18<02:25, 11.21s/it]epoch = 7 train acc= 0.8238333333333333 test acc= 0.8146
epoch = 8 train acc= 0.8278833333333333 test acc= 0.8161
45%|████▌ | 9/20 [01:41<02:03, 11.26s/it]epoch = 9 train acc= 0.8303333333333334 test acc= 0.8196
50%|█████ | 10/20 [01:52<01:52, 11.26s/it]epoch = 10 train acc= 0.8333666666666667 test acc= 0.824
epoch = 11 train acc= 0.8350333333333333 test acc= 0.823
60%|██████ | 12/20 [02:14<01:29, 11.23s/it]epoch = 12 train acc= 0.8369 test acc= 0.8237
65%|██████▌ | 13/20 [02:25<01:18, 11.19s/it]epoch = 13 train acc= 0.8387166666666667 test acc= 0.8254
epoch = 14 train acc= 0.83915 test acc= 0.8264
75%|███████▌ | 15/20 [02:48<00:56, 11.25s/it]epoch = 15 train acc= 0.8408833333333333 test acc= 0.8266
epoch = 16 train acc= 0.8413 test acc= 0.8278
85%|████████▌ | 17/20 [03:11<00:34, 11.35s/it]epoch = 17 train acc= 0.8427666666666667 test acc= 0.8269
90%|█████████ | 18/20 [03:22<00:22, 11.38s/it]epoch = 18 train acc= 0.84355 test acc= 0.828
95%|█████████▌| 19/20 [03:34<00:11, 11.31s/it]epoch = 19 train acc= 0.8446166666666667 test acc= 0.8302
epoch = 20 train acc= 0.84585 test acc= 0.8297
100%|██████████| 20/20 [03:45<00:00, 11.27s/it]
Process finished with exit code 0