团队博客: 汽车电子社区
一、模块概述
Prediction模块是Apollo自动驾驶系统的"预言家",负责预测其他交通参与者(车辆、行人、自行车等)的未来行为和轨迹。该模块基于感知结果、地图信息、交通规则等多维信息,为Planning模块提供准确的预测轨迹,是安全决策的重要基础。
二、模块架构
2.1 目录结构
modules/prediction/
├── common/ # 预测通用组件
├── prediction_component/ # 预测主组件
├── container/ # 容器和数据结构
├── evaluator/ # 评估器
├── jumper/ # 跳变检测
├── planner/ # 预测规划器
├── proto/ # 消息定义
└── scenario/ # 场景分析
2.2 核心组件
1. PredictionComponent - 预测主组件
2. ContainerManager - 容器管理器
3. ObstaclePredictor - 障碍物预测器
4. EvaluatorManager - 评估器管理器
5. Jumper - 跳变检测器
三、接口调用流程图
3.1 整体预测流程图
3.2 车辆预测流程图
3.3 行人预测流程图
3.4 评估器流程图
4.1.1 PredictionComponent
类定义与继承关系:
class PredictionComponent : public cyber::Component<perception::PerceptionObstacles> {
public:
~PredictionComponent();
std::string Name() const;
bool Init() override;
bool Proc(const std::shared_ptr<perception::PerceptionObstacles>&) override;
void OfflineProcessFeatureProtoFile(const std::string& features_proto_file);
private:
bool ContainerSubmoduleProcess(const std::shared_ptr<perception::PerceptionObstacles>&);
bool PredictionEndToEndProc(const std::shared_ptr<perception::PerceptionObstacles>&);
// 时间统计
double component_start_time_ = 0.0;
double frame_start_time_ = 0.0;
// 消息订阅器
std::shared_ptr<cyber::Reader<planning::ADCTrajectory>> planning_reader_;
std::shared_ptr<cyber::Reader<localization::LocalizationEstimate>> localization_reader_;
std::shared_ptr<cyber::Reader<storytelling::Stories>> storytelling_reader_;
// 消息发布器
std::shared_ptr<cyber::Writer<PredictionObstacles>> prediction_writer_;
std::shared_ptr<cyber::Writer<SubmoduleOutput>> container_writer_;
std::shared_ptr<cyber::Writer<ADCTrajectoryContainer>> adc_container_writer_;
std::shared_ptr<cyber::Writer<perception::PerceptionObstacles>> perception_obstacles_writer_;
// 管理器实例
std::shared_ptr<ContainerManager> container_manager_;
std::unique_ptr<EvaluatorManager> evaluator_manager_;
std::unique_ptr<PredictorManager> predictor_manager_;
std::unique_ptr<ScenarioManager> scenario_manager_;
};
初始化过程:
bool PredictionComponent::Init() {
component_start_time_ = Clock::NowInSeconds();
// 创建管理器实例
container_manager_ = std::make_shared<ContainerManager>();
evaluator_manager_.reset(new EvaluatorManager());
predictor_manager_.reset(new PredictorManager());
scenario_manager_.reset(new ScenarioManager());
// 加载配置文件
PredictionConf prediction_conf;
if (!ComponentBase::GetProtoConfig(&prediction_conf)) {
AERROR << "Unable to load prediction conf file: "
<< ComponentBase::ConfigFilePath();
return false;
}
// 初始化消息处理器
if (!MessageProcess::Init(container_manager_.get(), evaluator_manager_.get(),
predictor_manager_.get(), prediction_conf)) {
return false;
}
// 创建订阅器
planning_reader_ = node_->CreateReader<ADCTrajectory>(
prediction_conf.topic_conf().planning_trajectory_topic(), nullptr);
localization_reader_ = node_->CreateReader<localization::LocalizationEstimate>(
prediction_conf.topic_conf().localization_topic(), nullptr);
storytelling_reader_ = node_->CreateReader<storytelling::Stories>(
prediction_conf.topic_conf().storytelling_topic(), nullptr);
// 创建发布器
prediction_writer_ = node_->CreateWriter<PredictionObstacles>(
prediction_conf.topic_conf().prediction_topic());
container_writer_ = node_->CreateWriter<SubmoduleOutput>(
prediction_conf.topic_conf().container_topic_name());
adc_container_writer_ = node_->CreateWriter<ADCTrajectoryContainer>(
prediction_conf.topic_conf().adccontainer_topic_name());
perception_obstacles_writer_ = node_->CreateWriter<PerceptionObstacles>(
prediction_conf.topic_conf().perception_obstacles_topic_name());
return true;
}
主处理流程:
bool PredictionComponent::Proc(
const std::shared_ptr<PerceptionObstacles>& perception_obstacles) {
if (FLAGS_use_lego) {
return ContainerSubmoduleProcess(perception_obstacles);
}
return PredictionEndToEndProc(perception_obstacles);
}
端到端处理模式:
bool PredictionComponent::PredictionEndToEndProc(
const std::shared_ptr<PerceptionObstacles>& perception_obstacles) {
frame_start_time_ = Clock::NowInSeconds();
auto end_time1 = std::chrono::system_clock::now();
// 1. 更新定位容器
localization_reader_->Observe();
auto ptr_localization_msg = localization_reader_->GetLatestObserved();
if (ptr_localization_msg == nullptr) {
AERROR << "Prediction: cannot receive any localization message.";
return false;
}
MessageProcess::OnLocalization(container_manager_.get(), *ptr_localization_msg);
// 2. 更新故事容器
storytelling_reader_->Observe();
auto ptr_storytelling_msg = storytelling_reader_->GetLatestObserved();
if (ptr_storytelling_msg != nullptr) {
MessageProcess::OnStoryTelling(container_manager_.get(), *ptr_storytelling_msg);
}
// 3. 更新规划轨迹容器
planning_reader_->Observe();
auto ptr_trajectory_msg = planning_reader_->GetLatestObserved();
if (ptr_trajectory_msg != nullptr) {
MessageProcess::OnPlanning(container_manager_.get(), *ptr_trajectory_msg);
}
// 4. 处理感知消息并生成预测
auto perception_msg = *perception_obstacles;
PredictionObstacles prediction_obstacles;
MessageProcess::OnPerception(
perception_msg, container_manager_, evaluator_manager_.get(),
predictor_manager_.get(), scenario_manager_.get(), &prediction_obstacles);
// 5. 后处理预测结果
prediction_obstacles.set_start_timestamp(frame_start_time_);
prediction_obstacles.set_end_timestamp(Clock::NowInSeconds());
prediction_obstacles.mutable_header()->set_lidar_timestamp(
perception_msg.header().lidar_timestamp());
prediction_obstacles.mutable_header()->set_camera_timestamp(
perception_msg.header().camera_timestamp());
prediction_obstacles.mutable_header()->set_radar_timestamp(
perception_msg.header().radar_timestamp());
// 6. 发布结果
common::util::FillHeader(node_->Name(), &prediction_obstacles);
prediction_writer_->Write(prediction_obstacles);
return true;
}
4.1.2 PredictorManager 预测器管理器
类定义:
class PredictorManager {
public:
PredictorManager();
virtual ~PredictorManager() = default;
void Init(const PredictionConf& config);
Predictor* GetPredictor(const ObstacleConf::PredictorType& type);
void Run(const apollo::perception::PerceptionObstacles& perception_obstacles,
const ADCTrajectoryContainer* adc_trajectory_container,
ObstaclesContainer* obstacles_container);
void PredictObstacle(const ADCTrajectoryContainer* adc_trajectory_container,
Obstacle* obstacle,
ObstaclesContainer* obstacles_container,
PredictionObstacle* prediction_obstacle);
private:
std::map<ObstacleConf::PredictorType, std::unique_ptr<Predictor>> predictors_;
};
预测器类型:
- LANE_SEQUENCE_PREDICTOR - 车道序列预测器
- MOVE_SEQUENCE_PREDICTOR - 运动序列预测器
- FREE_MOVE_PREDICTOR - 自由运动预测器
- REGIONAL_PREDICTOR - 区域预测器
- JUNCTION_PREDICTOR - 路口预测器
- EXTRAPOLATION_PREDICTOR - 外推预测器
4.1.3 EvaluatorManager 评估器管理器
评估器类型:
- COST_EVALUATOR - 成本评估器
- MLP_EVALUATOR - 多层感知器评估器
- RNN_EVALUATOR - 循环神经网络评估器
- PHYSICS_EVALUATOR - 物理约束评估器
4.2 关键算法源码分析
4.2.1 车道序列预测算法
核心逻辑:
void LaneSequencePredictor::Predict(
const ADCTrajectoryContainer* adc_trajectory_container,
Obstacle* obstacle,
ObstaclesContainer* obstacles_container,
PredictionObstacle* prediction_obstacle) {
// 获取障碍物当前状态
const Feature& feature = obstacle->latest_feature();
if (!feature.has_lane()) {
AERROR << "Obstacle [" << obstacle->id() << " has no lane feature.";
return;
}
// 生成候选车道序列
std::vector<LaneSequence> lane_sequences = GenerateLaneSequences(
obstacle, obstacles_container);
// 为每个车道序列生成轨迹
for (auto& lane_sequence : lane_sequences) {
std::vector<TrajectoryPoint> trajectory_points;
// 横向和纵向运动模型
for (double rel_time = 0.0; rel_time < prediction_time_horizon_;
rel_time += prediction_time_resolution_) {
// 横向位置计算(考虑换道意图)
double lateral_position = ComputeLateralPosition(
lane_sequence, obstacle, rel_time);
// 纵向位置计算(考虑速度模型)
double longitudinal_position = ComputeLongitudinalPosition(
lane_sequence, obstacle, rel_time);
// 转换为全局坐标
TrajectoryPoint point = ConvertToGlobalPoint(
lateral_position, longitudinal_position, lane_sequence);
trajectory_points.push_back(point);
}
// 创建轨迹并添加到预测结果
Trajectory* trajectory = prediction_obstacle->add_trajectory();
for (const auto& point : trajectory_points) {
*trajectory->add_trajectory_point() = point;
}
// 计算轨迹概率
trajectory->set_probability(ComputeTrajectoryProbability(lane_sequence, obstacle));
}
}
4.2.2 自由运动预测算法
void FreeMovePredictor::Predict(
Obstacle* obstacle,
PredictionObstacle* prediction_obstacle) {
const Feature& feature = obstacle->latest_feature();
// 获取历史轨迹
std::vector<apollo::common::TrajectoryPoint> history_points;
for (int i = 0; i < feature.feature_size(); ++i) {
if (feature.feature(i).has_trajectory()) {
const auto& trajectory = feature.feature(i).trajectory();
for (const auto& point : trajectory.trajectory_point()) {
history_points.push_back(point);
}
}
}
// 拟合运动模型
MotionModel motion_model = FitMotionModel(history_points);
// 生成未来轨迹
std::vector<double> acceleration_samples = {0.0, 1.0, -1.0, 2.0, -2.0};
std::vector<double> yaw_rate_samples = {0.0, 0.1, -0.1, 0.2, -0.2};
for (double acc : acceleration_samples) {
for (double yaw_rate : yaw_rate_samples) {
Trajectory trajectory;
// 基于运动学模型生成轨迹
for (double t = 0.0; t < prediction_horizon_; t += dt_) {
TrajectoryPoint point = GenerateTrajectoryPoint(
motion_model, acc, yaw_rate, t);
*trajectory.add_trajectory_point() = point;
}
// 计算轨迹合理性评分
double score = EvaluateTrajectoryReasonableness(trajectory, motion_model);
trajectory.set_probability(score);
*prediction_obstacle->add_trajectory() = trajectory;
}
}
}
五、消息接口定义
5.1 输入消息
5.1.1 PerceptionObstacles(感知障碍物)
message PerceptionObstacles {
// 感知头信息
apollo.common.Header header = 1;
// 感知障碍物列表
repeated PerceptionObstacle perception_obstacle = 2;
// 错误码
ErrorCode error_code = 3;
// 感知废弃物信息
repeated PerceptionWaste perception_waste = 4;
}
message PerceptionObstacle {
// 障碍物ID
optional int32 id = 1;
// 障碍物位置
optional apollo.common.Point3D position = 2;
// 障碍物速度
optional apollo.common.Point3D velocity = 3;
// 障碍物加速度
optional apollo.common.Point3D acceleration = 4;
// 障碍物方向角
optional double theta = 5;
// 障碍物尺寸(长、宽、高)
optional double length = 6;
optional double width = 7;
optional double height = 8;
// 障碍物类型
optional Type type = 9;
// 时间戳
optional double timestamp = 10;
// 轨迹点历史
repeated apollo.common.TrajectoryPoint trajectory_point = 11;
// 多边形轮廓
repeated apollo.common.Point2D polygon_point = 12;
}
5.2 输出消息
5.2.1 PredictionObstacles(预测障碍物)
message PredictionObstacles {
// 预测头信息
apollo.common.Header header = 1;
// 开始时间戳
double start_timestamp = 2;
// 结束时间戳
double end_timestamp = 3;
// 预测障碍物列表
repeated PredictionObstacle prediction_obstacle = 4;
// 感知错误码
apollo.common.ErrorCode perception_error_code = 5;
// 感知废弃物
repeated apollo.perception.PerceptionWaste perception_waste = 6;
}
message PredictionObstacle {
// 感知障碍物信息
apollo.perception.PerceptionObstacle perception_obstacle = 1;
// 时间戳
double timestamp = 2;
// 预测轨迹列表
repeated Trajectory trajectory = 3;
// 优先级
Priority priority = 4;
// 智能汽车意图
Intent intent = 5;
}
message Trajectory {
// 概率
double probability = 1;
// 轨迹点列表
repeated apollo.common.TrajectoryPoint trajectory_point = 2;
// 轨迹类型
optional TrajectoryType trajectory_type = 3;
}
六、配置文件说明
6.1 预测配置文件(prediction_conf.pb.txt)
topic_conf {
planning_trajectory_topic: "/apollo/planning"
localization_topic: "/apollo/localization/pose"
storytelling_topic: "/apollo/storytelling"
prediction_topic: "/apollo/prediction"
container_topic_name: "/apollo/prediction/container"
adccontainer_topic_name: "/apollo/prediction/adc_container"
perception_obstacles_topic_name: "/apollo/prediction/perception_obstacles"
}
predictor_conf {
predictor_type: LANE_SEQUENCE_PREDICTOR
predictor_type: MOVE_SEQUENCE_PREDICTOR
predictor_type: FREE_MOVE_PREDICTOR
predictor_type: REGIONAL_PREDICTOR
predictor_type: JUNCTION_PREDICTOR
predictor_type: EXTRAPOLATION_PREDICTOR
}
evaluator_conf {
evaluator_type: COST_EVALUATOR
evaluator_type: MLP_EVALUATOR
evaluator_type: RNN_EVALUATOR
}
scenario_conf {
junction_type: URBAN_ROAD
junction_type: HIGHWAY
junction_type: INTERSECTION
}
6.2 障碍物配置(obstacle_conf.pb.txt)
obstacle_conf {
obstacle_type: VEHICLE {
default_predictor_type: LANE_SEQUENCE_PREDICTOR
default_evaluator_type: COST_EVALUATOR
max_prediction_num: 3
prediction_horizon: 8.0
prediction_time_resolution: 0.1
}
obstacle_type: PEDESTRIAN {
default_predictor_type: FREE_MOVE_PREDICTOR
default_evaluator_type: MLP_EVALUATOR
max_prediction_num: 5
prediction_horizon: 5.0
prediction_time_resolution: 0.1
}
obstacle_type: BICYCLE {
default_predictor_type: MOVE_SEQUENCE_PREDICTOR
default_evaluator_type: COST_EVALUATOR
max_prediction_num: 3
prediction_horizon: 6.0
prediction_time_resolution: 0.1
}
}
七、性能优化策略
7.1 计算优化
7.1.1 障碍物过滤机制
// 基于距离和重要性过滤障碍物
std::vector<Obstacle*> FilterImportantObstacles(
ObstaclesContainer* obstacles_container) {
std::vector<Obstacle*> important_obstacles;
const auto& ego_pose = GetEgoPose();
for (const auto& obstacle : obstacles_container->obstacles()) {
// 计算与自车的距离
double distance = ComputeDistance(obstacle->position(), ego_pose);
// 距离过滤
if (distance > max_prediction_distance_) {
continue;
}
// 重要性评分
double importance_score = ComputeImportanceScore(obstacle, distance);
if (importance_score < importance_threshold_) {
continue;
}
important_obstacles.push_back(obstacle);
}
// 按重要性排序
std::sort(important_obstacles.begin(), important_obstacles.end(),
[](const Obstacle* a, const Obstacle* b) {
return a->importance_score() > b->importance_score();
});
return important_obstacles;
}
7.1.2 多线程预测
class ParallelPredictor {
public:
void PredictParallel(
const std::vector<Obstacle*>& obstacles,
ObstaclesContainer* obstacles_container) {
// 使用线程池并行预测
std::vector<std::future<void>> futures;
for (Obstacle* obstacle : obstacles) {
futures.push_back(thread_pool_->submit([this, obstacle, obstacles_container]() {
this->PredictSingleObstacle(obstacle, obstacles_container);
}));
}
// 等待所有预测完成
for (auto& future : futures) {
future.wait();
}
}
private:
std::unique_ptr<ThreadPool> thread_pool_;
};
7.2 内存优化
7.2.1 对象池模式
template<typename T>
class ObjectPool {
public:
std::shared_ptr<T> Acquire() {
std::lock_guard<std::mutex> lock(mutex_);
if (pool_.empty()) {
return std::make_shared<T>();
}
auto obj = pool_.back();
pool_.pop_back();
return obj;
}
void Release(std::shared_ptr<T> obj) {
std::lock_guard<std::mutex> lock(mutex_);
obj->Reset();
pool_.push_back(obj);
}
private:
std::vector<std::shared_ptr<T>> pool_;
std::mutex mutex_;
};
// 使用对象池管理预测结果
ObjectPool<PredictionObstacle> prediction_obstacle_pool_;
ObjectPool<Trajectory> trajectory_pool_;
7.2.2 内存缓存策略
class TrajectoryCache {
public:
std::shared_ptr<Trajectory> GetCachedTrajectory(
const Obstacle* obstacle,
const LaneSequence& lane_sequence) {
std::string cache_key = GenerateCacheKey(obstacle, lane_sequence);
auto it = cache_.find(cache_key);
if (it != cache_.end() && !IsExpired(it->second)) {
return it->second.trajectory;
}
return nullptr;
}
void CacheTrajectory(
const Obstacle* obstacle,
const LaneSequence& lane_sequence,
std::shared_ptr<Trajectory> trajectory) {
std::string cache_key = GenerateCacheKey(obstacle, lane_sequence);
CacheEntry entry{trajectory, Clock::NowInSeconds()};
cache_[cache_key] = entry;
// 清理过期缓存
CleanExpiredEntries();
}
private:
struct CacheEntry {
std::shared_ptr<Trajectory> trajectory;
double timestamp;
};
std::unordered_map<std::string, CacheEntry> cache_;
static constexpr double CACHE_EXPIRE_TIME = 0.1; // 100ms
};
八、测试与验证
8.1 单元测试
8.1.1 预测器测试
TEST(LaneSequencePredictorTest, PredictVehicle) {
// 创建测试障碍物
Obstacle obstacle;
obstacle.set_id(1);
obstacle.set_type(apollo::perception::PerceptionObstacle::VEHICLE);
// 设置初始状态
Feature* feature = obstacle.add_feature();
feature->mutable_position()->set_x(10.0);
feature->mutable_position()->set_y(5.0);
feature->mutable_velocity()->set_x(5.0);
feature->mutable_velocity()->set_y(0.0);
// 创建预测器
LaneSequencePredictor predictor;
predictor.Init();
// 执行预测
PredictionObstacle prediction_obstacle;
predictor.Predict(nullptr, &obstacle, nullptr, &prediction_obstacle);
// 验证结果
EXPECT_GT(prediction_obstacle.trajectory_size(), 0);
for (const auto& trajectory : prediction_obstacle.trajectory()) {
EXPECT_GT(trajectory.trajectory_point_size(), 0);
EXPECT_GE(trajectory.probability(), 0.0);
EXPECT_LE(trajectory.probability(), 1.0);
}
}
8.2 集成测试
8.2.1 端到端预测测试
TEST(PredictionComponentTest, EndToEndPrediction) {
// 创建预测组件
PredictionComponent component;
// 初始化
EXPECT_TRUE(component.Init());
// 创建感知消息
auto perception_obstacles = std::make_shared<PerceptionObstacles>();
auto* obstacle = perception_obstacles->add_perception_obstacle();
obstacle->set_id(1);
obstacle->set_type(PerceptionObstacle::VEHICLE);
obstacle->mutable_position()->set_x(10.0);
obstacle->mutable_position()->set_y(5.0);
// 执行预测
bool result = component.Proc(perception_obstacles);
EXPECT_TRUE(result);
}
九、总结
Prediction模块是Apollo自动驾驶系统的核心预测组件,通过多模态轨迹预测、智能意图识别和高效计算优化,为规划模块提供了准确的未来交通参与者行为预测。
9.1 技术特点
1. 多模态预测 - 为每个障碍物生成多条可能轨迹,覆盖不同行为意图
2. 场景化处理 - 针对不同道路场景采用专门的预测策略
3. 机器学习融合 - 结合传统物理模型和深度学习方法
4. 实时性能优化 - 通过并行计算和缓存策略保证实时性
5. 可扩展架构 - 支持新预测器和评估器的插件式集成
9.2 关键性能指标
- 预测时延: < 50ms
- 预测精度: 车辆轨迹误差 < 2.0m @ 5s
- 支持规模: 最大100个障碍物同时预测
- 更新频率: 10Hz
- 内存占用: < 500MB
Prediction模块的高精度预测为自动驾驶系统的安全决策提供了重要保障,是整个系统不可或缺的关键环节。

1438

被折叠的 条评论
为什么被折叠?



