数据准备
引入头文件
import h5py
import scipy.io as io
import PIL.Image as Image
import numpy as np
import os
import glob
from matplotlib import pyplot as plt
from scipy.ndimage.filters import gaussian_filter
import scipy
import scipy.spatial
import json
from matplotlib import cm as CM
import torch
import os
from xml.dom.minidom import parse
import numpy as np
数据路径准备
root = '../../../veritas/'
train = os.path.join(root, 'Train_img')
test = os.path.join(root, 'Test_img')
由于我们的标记数据为xml格式,所以需要编写读取xml的函数
def readXML(path):
domTree = parse(path)
# 文档根元素
rootNode = domTree.documentElement
my_list = []
people = rootNode.getElementsByTagName("point")
for person in people: # 获取所有人的坐标
# print(person)
pos_x = int(person.getElementsByTagName("x")[0].firstChild.data)
pos_y = int(person.getElementsByTagName("y")[0].firstChild.data)
my_list.append([pos_x, pos_y])
print("len: ", len(my_list))
return np.array(my_list)
最原始的密度图构建方法
path_sets = [train, test]
img_paths = []
for path in path_sets:
for img_path in glob.glob(os.path.join(path, '*.jpg')):
img_paths.append(img_path)
for i, img_path in enumerate(img_paths):
print(i, len(img_paths), img_path)
# 这里的gt是存放标记坐标的list
gt = readXML(img_path.replace('.jpg', '.xml').replace('Infrared', 'GT_Infrared'))
img = plt.imread(img_path)
k = np.zeros((img.shape[0], img.shape[1]))
for i in range(0, len(gt)):
if int(gt[i][1]) < img.shape[0] and int(gt[i][0]) < img.shape[1]:
k[int(gt[i][1]), int(gt[i][0])] = 1
# 选择使用的滤波
k = gaussian_filter(k, 15)
'''----------------------------生成.h5文件----------------------------'''
# 这里需要根据自己需要修改路径
with h5py.File(img_path.replace('.jpg', '.h5').replace('img', 'GT_img'), 'w') as hf:
hf['density'] = k
'''----------------------------生成.npy文件------------------------------'''
np.save(img_path.replace('.jpg', '.npy').replace('img', 'GT_img'), k)
利用相邻标记距离进行自适应密度图构建
利用最近邻构建密度图函数实现
# 利用最近邻构建密度图
def gaussian_filter_density(gt):
# 初始化密度图
density = np.zeros(gt.shape, dtype=np.float32)
# 查看地图一共有多少个人
gt_count = np.count_nonzero(gt)
if gt_count == 0: # 如果这个图里没有人,就直接返回全零密度图
return density
# 获得所有标注为人的位置
pts = np.array(list(zip(np.nonzero(gt)[1], np.nonzero(gt)[0])))
leafsize = 2048
# 构建 kdtree
tree = scipy.spatial.KDTree(pts.copy(), leafsize=leafsize)
# 查询 kdtree,参数k表示查询几个最近邻
distances, locations = tree.query(pts, k=2)
print('generate density...')
for i, pt in enumerate(pts):
pt2d = np.zeros(gt.shape, dtype=np.float32)
pt2d[pt[1], pt[0]] = 1.
if gt_count > 1: # 如果图中大于一个人,这里可以可以进行调整,这里使用k=2,所有只有一个邻居(因为第一个是自己)
# sigma = (distances[i][1] + distances[i][2] + distances[i][3]) * 0.1
# 利用距离计算sigma
sigma = (distances[i][1]) * 0.1
else: # 如果图中只有一个人,无法查询邻居,所以使用固定的sigma
sigma = np.average(np.array(gt.shape)) / 2. / 2. # case: 1 point
# 使用计算出的sigma,进行高斯滤波
density += scipy.ndimage.filters.gaussian_filter(pt2d, sigma, mode='constant')
return density
主循环实现
'''--------------------------------------获取所有图像地址--------------------------------------'''
path_sets = [train, test]
img_paths = []
for path in path_sets:
for img_path in glob.glob(os.path.join(path, '*.jpg')):
img_paths.append(img_path)
'''--------------------------------------构建密度图--------------------------------------'''
for i, img_path in enumerate(img_paths):
print(i, len(img_paths), img_path)
# 这里的gt是存放标记坐标的list
gt = readXML(img_path.replace('.jpg', '.xml').replace('img', 'GT_img'))
img = plt.imread(img_path)
k = np.zeros((img.shape[0], img.shape[1]))
for i in range(0, len(gt)):
if int(gt[i][1]) < img.shape[0] and int(gt[i][0]) < img.shape[1]:
k[int(gt[i][1]), int(gt[i][0])] = 1
# 构建密度图的方式
k = gaussian_filter_density(k)
# k = gaussian_filter(k, 15)
'''----------------------------生成.h5文件----------------------------'''
with h5py.File(img_path.replace('.jpg', '.h5').replace('img', 'GT_img'), 'w') as hf:
hf['density'] = k
'''----------------------------生成.npy文件------------------------------'''
np.save(img_path.replace('.jpg', '.npy').replace('img', 'GT_img'), k)
密度图可视化
# 可视化原始图像
plt.imshow(Image.open(img_paths[0]))
plt.show()
# 保存并可视化密度图
gt_file = h5py.File(img_paths[0].replace('.jpg', '.h5').replace('img', 'GT_img'), 'r')
groundtruth = np.asarray(gt_file['density'])
plt.imsave("./" + img_path[0] +"_test.jpg", groundtruth, cmap=CM.jet)
plt.imshow(groundtruth, cmap=CM.jet)
plt.show()
print(np.sum(groundtruth)) # 查看一共有多少人,一般情况下会和真实情况有些许出入