Argoverse轨迹预测demo

1.Argoverse数据集

Argoverse数据集官方网址:https://www.argoverse.org/index.html
Argoverse在github的源码:https://github.com/argoverse/argoverse-api/tree/master
环境安装:博客1博客2中有详细介绍
运行下述命令,可以得到完整的Argoverse源码,其中my_tests是我自己建立的文件夹:

git clone https://github.com/argoai/argoverse-api.git

文件夹分支
安装好环境后,就可以开始实验了。

2.Argoverse轨迹预测介绍

本文只介绍使用demo,因此数据下载的是样例数据集。
请添加图片描述将数据放到自己建立的文件夹,样例的数据只有5个csv文件,在Argoverse Motion Forecasting v1.1中可以找到更多数据:

在这里插入图片描述
后面代码部分将在argoverse_forecasting_tutorial.py中编写。

csv文件打开后,部分数据展示:
在这里插入图片描述

2.1数据加载

from argoverse.data_loading.argoverse_forecasting_loader import ArgoverseForecastingLoader
##设置为自己存放数据的目录
root_dir = 'forecasting_sample/data/'
##
afl = ArgoverseForecastingLoader(root_dir)

print('Total number of sequences:',len(afl))

ArgoverseForecastingLoader是一个类,包含有以下的方法:

方法名作用返回类型
__init__(self, root_dir)初始化方法,设置根目录和序列列表*
track_id_list(self)返回当前序列中的跟踪 ID 列表List[int]
city(self)返回当前序列的城市名称(‘PIT’ 或 ‘MIA’)str
num_tracks(self)返回当前序列中的轨迹数量int
seq_df(self)返回当前序列的 Pandas DataFramepd.DataFrame
agent_traj(self)返回当前序列中 ‘AGENT’ 类型的轨迹(二维数组)np.ndarray
__iter__(self)返回当前对象,以便可以进行迭代ArgoverseForecastingLoader
__next__(self)返回下一个序列的数据加载器对象ArgoverseForecastingLoader
__len__(self)返回数据中的序列数量int
__str__(self)返回当前序列的一些统计信息字符串str
__getitem __(self, key)返回给定索引的序列的数据加载器对象ArgoverseForecastingLoader
get(self, seq_id)返回给定序列路径的数据加载器对象ArgoverseForecastingLoader

ArgoverseForecastingLoader类的部分方法使用如下:

from argoverse.data_loading.argoverse_forecasting_loader import ArgoverseForecastingLoader

##set root_dir to the correct path to your dataset folder
root_dir = 'forecasting_sample/data/'

afl = ArgoverseForecastingLoader(root_dir)

print('Total number of sequences:',len(afl))

for argoverse_forecasting_data in (afl):
    print(argoverse_forecasting_data)    #输出单个csv文件的简要信息(City和轨迹数量)
    print(argoverse_forecasting_data.track_id_list)   #输出单个csv文件内所有的TRACK_ID
    print(argoverse_forecasting_data.city)    # 输出单个csv文件内的城市名
    print(argoverse_forecasting_data.num_tracks)  # 输出单个csv文件内的轨迹数量
    print(argoverse_forecasting_data.seq_df)  # 以pd.DataFrame类型输出csv文件的数据信息
    print(argoverse_forecasting_data.agent_traj)  # 以numpy.ndarray类型输出agent的轨迹

部分结果如下图所示:
在这里插入图片描述

2.2绘制可视化图片

绘制可视化图形代码如下,其中的viz_sequenceargoverse.visualization.visualize_sequences中写好的函数。

from argoverse.visualization.visualize_sequences import viz_sequence
import os
import matplotlib.pyplot as plt
# seq_path = f"{root_dir}/2645.csv"
# viz_sequence(afl.get(seq_path).seq_df, show=True)
# seq_path = f"{root_dir}/3828.csv"
# viz_sequence(afl.get(seq_path).seq_df, show=True)
# 确保保存目录存在
save_dir = "save_pics"
os.makedirs(save_dir, exist_ok=True)
# 可视化第一个序列并保存
seq_path = f"{root_dir}/2645.csv"
viz_sequence(afl.get(seq_path).seq_df, show=False)  # 不显示图形,直接保存
plt.savefig(os.path.join(save_dir, "sequence_2645.png"))  # 保存图像
plt.close()  # 关闭图形以释放内存

生成的结果图如下:
在这里插入图片描述

2.3使用map_api

获取代理轨迹的候选中心线是一个简单的函数调用。下面我们使用轨迹的前 2 秒来计算接下来 3 秒的候选中心线。map_api文件的函数有很多,**ArgoverseMap()**是里面的一个类,这里不展开讨论。

from argoverse.map_representation.map_api import ArgoverseMap

avm = ArgoverseMap()

obs_len = 20

index = 2
seq_path = afl.seq_list[index]
agent_obs_traj = afl.get(seq_path).agent_traj[:obs_len]
candidate_centerlines = avm.get_candidate_centerlines_for_traj(agent_obs_traj, afl[index].city, viz=True)

index = 3
seq_path = afl.seq_list[index]
agent_obs_traj = afl.get(seq_path).agent_traj[:obs_len]
candidate_centerlines = avm.get_candidate_centerlines_for_traj(agent_obs_traj, afl[index].city, viz=True)

生成的结果图如下:
在这里插入图片描述

在这里插入图片描述想要将图片保存下来,直接加plt.savefig(filename)貌似不行,保存下来的是空白。目前找到的方法需要修改map_api文件的get_candidate_centerlines_for_traj函数代码,目前我个人解决方案是下图。
在这里插入图片描述

2.4获取轨迹坐标的车道方向

和2.3相似,也是通过**ArgoverseMap()**类里面的函数调用得到的。

index = 2
seq_path = afl.seq_list[index]
agent_traj = afl.get(seq_path).agent_traj
lane_direction = avm.get_lane_direction(agent_traj[0], afl[index].city, visualize=True)

index = 3
seq_path = afl.seq_list[index]
agent_traj = afl.get(seq_path).agent_traj
lane_direction = avm.get_lane_direction(agent_traj[0], afl[index].city, visualize=True)

结果图如下:
在这里插入图片描述

在这里插入图片描述

3.结束语

这是一个简单的使用python对 Argoverse 预测数据集教程。如有不足,请指出。

### 关于Argoverse数据集的下载与使用 #### Argoverse 数据集概述 Argoverse 是一个专为自动驾驶研究设计的数据集,提供了丰富的传感器数据和标注信息。随着版本迭代,特别是v2版本的推出,使得研究人员能够更方便地获取并处理复杂的真实场景下的自动驾驶数据[^1]。 #### 获取 Argoverse 数据集 为了便于开发者和科研人员快速上手,官方提供了一个详细的指南来帮助用户完成数据集的下载过程: - **注册账号**:访问官方网站 (https://www.argoverse.org/) 并创建个人账户。 - **申请权限**:提交必要的个人信息以及计划用途描述给审核团队等待批准。 - **选择所需资源**:登录后进入“Data”页面挑选感兴趣的具体子集比如感知(Prediction) 或 预测(Forecasting),点击对应的链接前往GitCode仓库或其他托管平台。 - **克隆存储库**:以`av2-api`为例,可以通过命令行工具执行如下操作来进行本地复制: ```bash git clone https://gitcode.com/gh_mirrors/av2/av2-api.git ``` #### 安装依赖项及环境配置 确保安装Python解释器及其包管理工具pip之后,在终端内依次运行下列指令设置虚拟环境并加载必需软件包: ```bash cd av2-api/ python3 -m venv env source ./env/bin/activate # Windows 用户请使用 `.\env\Scripts\activate` pip install --upgrade pip setuptools wheel pip install -r requirements.txt ``` #### 加载与预览数据样本 成功部署好上述组件以后就可以调用API接口读取文件中的记录了。下面给出一段简单的代码片段用于展示如何打开指定路径下的一帧LiDAR扫描图元组,并将其可视化出来: ```python from argoverse.data_loading.synchronization_database import SynchronizationDB import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D sdb = SynchronizationDB(log_id="example_log", data_dir="./data/") lidar_data = sdb.get_lidar_at_timestamp(0) fig = plt.figure() ax = fig.add_subplot(projection='3d') ax.scatter(lidar_data[:, 0], lidar_data[:, 1], lidar_data[:, 2]) plt.show() ``` 通过这种方式不仅可以直观感受三维空间内的物体分布情况,同时也为进一步开展高级别的算法实验打下了坚实的基础[^4]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值