在PyTorch编程技术与深度学习中的`dogs_cats_split.py`:实现将原始的猫狗数据集划分为小的训练集、验证集和测试集保存到对应目录

文章目录

      • 1. 模块导入与全局变量设置
      • 2. 创建目标目录结构
      • 3. 复制猫的图片到各子集
      • 4. 复制狗的图片到各子集
      • 5. 验证并打印各子集样本数
      • **注意事项**
      • 总结

1. 模块导入与全局变量设置

# -*- coding: utf-8 -*-
"""
实现将原始的猫狗数据集划分为小的训练集、验证集和测试集
保存到对应目录

PyTorch Programming & Deep Learning
@author: Mike Yuan, Copyright 2020~2021
"""

import os
import shutil
from pathlib import Path`
# 训练集猫或狗的样本数
NUM_TRAIN_EXAMPLES = 1000
# 验证集猫或狗的样本数
NUM_VALID_EXAMPLES = 500
# 测试集猫或狗的样本数
NUM_TEST_EXAMPLES = 500

# 原始的猫狗数据集路径
ORIGINAL_DATASET_DIR = '../datasets/kaggledogvscat/original_train'

# 存储较小数据集的目录
BASE_DIR = '../datasets/kaggledogvscat/small'
  • 功能
    • 导入文件操作和路径处理模块。
    • 定义各子集(训练、验证、测试)的样本数量和路径常量。
  • 注意
    • ORIGINAL_DATASET_DIR 是原始数据集的根目录,假设其中包含 cat.0.jpgcat.12499.jpgdog.0.jpgdog.12499.jpg 的文件。
    • NUM_*_EXAMPLES 需确保总样本数不超过原始数据集的总量(例如猫狗各至少 2000 张)。

2. 创建目标目录结构

# 检查目标目录是否存在,若存在则报错退出
if Path(BASE_DIR).exists():
    print("错误:目标目录已存在!")
    os._exit(-1)
os.mkdir(BASE_DIR)

# 创建训练集、验证集、测试集目录
train_dir = os.path.join(BASE_DIR, 'train')
validation_dir = os.path.join(BASE_DIR, 'validation')
test_dir = os.path.join(BASE_DIR, 'test')
os.mkdir(train_dir)
os.mkdir(validation_dir)
os.mkdir(test_dir)

# 在每个子集目录下创建猫和狗的子目录
train_cats_dir = os.path.join(train_dir, 'cats')
os.mkdir(train_cats_dir)
train_dogs_dir = os.path.join(train_dir, 'dogs')
os.mkdir(train_dogs_dir)

validation_cats_dir = os.path.join(validation_dir, 'cats')
os.mkdir(validation_cats_dir)
validation_dogs_dir = os.path.join(validation_dir, 'dogs')
os.mkdir(validation_dogs_dir)

test_cats_dir = os.path.join(test_dir, 'cats')
os.mkdir(test_cats_dir)
test_dogs_dir = os.path.join(test_dir, 'dogs')
os.mkdir(test_dogs_dir)
  • 功能
    • 创建嵌套目录结构,例如 small/train/catssmall/train/dogs
  • 注意
    • 如果 BASE_DIR 已存在,程序会直接退出以避免覆盖。
    • 目录层级需与后续代码中的路径操作严格匹配。

3. 复制猫的图片到各子集

# 复制猫图片到训练集
fnames = ['cat.{}.jpg'.format(i) for i in range(NUM_TRAIN_EXAMPLES)]
for fname in fnames:
    src = os.path.join(ORIGINAL_DATASET_DIR, fname)
    dst = os.path.join(train_cats_dir, fname)
    shutil.copyfile(src, dst)

# 复制猫图片到验证集
fnames = ['cat.{}.jpg'.format(i) for i in range(NUM_TRAIN_EXAMPLES, NUM_TRAIN_EXAMPLES + NUM_VALID_EXAMPLES)]
for fname in fnames:
    src = os.path.join(ORIGINAL_DATASET_DIR, fname)
    dst = os.path.join(validation_cats_dir, fname)
    shutil.copyfile(src, dst)

# 复制猫图片到测试集
fnames = ['cat.{}.jpg'.format(i) for i in range(NUM_TRAIN_EXAMPLES + NUM_VALID_EXAMPLES, 
                                             NUM_TRAIN_EXAMPLES + NUM_VALID_EXAMPLES + NUM_TEST_EXAMPLES)]
for fname in fnames:
    src = os.path.join(ORIGINAL_DATASET_DIR, fname)
    dst = os.path.join(test_cats_dir, fname)
    shutil.copyfile(src, dst)
  • 功能
    • 将原始数据集中猫的图片按索引范围复制到训练集、验证集和测试集目录。
    • 索引逻辑
      • 训练集:0~999
      • 验证集:1000~1499
      • 测试集:1500~1999
  • 注意
    • 假设原始数据集中猫的图片文件名是连续且按顺序命名的。
    • 若原始数据索引不连续或文件名格式不符,会导致复制失败。

4. 复制狗的图片到各子集

# 复制狗图片到训练集
fnames = ['dog.{}.jpg'.format(i) for i in range(NUM_TRAIN_EXAMPLES)]
for fname in fnames:
    src = os.path.join(ORIGINAL_DATASET_DIR, fname)
    dst = os.path.join(train_dogs_dir, fname)
    shutil.copyfile(src, dst)

# 复制狗图片到验证集
fnames = ['dog.{}.jpg'.format(i) for i in range(NUM_TRAIN_EXAMPLES, NUM_TRAIN_EXAMPLES + NUM_VALID_EXAMPLES)]
for fname in fnames:
    src = os.path.join(ORIGINAL_DATASET_DIR, fname)
    dst = os.path.join(validation_dogs_dir, fname)
    shutil.copyfile(src, dst)

# 复制狗图片到测试集
fnames = ['dog.{}.jpg'.format(i) for i in range(NUM_TRAIN_EXAMPLES + NUM_VALID_EXAMPLES, 
                                             NUM_TRAIN_EXAMPLES + NUM_VALID_EXAMPLES + NUM_TEST_EXAMPLES)]
for fname in fnames:
    src = os.path.join(ORIGINAL_DATASET_DIR, fname)
    dst = os.path.join(test_dogs_dir, fname)
    shutil.copyfile(src, dst)
  • 逻辑与猫的处理完全一致,仅将 cat 替换为 dog

5. 验证并打印各子集样本数

print('训练集中猫的图片总数:', len(os.listdir(train_cats_dir)))
print('训练集中狗的图片总数:', len(os.listdir(train_dogs_dir)))
print('验证集中猫的图片总数:', len(os.listdir(validation_cats_dir)))
print('验证集中狗的图片总数:', len(os.listdir(validation_dogs_dir)))
print('测试集中猫的图片总数:', len(os.listdir(test_cats_dir)))
print('测试集中狗的图片总数:', len(os.listdir(test_dogs_dir)))
  • 功能:输出各子集的样本数量,验证复制操作是否成功。
  • 期望输出
    训练集中猫的图片总数:1000
    训练集中狗的图片总数:1000
    验证集中猫的图片总数:500
    验证集中狗的图片总数:500
    测试集中猫的图片总数:500
    测试集中狗的图片总数:500
    

注意事项

  1. 原始数据集结构

    • 原始目录 original_train 必须包含按 cat.{i}.jpgdog.{i}.jpg 命名的文件,且索引从0开始连续。
    • 若文件名格式不符(如包含其他字符或索引不连续),代码会报错。
  2. 目标目录冲突

    • 如果 BASE_DIR 已存在,程序会直接退出。运行前需手动删除旧目录或修改 BASE_DIR 路径。
  3. 样本数量限制

    • 确保 NUM_TRAIN_EXAMPLES + NUM_VALID_EXAMPLES + NUM_TEST_EXAMPLES 不超过原始数据集的总样本数(例如猫狗各至少 2000 张)。
  4. 跨平台兼容性

    • 路径分隔符在 Windows 和 Linux/macOS 中不同,建议使用 os.path.join(代码已正确处理)。
  5. 文件权限

    • 需确保程序有权限在目标路径创建目录和写入文件。
  6. 扩展性

    • 当前代码仅支持按固定索引划分数据集,若需随机划分,需引入随机抽样逻辑(例如使用 random.sample)。

总结

此代码实现了以下功能:

  1. 创建标准化的目录结构。
  2. 按固定索引范围划分猫狗数据集。
  3. 复制文件到对应目录。
  4. 验证划分结果。

适用于需要快速构建小型数据集的场景,但需严格确保原始数据集的命名和结构符合要求。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值