OpenPose -tensorflow代码解析(5)—— 预测代码解析 predict.py

前言

该openpose-tensorflow的工程是自己实现的,所以有些地方写的会比较简单,但阅读性强、方便使用。

论文翻译 || openpose – Realtime Multi-Person 2D Pose Estimation using Part Affinity Fields
工程实现 || 基于opencv使用openpose完成人体姿态估计

OpenPose -tensorflow代码解析(1)——工程概述&&训练前的准备
OpenPose -tensorflow代码解析(2)—— 数据增强和处理 dataset.py
OpenPose -tensorflow代码解析(3)—— 网络结构的搭建 Net.py
OpenPose -tensorflow代码解析(4)—— 训练脚本 train.py
OpenPose -tensorflow代码解析(5)—— 预测代码解析 predict.py

1 预测脚本

预测脚本会比较简单,主要包括了:

  • def __init__(self) 初始化:创建会话、模型加载
  • def run(self, batch_image) 预测图片,得到网络输出的heatmap并解析
  • def predict(self) 使用验证集或测试集中的数据进行测试,对比标签和预测的结果并显示
  • def test(self, image_path) 预测实际某张图片,并显示预测结果
from dataset import *
from eval import *
from NET import *
from opt import *

os.environ["CUDA_VISIBLE_DEVICES"] = "1"


def show_image(img, predict, label=""):
   c1 = []
   c2 = []
   print("predict size: ", len(predict))
   for ii in range(len(predict)):
       c1.append((int(label[ii][0]), int(label[ii][1]))) if label!="" else None
       c2.append((int(predict[ii, 0 ] +4), int(predict[ii, 1])))

   for cc in range(len(predict)):
       cv2.circle(img, c1[cc], 2, (255, 0, 0), thickness=1) if label!="" else None
       cv2.circle(img, c2[cc], 2, (0, 0, 255), thickness=1)

   cv2.namedWindow("image check", cv2.WINDOW_NORMAL)
   cv2.imshow("image check", img)
   cv2.waitKey(0)


class OpenPoseTrain(object):
   def __init__(self):

       self.strides = cfg.OP.strides
       self.sess = tf.Session()

       self.input_node = tf.placeholder(tf.float32, shape=(None, None, None, 3), name='image')
       self.net = OpenPose(self.input_node, True)
       self.net_out = self.net.CPM[-1]

       self.saver = tf.train.Saver()
       self.saver.restore(self.sess, 'model/checkpoint1/model-19')


   def run(self, batch_image):
       feed_dict_data = dict()
       feed_dict_data[self.input_node] = batch_image
       heatmap = self.sess.run(self.net_out, feed_dict=feed_dict_data)

       result = get_preds(heatmap)
       result = np.squeeze(result)[0:-1,:] * self.strides
       return result


   def predict(self):
       cfg.TEST.batch_size = 1
       humandata = Dataset('test')
       humandata.Prepare()

       for i in range(len(humandata.annotations)):
           print(humandata.annotations[i][0])

           img, label_np = humandata.load_data(humandata.annotations[i][0], humandata.annotations[i][1])
           img1 = img.astype(np.float32)
           img1 = (img1 - np.mean(img1, axis=(0, 1))) / (np.std(img1, axis=(0, 1)) + 1e-8)
           humandata.batch_image[0, :, :, :] = img1

           result = self.run(humandata.batch_image)
           show_image(img, result, label_np)

   def test(self, image_path):

       image = np.array(cv2.imread(image_path))
       img = image_preporcess(image, [512, 512])
       img = img.astype(np.float32)
       img = (img - np.mean(img, axis=(0, 1))) / (np.std(img, axis=(0, 1)) + 1e-8)
       img = np.expand_dims(img, 0)


       result = self.run(img)
       show_image(image, result)


op = OpenPoseTrain()
# op.predict()
op.test("../data/dataset2/031_l320.png")

2 预测结果展示

  • op.predict()

    在这里插入图片描述
  • op.test("../data/dataset2/031_l320.png")

    在这里插入图片描述
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值