对Social-STGCNN模型进行可视化

本文展示了如何对Social-STGCNN模型进行轨迹预测的可视化,并详细介绍了test.py中的关键代码,包括数据预处理、模型评估和样本生成。通过图例解释了真实轨迹和预测轨迹的比较,便于理解模型性能。

之前发布了一篇关于复现了论文《Social-STGCNN:A Social Spatio-Temporal Graph Convolutional Neural Network for Human Trajectory Prediction》的博客
后来有小伙伴问到了可视化,最近这两天刚好在回顾这个模型,因此顺便把它可视化出来,具体操作如下:

对test.py文件进行以下修改:

import os
import math
import sys
import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import pickle
import argparse
import glob
import torch.distributions.multivariate_normal as torchdist
from utils import * 
from metrics import * 
from model import social_stgcnn
import copy
import matplotlib.pyplot as plt


device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

def vis_result(ypos, ppos):

    ypos = ypos.data.cpu().numpy()
    ppos = ppos.data.cpu().numpy()

    print(ypos.shape, ppos.shape)

    n_pred, node_num, _ = ypos.shape
    total = n_pred * 2
    ncol = 3
    nrow = total // ncol + (total % ncol)

    plt.figure(figsize=(16,12))
    n_fig = 0
    # print(ypos,ppos)
    for i in range(n_pred):
        yx = ypos[i, :, 0]
        yy = ypos[i,:, 1]
        ax = plt.subplot(nrow, ncol, n_fig+i+1)
        plt.title(str(i+1)+'_gt')
        plt.scatter(yx, yy)
    n_fig += n_pred
    for i in range(n_pred):
        px = ppos[i,:, 0]
        py = ppos[i, :,1]
        ax = plt.subplot(nrow, ncol, n_fig+i+1)
        plt.title(str(i+1)+'_pred')
        plt.scatter(px, py)
    # plt.savefig('./outputs'+'/result'+ str(?) +'.png')
    plt.show()
def test</
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值