文章目录
前言
记录将变化检测数据集进行划分,将训练数据计划分为训练集和验证集。代码主要使用了shutil.copy函数
import os
import json
import numpy as np
import shutil
# 数据集路径
dataset_root = r"D:\VScode\SSCD\Bi-SRNet\DATA_SCD\train"
output_root = "D:\VScode\SSCD\Bi-SRNet\DATA_SCD"
images_A= os.path.join(dataset_root, "A")
images_B= os.path.join(dataset_root, "B")
label= os.path.join(dataset_root, "label")
def get_all_files_in_dir(path):
all_files = []
for root, dirs, files in os.walk(path):
for file in files:
file_path = os.path.join(root, file)
all_files.append(file_path)
return files
A_list = get_all_files_in_dir(images_A)
b = get_all_files_in_dir(images_B)
# 输出路径
output_train = os.path.join(output_root, "train1")
output_val = os.path.join(output_root, "val1")
os.makedirs(output_root, exist_ok=True)
# 随机打乱数据
np.random.shuffle(A_list)
# 训练集,验证集,测试集比例
train_ratio, val_ratio = 0.8, 0.2
# 计算训练集,验证集,测试集的大小
num_images = len(A_list)
num_train = int(num_images * train_ratio)
num_val = int(num_images * val_ratio)
# 划分数据集
train_images = A_list[:num_train]
val_images = A_list[num_train:num_train + num_val]
# 分别为训练集、验证集和测试集创建子文件夹
train_A = os.path.join(output_train, "A")
val_A = os.path.join(output_val,"A")
train_B = os.path.join(output_train, "B")
val_B = os.path.join(output_val,"B")
train_label = os.path.join(output_train, "label")
val_label = os.path.join(output_val, "label")
os.makedirs(train_A, exist_ok=True)
os.makedirs(val_A, exist_ok=True)
os.makedirs(train_B, exist_ok=True)
os.makedirs(val_B, exist_ok=True)
os.makedirs(train_label, exist_ok=True)
os.makedirs(val_label, exist_ok=True)
#将图片文件复制到相应的子文件夹
for img in train_images:
shutil.copy(os.path.join(images_A, img), os.path.join(train_A, img))
shutil.copy(os.path.join(images_B, img), os.path.join(train_B, img))
shutil.copy(os.path.join(label, img.replace("tif","png")), os.path.join(train_label, img.replace("tif","png")))
for img in val_images:
shutil.copy(os.path.join(images_A, img), os.path.join(val_A, img))
shutil.copy(os.path.join(images_B, img), os.path.join(val_B, img))
shutil.copy(os.path.join(label, img.replace("tif","png")), os.path.join(val_label, img.replace("tif","png")))
print("数据集划分完成!")