文章目录
- 1. 模块导入与全局变量设置
- 2. 创建目标目录结构
- 3. 复制猫的图片到各子集
- 4. 复制狗的图片到各子集
- 5. 验证并打印各子集样本数
- **注意事项**
- 总结
1. 模块导入与全局变量设置
# -*- coding: utf-8 -*-
"""
实现将原始的猫狗数据集划分为小的训练集、验证集和测试集
保存到对应目录
PyTorch Programming & Deep Learning
@author: Mike Yuan, Copyright 2020~2021
"""
import os
import shutil
from pathlib import Path`
# 训练集猫或狗的样本数
NUM_TRAIN_EXAMPLES = 1000
# 验证集猫或狗的样本数
NUM_VALID_EXAMPLES = 500
# 测试集猫或狗的样本数
NUM_TEST_EXAMPLES = 500
# 原始的猫狗数据集路径
ORIGINAL_DATASET_DIR = '../datasets/kaggledogvscat/original_train'
# 存储较小数据集的目录
BASE_DIR = '../datasets/kaggledogvscat/small'
- 功能:
- 导入文件操作和路径处理模块。
- 定义各子集(训练、验证、测试)的样本数量和路径常量。
- 注意:
ORIGINAL_DATASET_DIR
是原始数据集的根目录,假设其中包含cat.0.jpg
到cat.12499.jpg
和dog.0.jpg
到dog.12499.jpg
的文件。NUM_*_EXAMPLES
需确保总样本数不超过原始数据集的总量(例如猫狗各至少 2000 张)。
2. 创建目标目录结构
# 检查目标目录是否存在,若存在则报错退出
if Path(BASE_DIR).exists():
print("错误:目标目录已存在!")
os._exit(-1)
os.mkdir(BASE_DIR)
# 创建训练集、验证集、测试集目录
train_dir = os.path.join(BASE_DIR, 'train')
validation_dir = os.path.join(BASE_DIR, 'validation')
test_dir = os.path.join(BASE_DIR, 'test')
os.mkdir(train_dir)
os.mkdir(validation_dir)
os.mkdir(test_dir)
# 在每个子集目录下创建猫和狗的子目录
train_cats_dir = os.path.join(train_dir, 'cats')
os.mkdir(train_cats_dir)
train_dogs_dir = os.path.join(train_dir, 'dogs')
os.mkdir(train_dogs_dir)
validation_cats_dir = os.path.join(validation_dir, 'cats')
os.mkdir(validation_cats_dir)
validation_dogs_dir = os.path.join(validation_dir, 'dogs')
os.mkdir(validation_dogs_dir)
test_cats_dir = os.path.join(test_dir, 'cats')
os.mkdir(test_cats_dir)
test_dogs_dir = os.path.join(test_dir, 'dogs')
os.mkdir(test_dogs_dir)
- 功能:
- 创建嵌套目录结构,例如
small/train/cats
和small/train/dogs
。
- 创建嵌套目录结构,例如
- 注意:
- 如果
BASE_DIR
已存在,程序会直接退出以避免覆盖。 - 目录层级需与后续代码中的路径操作严格匹配。
- 如果
3. 复制猫的图片到各子集
# 复制猫图片到训练集
fnames = ['cat.{}.jpg'.format(i) for i in range(NUM_TRAIN_EXAMPLES)]
for fname in fnames:
src = os.path.join(ORIGINAL_DATASET_DIR, fname)
dst = os.path.join(train_cats_dir, fname)
shutil.copyfile(src, dst)
# 复制猫图片到验证集
fnames = ['cat.{}.jpg'.format(i) for i in range(NUM_TRAIN_EXAMPLES, NUM_TRAIN_EXAMPLES + NUM_VALID_EXAMPLES)]
for fname in fnames:
src = os.path.join(ORIGINAL_DATASET_DIR, fname)
dst = os.path.join(validation_cats_dir, fname)
shutil.copyfile(src, dst)
# 复制猫图片到测试集
fnames = ['cat.{}.jpg'.format(i) for i in range(NUM_TRAIN_EXAMPLES + NUM_VALID_EXAMPLES,
NUM_TRAIN_EXAMPLES + NUM_VALID_EXAMPLES + NUM_TEST_EXAMPLES)]
for fname in fnames:
src = os.path.join(ORIGINAL_DATASET_DIR, fname)
dst = os.path.join(test_cats_dir, fname)
shutil.copyfile(src, dst)
- 功能:
- 将原始数据集中猫的图片按索引范围复制到训练集、验证集和测试集目录。
- 索引逻辑:
- 训练集:
0~999
- 验证集:
1000~1499
- 测试集:
1500~1999
- 训练集:
- 注意:
- 假设原始数据集中猫的图片文件名是连续且按顺序命名的。
- 若原始数据索引不连续或文件名格式不符,会导致复制失败。
4. 复制狗的图片到各子集
# 复制狗图片到训练集
fnames = ['dog.{}.jpg'.format(i) for i in range(NUM_TRAIN_EXAMPLES)]
for fname in fnames:
src = os.path.join(ORIGINAL_DATASET_DIR, fname)
dst = os.path.join(train_dogs_dir, fname)
shutil.copyfile(src, dst)
# 复制狗图片到验证集
fnames = ['dog.{}.jpg'.format(i) for i in range(NUM_TRAIN_EXAMPLES, NUM_TRAIN_EXAMPLES + NUM_VALID_EXAMPLES)]
for fname in fnames:
src = os.path.join(ORIGINAL_DATASET_DIR, fname)
dst = os.path.join(validation_dogs_dir, fname)
shutil.copyfile(src, dst)
# 复制狗图片到测试集
fnames = ['dog.{}.jpg'.format(i) for i in range(NUM_TRAIN_EXAMPLES + NUM_VALID_EXAMPLES,
NUM_TRAIN_EXAMPLES + NUM_VALID_EXAMPLES + NUM_TEST_EXAMPLES)]
for fname in fnames:
src = os.path.join(ORIGINAL_DATASET_DIR, fname)
dst = os.path.join(test_dogs_dir, fname)
shutil.copyfile(src, dst)
- 逻辑与猫的处理完全一致,仅将
cat
替换为dog
。
5. 验证并打印各子集样本数
print('训练集中猫的图片总数:', len(os.listdir(train_cats_dir)))
print('训练集中狗的图片总数:', len(os.listdir(train_dogs_dir)))
print('验证集中猫的图片总数:', len(os.listdir(validation_cats_dir)))
print('验证集中狗的图片总数:', len(os.listdir(validation_dogs_dir)))
print('测试集中猫的图片总数:', len(os.listdir(test_cats_dir)))
print('测试集中狗的图片总数:', len(os.listdir(test_dogs_dir)))
- 功能:输出各子集的样本数量,验证复制操作是否成功。
- 期望输出:
训练集中猫的图片总数:1000 训练集中狗的图片总数:1000 验证集中猫的图片总数:500 验证集中狗的图片总数:500 测试集中猫的图片总数:500 测试集中狗的图片总数:500
注意事项
-
原始数据集结构:
- 原始目录
original_train
必须包含按cat.{i}.jpg
和dog.{i}.jpg
命名的文件,且索引从0开始连续。 - 若文件名格式不符(如包含其他字符或索引不连续),代码会报错。
- 原始目录
-
目标目录冲突:
- 如果
BASE_DIR
已存在,程序会直接退出。运行前需手动删除旧目录或修改BASE_DIR
路径。
- 如果
-
样本数量限制:
- 确保
NUM_TRAIN_EXAMPLES + NUM_VALID_EXAMPLES + NUM_TEST_EXAMPLES
不超过原始数据集的总样本数(例如猫狗各至少 2000 张)。
- 确保
-
跨平台兼容性:
- 路径分隔符在 Windows 和 Linux/macOS 中不同,建议使用
os.path.join
(代码已正确处理)。
- 路径分隔符在 Windows 和 Linux/macOS 中不同,建议使用
-
文件权限:
- 需确保程序有权限在目标路径创建目录和写入文件。
-
扩展性:
- 当前代码仅支持按固定索引划分数据集,若需随机划分,需引入随机抽样逻辑(例如使用
random.sample
)。
- 当前代码仅支持按固定索引划分数据集,若需随机划分,需引入随机抽样逻辑(例如使用
总结
此代码实现了以下功能:
- 创建标准化的目录结构。
- 按固定索引范围划分猫狗数据集。
- 复制文件到对应目录。
- 验证划分结果。
适用于需要快速构建小型数据集的场景,但需严格确保原始数据集的命名和结构符合要求。