代码来源:
https://github.com/yuanyuanli85/Stacked_Hourglass_Network_Keras/blob/master/src/data_gen/data_process.py
代码
import numpy as np
import scipy
def get_transform(center, scale, res, rot=0):
"""
General image processing functions
"""
# Generate transformation matrix
h = 200 * scale
t = np.zeros((3, 3))
t[0, 0] = float(res[1]) / h
t[1, 1] = float(res[0]) / h
t[0, 2] = res[1] * (-float(center[0]) / h + .5)
t[1, 2] = res[0] * (-float(center[1]) / h + .5)
t[2, 2] = 1
if not rot == 0:
rot = -rot # To match direction of rotation from cropping
rot_mat = np.zeros((3, 3))
rot_rad = rot * np.pi / 180
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
rot_mat[0, :2] = [cs, -sn]
rot_mat[1, :2] = [sn, cs]
rot_mat[2, 2] = 1
# Need to rotate around center
t_mat = np.eye(3)
t_mat[0, 2] = -res[1] / 2
t_mat[1, 2] = -res[0] / 2
t_inv = t_mat.copy()
t_inv[:2, 2] *= -1
t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
return t
def transform(pt, center, scale, res, invert=0, rot=0):
# Transform pixel location to different reference
t = get_transform(center, scale, res, rot=rot)
if invert:
t = np.linalg.inv(t)
new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
new_pt = np.dot(t, new_pt)
return new_pt[:2].astype(int) + 1
def crop(img, center, scale, res, rot=0):
# Preprocessing for efficient cropping
ht, wd = img.shape[0], img.shape[1]
sf = scale * 200.0 / res[0]
if sf < 2:
sf = 1
else:
new_size = int(np.math.floor(max(ht, wd) / sf))
new_ht = int(np.math.floor(ht / sf))
new_wd = int(np.math.floor(wd / sf))
img = scipy.misc.imresize(img, [new_ht, new_wd])
center = center * 1.0 / sf
scale = scale / sf
# Upper left point
ul = np.array(transform([0, 0], center, scale, res, invert=1))
# Bottom right point
br = np.array(transform(res, center, scale, res, invert=1))
# Padding so that when rotated proper amount of context is included
pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
if not rot == 0:
ul -= pad
br += pad
new_shape = [br[1] - ul[1], br[0] - ul[0]]
if len(img.shape) > 2:
new_shape += [img.shape[2]]
new_img = np.zeros(new_shape)
# Range to fill new array
new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
# Range to sample from original image
old_x = max(0, ul[0]), min(len(img[0]), br[0])
old_y = max(0, ul[1]), min(len(img), br[1])
new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]]
if not rot == 0:
# Remove padding
new_img = scipy.misc.imrotate(new_img, rot)
new_img = new_img[pad:-pad, pad:-pad]
new_img = scipy.misc.imresize(new_img, res)
return new_img
def normalize(imgdata, color_mean):
'''
:param imgdata: image in 0 ~ 255
:return: image from 0.0 to 1.0
'''
imgdata = imgdata / 255.0
for i in range(imgdata.shape[-1]):
imgdata[:, :, i] -= color_mean[i]
return imgdata
def draw_labelmap(img, pt, sigma, type='Gaussian'):
# Draw a 2D gaussian
# Adopted from https://github.com/anewell/pose-hg-train/blob/master/src/pypose/draw.py
# Check that any part of the gaussian is in-bounds
ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)]
br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)]
if (ul[0] >= img.shape[1] or ul[1] >= img.shape[0] or
br[0] < 0 or br[1] < 0):
# If not, just return the image as is
return img
# Generate gaussian
size = 6 * sigma + 1
x = np.arange(0, size, 1, float)
y = x[:, np.newaxis]
x0 = y0 = size // 2
# The gaussian is not normalized, we want the center value to equal 1
if type == 'Gaussian':
g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
elif type == 'Cauchy':
g = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) ** 1.5)
# Usable gaussian range
g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0]
g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1]
# Image range
img_x = max(0, ul[0]), min(br[0], img.shape[1])
img_y = max(0, ul[1]), min(br[1], img.shape[0])
img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
return img
def transform_kp(joints, center, scale, res, rot):
newjoints = np.copy(joints)
for i in range(joints.shape[0]):
if joints[i, 0] > 0 and joints[i, 1] > 0:
_x = transform(newjoints[i, 0:2] + 1, center=center, scale=scale, res=res, invert=0, rot=rot)
newjoints[i, 0:2] = _x
return newjoints
def generate_gtmap(joints, sigma, outres):
npart = joints.shape[0]
gtmap = np.zeros(shape=(outres[0], outres[1], npart), dtype=float)
for i in range(npart):
visibility = joints[i, 2]
if visibility > 0:
gtmap[:, :, i] = draw_labelmap(gtmap[:, :, i], joints[i, :], sigma)
return gtmap
import os
import numpy as np
from random import shuffle
import scipy.misc
import json
import data_process
import random
class MPIIDataGen(object):
def __init__(self, jsonfile, imgpath, inres, outres, is_train):
'''
:param jsonfile: 标签存放路径
:param imgpath: 图片存放路径
:param inres: 输入图片尺寸
:param outres: 最终的输出图片尺寸
:param is_train: 当前 generator 是为了 train 还是 val
'''
self.jsonfile = jsonfile
self.imgpath = imgpath
self.inres = inres
self.outres = outres
self.is_train = is_train
'''keypoint 的个数为 16 个'''
self.nparts = 16
self.anno = self._load_image_annotation()
def _load_image_annotation(self):
# load train or val annotation
with open(self.jsonfile) as anno_file:
anno = json.load(anno_file)
'''按照标签的 isValidation 将数据集标签分为训练和测试两个部分'''
val_anno, train_anno = [], []
for idx, val in enumerate(anno):
if val['isValidation'] == True:
val_anno.append(anno[idx])
else:
train_anno.append(anno[idx])
if self.is_train:
'''如果当前的 MPIIDataGen 是为了 train 的,那么返回的就是 train 的标签'''
return train_anno
else:
return val_anno
def get_dataset_size(self):
return len(self.anno)
def get_color_mean(self):
'''
表示在 rgb 通道每个通道分别的平均值是多少,方便后面按照通道为标准进行标准化
:return:
'''
mean = np.array([0.4404, 0.4440, 0.4327], dtype=np.float)
return mean
def get_annotations(self):
return self.anno
def generator(self, batch_size, num_hgstack, sigma=1, with_meta=False, is_shuffle=False,
rot_flag=False, scale_flag=False, flip_flag=False):
'''
Input: batch_size * inres * Channel (3)
Output: batch_size * oures * nparts
'''
train_input = np.zeros(shape=(batch_size, self.inres[0], self.inres[1], 3), dtype=np.float)
gt_heatmap = np.zeros(shape=(batch_size, self.outres[0], self.outres[1], self.nparts), dtype=np.float)
meta_info = list()
if not self.is_train:
'''当进入到 val 的阶段的时候,一定不能对标签进行 shuffle,也一定不能开 rot_flag'''
assert (is_shuffle == False), 'shuffle must be off in val model'
assert (rot_flag == False), 'rot_flag must be off in val model'
while True:
'''只有当 train 的时候,可以选择是否进行 shuffle,val 的时候一定不能 shuffle,train 的时候可以选择不 shuffle'''
if is_shuffle:
shuffle(self.anno)
for i, kpanno in enumerate(self.anno):
_imageaug, _gthtmap, _meta = self.process_image(i, kpanno, sigma, rot_flag, scale_flag, flip_flag)
_index = i % batch_size
train_input[_index, :, :, :] = _imageaug
gt_heatmap[_index, :, :, :] = _gthtmap
meta_info.append(_meta)
if i % batch_size == (batch_size - 1):
out_hmaps = []
for m in range(num_hgstack):
out_hmaps.append(gt_heatmap)
if with_meta:
yield train_input, out_hmaps, meta_info
meta_info = []
else:
yield train_input, out_hmaps
def process_image(self, sample_index, kpanno, sigma, rot_flag, scale_flag, flip_flag):
'''
:param sample_index:
:param kpanno: 一张照片对应的 annotation,所有照片的 anno 都在 self.anno 中
:param sigma:
:param rot_flag:
:param scale_flag:
:param flip_flag:
:return:
'''
'''通过 anno 得到对应图片的路径'''
imagefile = kpanno['img_paths']
'''读到这张图片'''
image = scipy.misc.imread(os.path.join(self.imgpath, imagefile))
# get center
'''得到这个图片中心点的位置 objpos = object position
'objpos': [594.0, 257.0],'''
center = np.array(kpanno['objpos'])
'''得到这个图片的所有 points 的信息,最后一个 1 或者 0 代表是否可见
'joint_self': [[620.0, 394.0, 1.0],
[616.0, 269.0, 1.0],
[573.0, 185.0, 1.0],
[647.0, 188.0, 0.0],
[661.0, 221.0, 1.0],
............... '''
joints = np.array(kpanno['joint_self'])
'''得到这个图片的缩放的尺度
'scale_provided': 3.021,'''
scale = kpanno['scale_provided']
# Adjust center/scale slightly to avoid cropping limbs
if center[0] != -1:
center[1] = center[1] + 15 * scale
scale = scale * 1.25
# filp
if flip_flag and random.choice([0, 1]):
image, joints, center = self.flip(image, joints, center)
# scale
if scale_flag:
scale = scale * np.random.uniform(0.8, 1.2)
# rotate image
if rot_flag and random.choice([0, 1]):
rot = np.random.randint(-1 * 30, 30)
else:
rot = 0
cropimg = data_process.crop(image, center, scale, self.inres, rot)
cropimg = data_process.normalize(cropimg, self.get_color_mean())
# transform keypoints
transformedKps = data_process.transform_kp(joints, center, scale, self.outres, rot)
gtmap = data_process.generate_gtmap(transformedKps, sigma, self.outres)
# meta info
metainfo = {'sample_index': sample_index, 'center': center, 'scale': scale,
'pts': joints, 'tpts': transformedKps, 'name': imagefile}
return cropimg, gtmap, metainfo
@classmethod
def get_kp_keys(cls):
keys = ['r_ankle', 'r_knee', 'r_hip',
'l_hip', 'l_knee', 'l_ankle',
'plevis', 'thorax', 'upper_neck', 'head_top',
'r_wrist', 'r_elbow', 'r_shoulder',
'l_shoulder', 'l_elbow', 'l_wrist']
return keys
def flip(self, image, joints, center):
import cv2
joints = np.copy(joints)
matchedParts = (
[0, 5], # ankle
[1, 4], # knee
[2, 3], # hip
[10, 15], # wrist
[11, 14], # elbow
[12, 13] # shoulder
)
org_height, org_width, channels = image.shape
# flip image
flipimage = cv2.flip(image, flipCode=1)
# flip each joints
joints[:, 0] = org_width - joints[:, 0]
for i, j in matchedParts:
temp = np.copy(joints[i, :])
joints[i, :] = joints[j, :]
joints[j, :] = temp
# center
flip_center = center
flip_center[0] = org_width - center[0]
return flipimage, joints, flip_center
通过 Generator 产生训练和测试数据的思路
- 根据
annotation
文件将数据切分成训练数据和测试数据 - 通过
process_image
进行数据的增强处理- 按照喜好进行放缩,翻折,剪裁,旋转操作
- 通过
crop
裁剪出image
中的主要的person
的区域并对新图像进行normalize
- 注意,对应的
keypoint
的标签也要做同样的操作 - 通过进行变换的
keypoints
点生成他们高斯点的标签(gtmap
中的每个通道都是一个 ) - 最后返回
- 将得到的数据分成多个
batch
;然后将这个batch
的gtmaps
标签复制num_hgstack
份,这是因为在训练的时候,每个 stack 都要进行监督也就是需要一份gtmaps
进行loss
的计算 - 最后通过
yield
方法构成一个生成器,每次产生一个batch
的train
数据和out_hmaps
标签,假设batch = 64
,那么每次产生的train
数据应该是(64,256,256,3)
而out_hmaps
的数据维度应该是(num_hgstack, 64, 64,64)
验证维度
-
首先将文章最开始的 github 代码下载下来
-
下载 MPII 数据集和标签,按照 github 中的文档结构放到 data 文件夹中
- 其中 image 是所有的 MPII 图片数据
- 和 image 并列的放置 annotations
-
在 MPIIDataGen 中给合适的路径即可
train_gen = MPIIDataGen("../../data/mpii/mpii_annotations.json","../../data/mpii/images/",(256,256),(64,64),True)
gen = train_gen.generator(batch_size=8
,num_hgstack=2)
train_input, out_hmaps = next(gen)