文章目录
1. 导入库
- 主要导入了用于深度学习、数据处理、可视化等的库:
import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, random_split from torchvision import datasets, transforms import os import matplotlib.pyplot as plt from PIL import Image import pandas as pd import numpy as np from tqdm import tqdm
2. 数据预处理和可视化
- 设置数据集路径:定义数据集的根目录和类别列表。
data_dir = './flower_images' categories = ['Lilly', 'Lotus', 'Orchid', 'Sunflower', 'Tulip']
- 展示样本图片:展示每个类别的第一张图片,统一调整为150x150像素。
fig, axes = plt.subplots(1, 5, figsize=(12, 6)) for i, category in enumerate(categories): category_path = os.path.join(data_dir, category) image_files = os.listdir(category_path) image_path = os.path.join(category_path, image_files[0]) img = Image.open(image_path) img = img.resize((150, 150)) axes[i].imshow(img) axes[i].set_title(category) axes[i].axis('off') plt.tight_layout() plt.show()
3. 数据统计
- 统计每个类别的图像数量,并绘制条形图。
df = pd.DataFrame() name = [] counts = [] for i in os.listdir(data_dir): name.append(i) counts.append(len(os.listdir(data_dir+'/'+i))) df['Name'] = name df['Counts'] = counts df.head() plt.bar(df['Name'], df['Counts'],