选择案例
在给出的课程实践选题中选择猫狗识别项目:
下载相应数据集
项目实施
1.环境搭建
A.选择项目工具
将jupyter作为本项目的实施工具。
B.安装pytorch框架和所需的第三方库
项目所用库
import os
import sys
import time
import argparse
import itertools
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import intel_extension_for_pytorch as ipex
import pandas as pd
from torch import nn
from torch import optim
from torch.autograd import Variable
from torchvision import models
from matplotlib.patches import Rectangle
from sklearn.metrics import confusion_matrix, accuracy_score, balanced_accuracy_score
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision import transforms
from sklearn.model_selection import train_test_split, StratifiedKFold
from torch.utils.data import DataLoader
from sklearn.metrics import f1_score
C.配置cuda环境
由于模型在GPU上训练更快,所以需要nvidia的GPU和cuda环境。
打开cmd,输入nvidia-smi查看显卡支持的最大cuda版本
然后去网络上选择合适版本的cuda已经配套的cudnn进行安装
本项目选择cuda11.8
安装完cuda后对cudnn文件进行配置并设置好对应的4个环境变量
打开jupyterlab,运行以下代码
import torch
print(torch.__version__)
print("是否可用:", torch.cuda.is_available()) # 查看GPU是否可用
print("GPU数量:", torch.cuda.device_count()) # 查看GPU数量
print("torch方法查看CUDA版本:", torch.version.cuda) # torch方法查看CUDA版本
print("GPU索引号:", torch.cuda.current_device()) # 查看GPU索引号
结果:
2.数据集划分
#数据分类,选择猫和狗各13000张作训练集,各100张作测试集
import os
import shutil
root_path = ''
animal = ['cat','dog']
style = ['train','test']
train_data_path = os.path.join(root_path,'train_data')
if not os.path.exists(train_data_path):
os.makedirs(train_data_path)
for f in style:
folder_path = os.path.join(train_data_path,f)
if not os.path.exists(folder_path):
os.makedirs(folder_path)
for name in animal:
under_path = os.path.join(folder_path,name)
if not os.path.exists(under_path):
os.makedirs(under_path)
for filename in os.listdir('train'):
file_path = os.path.join('train',filename)
if os.path.isfile(file_path):
if filename.startswith(name):
contents = filename.split('.')
num = int(contents[1])
if f == 'test' and num >12399:
shutil.copy(file_path,under_path)
elif f == 'train' and num < 13000:#数据集是按顺序排列的所以可以使用这种方法
shutil.copy(file_path,under_path)
3.数据处理
#数据处理
transform = transforms.Compose([
transforms.RandomResizedCrop((224,224)),
transforms.RandomRotation(20),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor()
])#对图片进行变换和数据增强
root = 'train_data'