1:作业简介:猫狗大战
1.1问题描述:
猫狗问题是一个经典的机器学习分类问题,目标是通过计算机算法自动识别图像中是猫还是狗。
1.2预期解决方案:
通过训练一个机器学习模型,使其在给定一张图像时能够准确地预测图像中是猫还是狗。模型应该能够推广到未见过的图像,并在测试数据上表现良好。
1.3数据集
训练集:
链接:https://pan.baidu.com/s/1bsJZmnR5I38rucyB1V6qEA
提取码:asd2
测试集:
链接:https://pan.baidu.com/s/1uao8yPk2PQFlWJ6tJAtOeA
提取码:asd3
1.4图像展示
2:数据预处理
2.1数据集结构
本项目数据集共由三部份组成,分别包含train、test1和test文件夹,其中test1文件夹中数据,为不带标签的测试集,test中为带标签的测试集。
其中train文件夹中包含25000张带有标签的猫狗图片,用作训练集。test1文件夹中包含12500张没有标签的图片。
2.2探索性数据分析
先导入接下来将要使用的包
import os
import sys
import time
import argparse
import itertools
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import intel_extension_for_pytorch as ipex
import pandas as pd
from torch import nn
from torch import optim
from torch.autograd import Variable
from torchvision import models
from matplotlib.patches import Rectangle
from sklearn.metrics import confusion_matrix, accuracy_score, balanced_accuracy_score
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision import transforms
from sklearn.model_selection import train_test_split, StratifiedKFold
from torch.utils.data import DataLoader
from sklearn.metrics import f1_score
from torch.utils.data.dataset import Subset
查看数据集的大小以及数据集中文件名称
#返回指定路径下的所有文件和目录的名称列表
train_path = '../data/train'
test_path = '../data/test1'
train_file_names = os.listdir(train_path)
test_file_names = os.listdir(test_path)
print("训练集大小:{}".format(len(train_file_names)))
print("测试集大小:{}".format(len(test_file_names)))
print("训练集样例:{}".format(train_file_names[0:5]))#训练集文件名:标签+序号
print("测试集样例:{}".format(test_file_names[0:5])) #测试集文件名:序号
查看train文件夹中cat和dog的数量并提取三个不同的猫狗图片进行展示:
train_path = '../data/train'
image_files = [file for file in os.listdir(train_path) if file.lower().endswith( '.jpg')]
# 猫和狗的路径
cat_imgs = [file for file in image_files if file.lower().startswith('cat') and len(file) >= 3]
dog_imgs = [file for file in image_files if file.lower().startswith('dog') and len(file) >= 3]
print(f'猫的数量为: {len(cat_imgs)}')
print(f'狗的数量为: {len(dog_imgs)}')
# 随机不重样的抽选3个猫 3个狗
select_cat = np.random.choice(cat_imgs, 3, replace = False)
select_dog = np.random.choice(dog_imgs, 3, replace = False)
# 使用plt打印出来
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
for i in range(6):
if i < 3:
fp = os.path.join(train_path, select_cat[i])
label = 'CAT'
else:
fp = os.path.join(train_path, select_dog[i-3])
label = 'DOG'
# 加载图像
img = Image.open(fp)
# 显示图像
axes[i // 3, i % 3].imshow(img)
axes[i // 3, i % 3].set_title(label)
axes[i // 3, i % 3].axis('off')
# 显示图像
plt.show()
2.3自定义数据集
为了更好的提取图片,我就将train数据集中的猫狗图片进行分类并打好标签
class SelfDataset(Dataset):
#初始化接受数据的路径 和数据转换的函数
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.data = self.load_data()
def load_data(self):
data = []
for file_name in os.listdir(self.root_dir):
file_path = os.path.join(self.root_dir, file_name)
if os.path.isfile(file_path) and file_name.lower().endswith('.jpg'):
class_label = 0 if file_name.lower().startswith('cat') else 1
data.append((file_path, class_label))
return data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
img_path, label = self.data[idx]
img = Image.open