进行训练之前要先进行标注及数据增强,标注工具写在另一篇,首先讲数据增强。
数据增强
进行简单的色彩变换和位置变换,代码如下:
from PIL import Image, ImageEnhance
import numpy as np
import os
import glob
import json
import torch
import torchvision.transforms as transforms
from TrainingTools.ImageResizeTool import input_path, output_path
def color(flag, input_directory, output_directory):
# 创建输出目录
if not os.path.exists(output_directory):
os.makedirs(output_directory)
# 遍历输入目录中的图片文件
for filename in os.listdir(input_directory):
if filename.endswith(".jpg") or filename.endswith(".png"):
# 读取图片
image_path = os.path.join(input_directory, filename)
original_image = Image.open(image_path)
# 色彩增强参数,可以根据需求进行调整
brightness_factor = np.random.uniform(0.6, 1.8) # 亮度
contrast_factor = np.random.uniform(0.7, 1.7) # 对比度
saturation_factor = np.random.uniform(0.7, 1.7) # 饱和度
# 对图像进行色彩增强
enhancer = ImageEnhance.Brightness(original_image)
enhanced_image = enhancer.enhance(brightness_factor)
enhancer = ImageEnhance.Contrast(enhanced_image)
enhanced_image = enhancer.enhance(contrast_factor)
enhancer = ImageEnhance.Color(enhanced_image)
enhanced_image = enhancer.enhance(saturation_factor)
new_name = os.path.splitext(filename)[0] + "-" + flag + ".jpg"
json_file = os.path.join(input_directory, os.path.splitext(filename)[0] +".json")
with open(json_file, 'r') as fj:
data = json.load(fj)
data["imagepath"] = new_name
saved_json = os.path.join(output_directory, os.path.splitext(new_name)[0] + ".json")
json.dump(data, open(saved_json, 'w'), ensure_ascii=False, indent=2)
# 保存增强后的图片
output_path = os.path.join(output_directory, new_name)
enhanced_image.save(output_path)
print(output_path)
print("数据增强完成。")
def position(flag, data_folder, output_folder):
os.makedirs(output_folder, exist_ok=True)
# 获取文件夹内所有JSON文件路径
json_files = glob.glob(os.path.join(data_folder, '*.json'))
# 定义图像变换
transform = transforms.RandomAffine(degrees=30, translate=(0.2, 0.2), scale=(0.8, 1.2), shear=10)
for json_file_path in json_files:
flage = 0
while True:
# 读取JSON文件
with open(json_file_path, 'r') as json_file:
data = json.load(json_file)
# 图像路径
image_path = os.path.splitext(json_file_path)[0] + ".jpg"
# 读取图像并获取图像宽高
image = Image.open(image_path)
image_width, image_height = image.size
aspect_ratio = image_width / image_height
crop_height = int(torch.randint(int(image_height * 0.5), image_height, (1,)).item())
crop_width = int(crop_height * aspect_ratio)
transform = transforms.CenterCrop((crop_height, crop_width))
transformed_image = transform(image)
# 获取变换后图像的宽高
transformed_image_width, transformed_image_height = transformed_image.size
crop_zero_point = ((image_width - transformed_image_width) / 2, (image_height - transformed_image_height) / 2)
keypoints = data["keypoints"][0]
for i in range(0, len(keypoints), 3):
x, y, v = keypoints[i], keypoints[i + 1], keypoints[i + 2]
keypoints[i], keypoints[i + 1], keypoints[i + 2] = [
x - crop_zero_point[0], y - crop_zero_point[1], v]
# 判断是否有关键点超出图像边缘
if all(0 <= keypoints[i] < transformed_image_width and 0 <= keypoints[i+1] < transformed_image_height for i in range(0, len(keypoints), 3)):
break # 如果所有关键点在图像内,则退出循环
flage += 1
print(flage)
if flage == 20:
break
if flage == 20:
continue
resized_image = transformed_image.resize((image_width, image_height))
resize_ration = image_width / transformed_image_width
# 将关键点坐标转换回绝对坐标
for i in range(0, len(keypoints), 3):
x, y, v = [keypoints[i], keypoints[i + 1], keypoints[i + 2]]
keypoints[i], keypoints[i+1], keypoints[i+2] = [
x * resize_ration, y * resize_ration, v]
# 更新数据
data["keypoints"] = [keypoints]
data["imagepath"] = os.path.splitext(data["imagepath"])[0] + "-" + flag + ".jpg"
# 保存增强后的图像
transformed_image_path = os.path.join(output_folder, os.path.splitext(os.path.basename(image_path))[0] + "-" + flag + ".jpg")
resized_image.save(transformed_image_path)
# 更新 JSON 文件
transformed_json_file_path = os.path.join(output_folder, os.path.splitext(os.path.basename(json_file_path))[0] + "-" + flag + ".json")
with open(transformed_json_file_path, 'w') as transformed_json_file:
json.dump(data, transformed_json_file, indent=2)
print(transformed_image_path)
print("Data augmentation completed.")
input_path = ".../mmpose/data/coco/ori"
output_path = ".../mmpose/data/coco/col1"
flag = "col1"
color(flag, input_path, output_path) # flag = col1 col2
# position(flag, input_path, output_path) # flag = posi
一般会进行一次