import json
import random
def balance_empty_targets(coco_json_file, output_file):
with open(coco_json_file, 'r') as file:
data = json.load(file)
image_ids_with_annotations = set(anno['image_id'] for anno in data['annotations'])
images_with_targets = [image for image in data['images'] if image['id'] in image_ids_with_annotations]
images_without_targets = [image for image in data['images'] if image['id'] not in image_ids_with_annotations]
num_to_remove = len(images_without_targets) - len(images_with_targets)
if num_to_remove > 0:
images_to_remove = random.sample(images_without_targets, num_to_remove)
images_without_targets = [image for image in images_without_targets if image not in images_to_remove]
balanced_images = images_with_targets + images_without_targets
data['images'] = balanced_images
with open(output_file, 'w') as file:
json.dump(data, file)
if __name__ == '__main__':
coco_json_file = './train.json'
output_file = './balanced_coco.json'
balance_empty_targets(coco_json_file, output_file)