pytorch学习2:pytorch搭建Alexnet网络

本文介绍了一个基于PyTorch实现的AlexNet模型,详细展示了模型定义、训练流程及预测方法。采用ImageFolder数据集进行训练,并实现了模型的保存与加载。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

推荐神仙up主 霹雳吧啦Wz 我的代码基本就是按照他的代码自己写了一遍加深印象,有兴趣的可以去看看,强烈推荐。我写这个博客只是记录一下学习的过程,防止忘记。添加了一些注释,帮助理解。

1、模型

import torch.nn as nn
import torch

class AlexNet(nn.Module):
    def __init__(self,num_classes = 1000,init_weights = False):
        super(AlexNet,self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3,48,kernel_size=11,stride=4,padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2),

            nn.Conv2d(48,128,kernel_size=5,padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2),

            nn.Conv2d(128,192,kernel_size=3,padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(192,192,kernel_size=3,padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(192,128,kernel_size=3,padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(128*6*6,2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048,2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048,num_classes),
        )
        if init_weights:
            self._initialize_weights()

    def forward(self,x):
        x = self.features(x)
        x = torch.flatten(x,start_dim=1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m,nn.Conv2d):
                nn.init.kaiming_normal_(m.weight,mode="fan_out",nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias,0)
            elif isinstance(m,nn.Linear):
                nn.init.normal_(m.weight,0,0.01)
                nn.init.constant_(m.bias,0)

训练部分代码

import os
import json
import torch
import torch.nn as nn
from torchvision import transforms,datasets,utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdm
from model import AlexNet


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([
         transforms.RandomResizedCrop(224),
         transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]),
         "val": transforms.Compose([transforms.Resize((224,224)),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])}
    data_root = os.path.abspath(os.path.join(os.getcwd(),"../.."))
    #data_root = D:\classic_nets
    image_path = os.path.join(data_root,"data_set","flower_data")
    print(image_path)
    assert os.path.exists(image_path),"{} path does not exit.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path,"train"),
                                         transform = data_transform["train"])
    train_num = len(train_dataset)
    print(train_num)

    flower_list = train_dataset.class_to_idx
    #flower_list  {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
    cla_dict = dict((val,key) for key,val in flower_list.items())
    #print(cla_dict)#{0: 'daisy', 1: 'dandelion', 2: 'roses', 3: 'sunflowers', 4: 'tulips'}
    #write dict into json file
    json_str = json.dumps(cla_dict,indent=4)
    with open("class_indices.json",'w') as json_file:
        json_file.write(json_str)

    batch_size = 32
    nw = min([os.cpu_count(),batch_size if batch_size>1 else 0,8])
    print("using {} dataloader workers every process".format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size = batch_size,
                                               shuffle = True,
                                               num_workers = 0)
    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path,"val"),
                                            transform = data_transform["val"])

    val_num = len(validate_dataset)
    print(val_num)#val_num = 364
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size = 4,
                                                  shuffle=False,
                                                  num_workers = 0)
    print("using {} images for training, {} images for validation .".format(train_num,val_num))





    net = AlexNet(num_classes = 5,init_weights=True)
    net.to(device)
    #当我们指定了设备之后,就需要将模型加载到相应设备中,此时需要使用model=model.to(device),将模型加载到相应的设备中
    loss_function = nn.CrossEntropyLoss()
    #损失函数
    optimizer = optim.Adam(net.parameters(),lr=0.0002)
    #训练所有的参数,学习率为0.0002

    epochs = 10
    save_path = './AlexNet.pth'
    best_acc = 0.0
    train_steps = len(train_loader)
    for epoch in range(epochs):
        #train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader)
        print(train_bar)
        for step,data in enumerate(train_bar):
            images,labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs,labels.to(device))
            loss.backward()
            optimizer.step()
            '''
            总得来说,这三个函数的作用是先将梯度归零(optimizer.zero_grad()),
            然后反向传播计算得到每个参数的梯度值(loss.backward()),
            最后通过梯度下降执行参数更新(optimizer.step())
            '''


            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch +1,
                                                                     epochs,
                                                                     loss)


        #validate
        net.eval()
        acc = 0.0
        with torch.no_grad():
            val_bar = tqdm(validate_loader)
            for val_data in val_bar:
                val_images,val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs,dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
        val_accurate = acc/val_num
        print('[epoch %d] train_loss:%.3f val_accuracy:%.3f'%
              (epoch+1,running_loss/train_steps,val_accurate))



        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(),save_path)

    print("Finished Training")




if __name__ =='__main__':
    main()

预测部分代码

import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import AlexNet

def main(num_classes):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose(
        [
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
        ]
    )

    #load image
    img_path = r"D:\classic_nets\data_set\flower_data\val\daisy\253426762_9793d43fcd.jpg"
    assert os.path.exists(img_path),"file:'{}’ does not exist" .format(img_path)
    # assert  检查条件,不符合就终止程序

    img = Image.open(img_path)
    plt.imshow(img)
    img = data_transform(img)
    #expand batch dimension
    img = torch.unsqueeze(img,dim=0)#添加一个维度,网络输入的是一个四维的

    #read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path),"file: '{}' does not exist.".format(json_path)

    json_file = open(json_path,'r')
    class_indict = json.load(json_file)

    #creat model
    model = AlexNet(num_classes = num_classes).to(device)

    #load model weights
    weights_path = "./AlexNet.pth"
    assert os.path.exists(weights_path),"file:weights path doesn't exist"
    model.load_state_dict(torch.load(weights_path))

    model.eval()
    with torch.no_grad():
        #predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output,dim=0)
        predict_cla = torch.argmax(predict).numpy()


    print_result = "class:{} prob:{:.3f}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    plt.title(print_result)
    print(print_result)
    plt.show()

if __name__ == '__main__':
    main(num_classes=5)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值