YOLOX中decode 特征点解码过程可视化
该代码是特征宽高为20*20,batch_size=4,num_classes = 20进行解码可视化的过程。
import numpy as np
import matplotlib.pyplot as plt
def decode_for_vision(output):
bs, hw = np.shape(output)[0], np.shape(output)[1:3]
# hw[0] * hw[1] ------- 20,20
output = np.reshape(output, [bs, hw[0] * hw[1], -1])
#print(output)
#output ------(4, 400, 23)
grid_x, grid_y = np.meshgrid(np.arange(hw[1]), np.arange(hw[0]))
#print(grid_x)
grid = np.reshape(np.stack((grid_x, grid_y), 2), (1, -1, 2))
#grid ---------(1, 400, 2)
#print(grid)
box_xy = (output[..., :2] + grid)
#box_xy.shape (4, 400, 2)
#output[..., :2] (4, 400, 2)
#grid (1, 400, 2)
box_wh = np.exp(output[..., 2:4])
#output[..., 2:4].shape (4, 400, 2)
#box_wh (4, 400, 2)
fig = plt.figure()
ax = fig.add_subplot(121)
plt.ylim(-2.22, hw[0] + 2.22)
plt.xlim(-2.22, hw[1] + 2.22)
plt.scatter(grid_x, grid_y)
plt.scatter(0, 0, c='black')
plt.scatter(1, 0, c='black')
plt.scatter(2, 0, c='black')
plt.scatter(box_xy[0, 0, 0], box_xy[0, 0, 1], c='r')
plt.scatter(box_xy[0, 1, 0], box_xy[0, 1, 1], c='g')
plt.scatter(box_xy[0, 2, 0], box_xy[