对于数据集我这里的目标是提取人体图片并分割出来,以及标注处理
对与标准框体大小,我们这里进行的是20%各边的放大,对于坐标进行了变换,写成了Python脚本的模式,
请看代码,运行的话将会对Coco数据集进行自动处理,提取所有的人体图片,以及标注信息,生成相应的文件
运行请注意路径的修改~
代码比较长,看完,请点下一节,对筛选的数据如何划分,训练集和测试集,我们同样给出了具体的脚本:
下一节:https://blog.youkuaiyun.com/weixin_41994751
from pycocotools.coco import COCO
import pylab
import cv2
import json
from collections import defaultdict
pylab.rcParams['figure.figsize'] = (8.0, 10.0)
json_File = '/media/blacktea/DATA/MScoco/new_annotations/new_anno.json'
w_dir='/media/blacktea/DATA/MScoco/new_img1/'
#this is init~
# initialize COCO api for person keypoints annotations
dataDir = '/media/blacktea/DATA/MScoco/'
dataType = 'train2017'
annFile = '{}/annotations/person_keypoints_{}.json'.format(dataDir, dataType)
coco_kps=COCO(annFile)
# display COCO categories and supercategories
cats = coco_kps.loadCats(coco_kps.getCatIds())
nms=[cat['name'] for cat in cats]
print('COCO categories: \n{}\n'.format(' '.join(nms)))
nms = set([cat['supercategory'] for cat in cats])
print('COCO supercategories: \n{}'.format(' '.join(nms)))
# get all images containing given categories, select one at random
catIds = coco_kps.getCatIds(catNms=['person']);
imgIds = coco_kps.getImgIds(catIds=catIds );
print ('there are %d images containing human'%len(imgIds))
# 数据收集数组定义
dnfboxs = []
imagenames = []
part = []
new_dict=defaultdict(str)
def getBndboxKeypointsGT():
# csvFile = open('....../KeypointBndboxGT.csv','wb')
# keypointsWriter = csv.writer(csvFile)
# firstRow = ['imageName','personNumber','bndbox','nose',
# 'left_eye','right_eye','left_ear','right_ear','left_shoulder','right_shoulder',
# 'left_elbow','right_elbow','left_wrist','right_wrist','left_hip','right_hip',
# 'left_knee','right_knee','left_ankle','right_ankle']
# keypointsWriter.writerow(firstRow)
img_num = 100000
tim = 0
#循环遍历人的图片
for i in range(len(imgIds)):
imageNameTemp = coco_kps.loadImgs(imgIds[i])[0]
imageName = imageNameTemp['file_name'].encode('raw_unicode_escape')
imageName1 = imageNameTemp['file_name']
img = coco_kps.loadImgs(imgIds[i])[0]
annIds = coco_kps.getAnnIds(imgIds=img['id'], catIds=catIds, iscrowd=None)
#获取越界限制信息
x=img['width']
y=img['height']
#ann 标注列表
anns = coco_kps.loadAnns(annIds)
#人物数量
personNumber = len(anns)
# if(personNumber>6):continue
for j in range(personNumber):
#子列表定义
box=[]
name=[]
point=[]
#获取显示相关的数据
#包围框
bndbox = anns[j]['bbox']
#姿态关键点数组解析
keyPoints = anns[j]['keypoints']
#展示矩阵
# keypointsRow = [imageName,str(personNumber),
# str(bndbox[0])+'_'+str(bndbox[1])+'_'+str(bndbox[2])+'_'+str(bndbox[3]),
# str(keyPoints[0])+'_'+str(keyPoints[1])+'_'+str(keyPoints[2]),
# str(keyPoints[3])+'_'+str(keyPoints[4])+'_'+str(keyPoints[5]),
# str(keyPoints[6])+'_'+str(keyPoints[7])+'_'+str(keyPoints[8]),
# str(keyPoints[9])+'_'+str(keyPoints[10])+'_'+str(keyPoints[11]),
# str(keyPoints[12])+'_'+str(keyPoints[13])+'_'+str(keyPoints[14]),
# str(keyPoints[15])+'_'+str(keyPoints[16])+'_'+str(keyPoints[17]),
# str(keyPoints[18])+'_'+str(keyPoints[19])+'_'+str(keyPoints[20]),
# str(keyPoints[21])+'_'+str(keyPoints[22])+'_'+str(keyPoints[23]),
# str(keyPoints[24])+'_'+str(keyPoints[25])+'_'+str(keyPoints[26]),
# str(keyPoints[27])+'_'+str(keyPoints[28])+'_'+str(keyPoints[29]),
# str(keyPoints[30])+'_'+str(keyPoints[31])+'_'+str(keyPoints[32]),
# str(keyPoints[33])+'_'+str(keyPoints[34])+'_'+str(keyPoints[35]),
# str(keyPoints[36])+'_'+str(keyPoints[37])+'_'+str(keyPoints[38]),
# str(keyPoints[39])+'_'+str(keyPoints[40])+'_'+str(keyPoints[41]),
# str(keyPoints[42])+'_'+str(keyPoints[43])+'_'+str(keyPoints[44]),
# str(keyPoints[45])+'_'+str(keyPoints[46])+'_'+str(keyPoints[47]),
# str(keyPoints[48])+'_'+str(keyPoints[49])+'_'+str(keyPoints[50]),]
# print(keypointsRow)
# 定义框
x_con= bndbox[0]
y_con = bndbox[1]
x1= bndbox[0]
y1 = bndbox[1]
width1 = bndbox[2]
height1 = bndbox[3]
# print(bndbox.shape())
print(x_con,y_con,width1,height1)
cv2.waitKey(0)
#x,y x,y
lx=0
ly=0
rx=width1
ry=height1
width = 0
height = 0
#正面过滤
if(keyPoints[2]==2):
if((keyPoints[15]-keyPoints[18])>30):
if(keyPoints[3]-keyPoints[6]>5):
#扩增尺寸的确定
alpx=width1*0.2
alpy=height1*0.2
#越界判定
a=((x_con-alpx)<0 or (y_con-alpy)<0)#左上角越界
b=((x_con+width1+alpx)>x or (y_con+height1+alpy)>y)#右上角越界
if(a and not b):
x1 = bndbox[0]
y1 = bndbox[1]
width = bndbox[2]+alpx
height = bndbox[3]+alpy
lx = 0
ly = 0
rx = width1
ry = height1
elif(not a and b ):
x1 = bndbox[0]-alpx
y1 = bndbox[1]-alpy
width = bndbox[2]+alpx
height = bndbox[3]+alpy
lx = alpx
ly = alpy
rx = width1
ry = height1
elif(not a and not b):
x1 = bndbox[0]-alpx
y1 = bndbox[1]-alpy
width = bndbox[2]+2*alpx
height = bndbox[3]+2*alpy
lx = alpx
ly = alpy
rx = width1
ry = height1
else:
x1 = bndbox[0]
y1 = bndbox[1]
width = bndbox[2]
height = bndbox[3]
lx = 0
ly = 0
rx = width1
ry = height1
#小图不要
if(width<100 or height<130):continue
#截图
#read the img
src=cv2.imread('/media/blacktea/DATA/MScoco/train2017/'+imageName1)
# cv2.namedWindow('222')
# cv2.namedWindow('11')
# cv2.imshow('222', src)
dst=src[int(y1):int(height+y1),int(x1):int(width+x1)]
# dst1 = src[int(bndbox[1]):int(bndbox[1]+bndbox[3]), int(bndbox[0]):int(bndbox[0]+bndbox[3])]
#imshow the image
# cv2.imshow('11',dst1)
# cv2.waitKey(0)
cv2.imwrite(w_dir+str(img_num)+'.jpg',dst)
#summary
box.append(lx)
box.append(ly)
box.append(rx)
box.append(ry)
name.append(img_num)
point=point+[keyPoints[0]-x1,keyPoints[1]-y1,
keyPoints[3]-x1, keyPoints[4]-y1,
keyPoints[6]-x1, keyPoints[7]-y1,
keyPoints[9]-x1, keyPoints[10]-y1,
keyPoints[12]-x1, keyPoints[13]-y1,
keyPoints[15]-x1, keyPoints[16]-y1,
keyPoints[18]-x1, keyPoints[19]-y1,
keyPoints[21]-x1, keyPoints[22]-y1,
keyPoints[24]-x1, keyPoints[25]-y1,
keyPoints[27]-x1, keyPoints[28]-y1,
keyPoints[30]-x1, keyPoints[31]-y1,
keyPoints[33]-x1, keyPoints[34]-y1,
keyPoints[36]-x1, keyPoints[37]-y1,
keyPoints[39]-x1, keyPoints[40]-y1,
keyPoints[42]-x1, keyPoints[43]-y1,
keyPoints[45]-x1, keyPoints[46]-y1,
keyPoints[48]-x1, keyPoints[49]-y1]
#定义总列表
dnfboxs.append(box)
imagenames.append(name)
part.append(point)
print(str(tim))
tim=tim +1
#增加变量
img_num=img_num+1
new_dict['bndbox']=dnfboxs
new_dict['imgnames'] = imagenames
new_dict['part'] = part
with open(json_File,"w") as f:
json.dump(new_dict,f)
if __name__ == "__main__":
print ('Filter images and extract keypoints data to files..."')
getBndboxKeypointsGT()