import os
import yaml
def load_dataset_paths(yaml_path):
with open(yaml_path, 'r', encoding='utf-8') as f:
data = yaml.safe_load(f)
train_paths = data.get("train", [])
val_paths = data.get("val", [])
test_paths = data.get("test", [])
# 保证是列表
if not isinstance(train_paths, list):
train_paths = [train_paths]
if not isinstance(val_paths, list):
val_paths = [val_paths]
if not isinstance(test_paths, list):
test_paths = [test_paths]
return train_paths, val_paths, test_paths
if __name__ == '__main__':
# 示例
yaml_file = r"data_jiezhi_jita.yaml"
train_paths, val_paths, test_paths = load_dataset_paths(yaml_file)
print("Train paths:")
for dir_a in train_paths:
img_files = ['%s/%s' % (i[0], j) for i in os.walk(dir_a) for j in i[-1] if j.endswith(('.cache', 'xpng', 'jpeg'))]
for cache_path in img_files:
os.remove(cache_path)
print(" ", dir_a)
print("\nVal paths:")
for dir_a in val_paths:
img_files = ['%s/%s' % (i[0], j) for i in os.walk(dir_a) for j in i[-1] if j.endswith(('.cache', 'xpng', 'jpeg'))]
for cache_path in img_files:
os.remove(cache_path)
print(" ", dir_a)