mmpose框架进行人体姿态识别模型HRNet训练

进行训练之前要先进行标注及数据增强,标注工具写在另一篇,首先讲数据增强。

数据增强

进行简单的色彩变换和位置变换,代码如下:

from PIL import Image, ImageEnhance
import numpy as np
import os
import glob
import json
import torch
import torchvision.transforms as transforms

from TrainingTools.ImageResizeTool import input_path, output_path


def color(flag, input_directory, output_directory):
    # 创建输出目录
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)

    # 遍历输入目录中的图片文件
    for filename in os.listdir(input_directory):
        if filename.endswith(".jpg") or filename.endswith(".png"):
            # 读取图片
            image_path = os.path.join(input_directory, filename)
            original_image = Image.open(image_path)

            # 色彩增强参数,可以根据需求进行调整
            brightness_factor = np.random.uniform(0.6, 1.8)  # 亮度
            contrast_factor = np.random.uniform(0.7, 1.7)    # 对比度
            saturation_factor = np.random.uniform(0.7, 1.7)  # 饱和度

            # 对图像进行色彩增强
            enhancer = ImageEnhance.Brightness(original_image)
            enhanced_image = enhancer.enhance(brightness_factor)

            enhancer = ImageEnhance.Contrast(enhanced_image)
            enhanced_image = enhancer.enhance(contrast_factor)

            enhancer = ImageEnhance.Color(enhanced_image)
            enhanced_image = enhancer.enhance(saturation_factor)

            new_name = os.path.splitext(filename)[0] + "-" + flag + ".jpg"

            json_file = os.path.join(input_directory, os.path.splitext(filename)[0] +".json")
            with open(json_file, 'r') as fj:
                data = json.load(fj)
                data["imagepath"] = new_name
            saved_json = os.path.join(output_directory, os.path.splitext(new_name)[0] + ".json")
            json.dump(data, open(saved_json, 'w'), ensure_ascii=False, indent=2)

            # 保存增强后的图片
            output_path = os.path.join(output_directory, new_name)
            enhanced_image.save(output_path)
            print(output_path)

    print("数据增强完成。")

def position(flag, data_folder, output_folder):
    os.makedirs(output_folder, exist_ok=True)

    # 获取文件夹内所有JSON文件路径
    json_files = glob.glob(os.path.join(data_folder, '*.json'))

    # 定义图像变换
    transform = transforms.RandomAffine(degrees=30, translate=(0.2, 0.2), scale=(0.8, 1.2), shear=10)

    for json_file_path in json_files:
        flage = 0
        while True:
            # 读取JSON文件
            with open(json_file_path, 'r') as json_file:
                data = json.load(json_file)

            # 图像路径
            image_path = os.path.splitext(json_file_path)[0] + ".jpg"

            # 读取图像并获取图像宽高
            image = Image.open(image_path)
            image_width, image_height = image.size

            aspect_ratio = image_width / image_height
            crop_height = int(torch.randint(int(image_height * 0.5), image_height, (1,)).item())
            crop_width = int(crop_height * aspect_ratio)

            transform = transforms.CenterCrop((crop_height, crop_width))
            transformed_image = transform(image)

            # 获取变换后图像的宽高
            transformed_image_width, transformed_image_height = transformed_image.size

            crop_zero_point = ((image_width - transformed_image_width) / 2, (image_height - transformed_image_height) / 2)

            keypoints = data["keypoints"][0]
            for i in range(0, len(keypoints), 3):
                x, y, v = keypoints[i], keypoints[i + 1], keypoints[i + 2]
                keypoints[i], keypoints[i + 1], keypoints[i + 2] = [
                    x - crop_zero_point[0], y - crop_zero_point[1], v]

            # 判断是否有关键点超出图像边缘
            if all(0 <= keypoints[i] < transformed_image_width and 0 <= keypoints[i+1] < transformed_image_height for i in range(0, len(keypoints), 3)):
                break  # 如果所有关键点在图像内,则退出循环

            flage += 1
            print(flage)
            if flage == 20:
                break

        if flage == 20:
            continue

        resized_image = transformed_image.resize((image_width, image_height))
        resize_ration = image_width / transformed_image_width
        # 将关键点坐标转换回绝对坐标
        for i in range(0, len(keypoints), 3):
            x, y, v = [keypoints[i], keypoints[i + 1], keypoints[i + 2]]
            keypoints[i], keypoints[i+1], keypoints[i+2] = [
                    x * resize_ration, y * resize_ration, v]

        # 更新数据
        data["keypoints"] = [keypoints]
        data["imagepath"] = os.path.splitext(data["imagepath"])[0] + "-" + flag + ".jpg"

        # 保存增强后的图像
        transformed_image_path = os.path.join(output_folder, os.path.splitext(os.path.basename(image_path))[0] + "-" + flag + ".jpg")
        resized_image.save(transformed_image_path)

        # 更新 JSON 文件
        transformed_json_file_path = os.path.join(output_folder, os.path.splitext(os.path.basename(json_file_path))[0] + "-" + flag + ".json")
        with open(transformed_json_file_path, 'w') as transformed_json_file:
            json.dump(data, transformed_json_file, indent=2)

        print(transformed_image_path)

    print("Data augmentation completed.")

input_path = ".../mmpose/data/coco/ori"
output_path = ".../mmpose/data/coco/col1"
flag = "col1"
color(flag, input_path, output_path)         # flag = col1 col2
# position(flag, input_path, output_path)    # flag = posi

一般会进行一次

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值