之前发布了一篇关于复现了论文《Social-STGCNN:A Social Spatio-Temporal Graph Convolutional Neural Network for Human Trajectory Prediction》的博客
后来有小伙伴问到了可视化,最近这两天刚好在回顾这个模型,因此顺便把它可视化出来,具体操作如下:
对test.py文件进行以下修改:
import os
import math
import sys
import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import pickle
import argparse
import glob
import torch.distributions.multivariate_normal as torchdist
from utils import *
from metrics import *
from model import social_stgcnn
import copy
import matplotlib.pyplot as plt
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
def vis_result(ypos, ppos):
ypos = ypos.data.cpu().numpy()
ppos = ppos.data.cpu().numpy()
print(ypos.shape, ppos.shape)
n_pred, node_num, _ = ypos.shape
total = n_pred * 2
ncol = 3
nrow = total // ncol + (total % ncol)
plt.figure(figsize=(16,12))
n_fig = 0
# print(ypos,ppos)
for i in range(n_pred):
yx = ypos[i, :, 0]
yy = ypos[i,:, 1]
ax = plt.subplot(nrow, ncol, n_fig+i+1)
plt.title(str(i+1)+'_gt')
plt.scatter(yx, yy)
n_fig += n_pred
for i in range(n_pred):
px = ppos[i,:, 0]
py = ppos[i, :,1]
ax = plt.subplot(nrow, ncol, n_fig+i+1)
plt.title(str(i+1)+'_pred')
plt.scatter(px, py)
# plt.savefig('./outputs'+'/result'+ str(?) +'.png')
plt.show()
def test</

本文展示了如何对Social-STGCNN模型进行轨迹预测的可视化,并详细介绍了test.py中的关键代码,包括数据预处理、模型评估和样本生成。通过图例解释了真实轨迹和预测轨迹的比较,便于理解模型性能。
最低0.47元/天 解锁文章
7409

被折叠的 条评论
为什么被折叠?



