注意⚠️:请根据实际情况修改代码中的路径
一、根据分类id提取人的标签并生成新的json
import json
from tqdm import tqdm
def filter_coco_data(input_json, output_json, target_category_id=1):
"""
过滤COCO数据,保留指定类别的标注及关联图片
参数:
input_json: 输入JSON文件路径
output_json: 输出JSON文件路径
target_category_id: 目标类别ID(默认为1)
"""
with open(input_json, 'r') as f:
data = json.load(f)
valid_image_ids = set()
filtered_annotations = []
for ann in tqdm(data['annotations'], desc="过滤标注"):
if ann['category_id'] == target_category_id:
filtered_annotations.append(ann)
valid_image_ids.add(ann['image_id'])
filtered_images = [
img for img in data['images']
if img['id'] in valid_image_ids
]
new_data = {
"info": data.get('info', {}),
"licenses": data.get('licenses', []),
"categories": data['categories'],
"images": filtered_images,
"annotations": filtered_annotations
}
with open(output_json, 'w') as f:
json.dump(new_data, f, indent=2)
print(f"原始数据:{len(data['images'])} 图片, {len(data['annotations'])} 标注")
print(f"过滤后:{len(filtered_images)} 图片, {len(filtered_annotations)} 标注")
if __name__ == "__main__":
filter_coco_data(
input_json="D:\\HRNet\\coco2017\\annotations\\person_keypoints_train2017.json",
output_json="D:\\HRNet\\coco2025\\annotations\\train.json",
target_category_id=1
)
二、根据新json提取对应的图片
import json
import os
import shutil
def extract_coco_images(json_path, src_img_dir, dst_dir):
"""
从COCO数据集中提取指定json对应的图片
参数:
json_path: COCO格式的json文件路径
src_img_dir: 原始图片存放目录
dst_dir: 目标存放目录
"""
os.makedirs(dst_dir, exist_ok=True)
with open(json_path, 'r') as f:
data = json.load(f)
for img_info in data['images']:
src_path = os.path.join(src_img_dir, img_info['file_name'])
dst_path = os.path.join(dst_dir, img_info['file_name'])
if not os.path.exists(src_path):
print(f"警告:文件 {src_path} 不存在,已跳过")
continue
shutil.copy(src_path, dst_path)
print(f"完成!共复制 {len(data['images'])} 张图片到 {dst_dir}")
if __name__ == "__main__":
extract_coco_images(
json_path="D:\\HRNet\\dataset\\coco2025\\annotations\\val2025.json",
src_img_dir="D:\\HRNet\\coco2025\\val2025",
dst_dir="D:\\HRNet\\dataset\\coco2025\\val2025"
)
三、从新json中取1000张图片标签生成小json作小数据集,之后再使用标题二中的代码再提取出这一千张图片即可
import json
def filter_coco_dataset(input_json_path, output_json_path, num_images=1000):
with open(input_json_path, 'r') as f:
data = json.load(f)
selected_images = data['images'][:num_images]
selected_image_ids = {img['id'] for img in selected_images}
selected_annotations = [
ann for ann in data['annotations']
if ann['image_id'] in selected_image_ids
]
filtered_data = {
"info": data.get("info", {}),
"licenses": data.get("licenses", []),
"categories": data.get("categories", []),
"images": selected_images,
"annotations": selected_annotations
}
with open(output_json_path, 'w') as f:
json.dump(filtered_data, f, indent=2)
filter_coco_dataset(
input_json_path='D:\\HRNet\\coco2025\\annotations\\val2025.json',
output_json_path='D:\\HRNet\\dataset\\coco2025\\annotations\\val2025.json',
num_images=1000
)