def reward_shaping(_obs, _extra_info, act, agent):
"""
奖励塑造函数:结合等待时间变化、队列长度、旅行效率和通行效率计算奖励
适配FeatureProcess类的数据结构,减少重复计算,提升逻辑准确性
"""
junction_ids = list(agent.preprocess.junction_dict.keys())
fp = agent.preprocess # 简化引用
frame_state = _obs["framestate"]
frame_time = frame_state["frame_time"]
vehicles = frame_state["vehicles"]
# 初始化奖励字典
rewards = {j_id: {'total': 0, 'components': {}} for j_id in junction_ids}
# 先更新交通信息(确保FeatureProcess中的动态数据是最新的)
fp.update_traffic_info(_obs, _extra_info)
# 预计算所有路口的平均等待时间(复用FeatureProcess的计算结果)
all_junction_waiting = fp.get_all_junction_waiting_time(vehicles)
# 为每个路口计算奖励分量
for j_id in junction_ids:
# 1. 等待时间变化奖励(基于平均等待时间的变化量)
current_waiting = all_junction_waiting.get(j_id, 0.0)
prev_waiting = fp.old_waiting_time[j_id]
waiting_delta = prev_waiting - current_waiting # 等待时间减少则奖励为正
# 归一化:将变化量映射到[0, 2](避免过大波动,0表示无改善,2表示最大改善)
waiting_reward = 1.0 + np.tanh(waiting_delta / 5.0) # 5.0为缩放因子,可调整
fp.old_waiting_time[j_id] = current_waiting # 更新历史等待时间
rewards[j_id]['components']['waiting'] = waiting_reward
# 2. 队列长度奖励(惩罚过长队列)
queue_length = calculate_queue_length(j_id, fp)
# 归一化:队列越长惩罚越大,映射到[-1, 0](tanh确保极端情况不会无限惩罚)
queue_reward = -np.tanh(queue_length / 15.0) # 15.0为车道容量参考值,可调整
rewards[j_id]['components']['queue'] = queue_reward
# 3. 旅行效率奖励(奖励车辆在进口道的有效行驶距离)
travel_reward = calculate_travel_reward(j_id, fp, vehicles)
rewards[j_id]['components']['travel'] = travel_reward
# 4. 通行效率奖励(单位时间通过路口的车辆数)
throughput_reward = calculate_throughput(j_id, frame_time, fp)
rewards[j_id]['components']['throughput'] = throughput_reward
# 组合奖励(权重可根据场景动态调整)
weights = {
'waiting': 0.4, # 优先优化等待时间
'queue': 0.3, # 次优先优化队列长度
'travel': 0.15, # 辅助优化行驶效率
'throughput': 0.15 # 辅助优化通行量
}
rewards[j_id]['total'] = sum(
weights[k] * rewards[j_id]['components'][k]
for k in weights
)
# 按路口ID排序,为每个路口返回2个奖励值(对应相位和时长控制)
sorted_junctions = sorted(junction_ids)
return tuple(
rewards[j_id]['total']
for j_id in sorted_junctions
for _ in range(2) # 每个路口输出2个相同奖励(适配多维度动作)
)
def calculate_queue_length(junction_id, fp: FeatureProcess):
"""
计算路口进口道的总排队车辆数(仅统计处于等待状态的车辆)
复用FeatureProcess维护的lane_volume和waiting_time_store
"""
junction = fp.junction_dict[junction_id]
total_queue = 0
# 遍历路口所有进口车道
for direction in junction["enter_lanes_on_directions"]:
for lane_id in direction["lanes"]:
# 仅统计车道上处于等待状态的车辆(速度≤0.1m/s)
# 从lane_volume中筛选在waiting_time_store中的车辆(等待状态)
waiting_vehicles = [
v_id for v_id in fp.lane_volume.get(lane_id, [])
if v_id in fp.waiting_time_store
]
total_queue += len(waiting_vehicles)
return total_queue
def calculate_travel_reward(junction_id, fp: FeatureProcess, vehicles):
"""
计算路口进口道车辆的平均旅行效率(有效行驶距离)
结合车辆当前位置和历史行驶距离,确保统计的是当前路口的车辆
"""
total_distance = 0.0
valid_count = 0
# 遍历当前路口进口道上的车辆(通过vehicle_prev_junction和当前junction判断)
for vehicle in vehicles:
v_id = vehicle["v_id"]
# 条件1:车辆当前在进口车道上
# 条件2:车辆的目标路口是当前路口(确保属于该路口管控范围)
if on_enter_lane(vehicle) and vehicle["target_junction"] == junction_id:
# 累加该车辆在当前进口道的行驶距离
if v_id in fp.vehicle_distance_store:
total_distance += fp.vehicle_distance_store[v_id]
valid_count += 1
if valid_count == 0:
return 0.0 # 无车辆时奖励为0
avg_distance = total_distance / valid_count
# 归一化:距离越长(行驶顺畅)奖励越高,映射到[0, 1]
return np.tanh(avg_distance / 100.0) # 100.0为参考距离,可根据道路长度调整
def calculate_throughput(junction_id, current_time, fp: FeatureProcess):
"""
计算路口单位时间的通行效率(最近30秒通过的车辆数)
基于FeatureProcess的enter_lane_time和waiting_time_store判断车辆是否通过
"""
time_window = 30.0 # 统计最近30秒
vehicles_passed = 0
# 遍历所有曾进入当前路口进口道的车辆
for v_id, enter_time in fp.enter_lane_time.items():
# 条件1:车辆的目标路口是当前路口
# 条件2:进入时间在统计窗口内
# 条件3:车辆已离开进口道(不在等待状态,即不在waiting_time_store中)
if (fp.vehicle_prev_junction.get(v_id, -1) == junction_id
and (current_time - enter_time) <= time_window
and v_id not in fp.waiting_time_store):
vehicles_passed += 1
# 归一化:通行量越高奖励越高,映射到[0, 1]
# 转换为"车辆/分钟"并缩放(假设单车道最大通行量约为120辆/分钟)
throughput = vehicles_passed * 2 # 30秒→60秒的转换
return np.tanh(throughput / 120.0) # 120为最大参考值,可调整
# 辅助函数:判断车辆是否在进口车道上(复用原逻辑,确保与FeatureProcess一致)
def on_enter_lane(vehicle):
"""判断车辆是否处于进口车道(参考FeatureProcess中的逻辑)"""
return vehicle["lane"] != "" and vehicle["in_junction"] is False
检查此代码是否和以下特征处理匹配?
class FeatureProcess:
"""
Update traffic information and perform feature processing
"""
"""
更新交通信息并进行特征处理
"""
def __init__(self, logger):
self.logger = logger
self.reset()
def reset(self):
# Store road structure and other relatively fixed dictionary-type variables
# 存储道路结构等相对固定的字典型变量
self.junction_dict = {}
self.edge_dict = {}
self.lane_dict = {}
self.l_id_to_index = {}
self.vehicle_configs_dict = {}
# Store dictionary-type variables for dynamic traffic information in traffic scenarios,
# and update data after each communication
# 存储交通场景中动态交通信息的字典型变量, 在每次通信后进行数据更新
self.vehicle_prev_junction = {}
self.vehicle_prev_position = {}
self.vehicle_distance_store = {}
self.last_waiting_moment = {}
self.waiting_time_store = {}
self.enter_lane_time = {}
# Stores variables that can be used to calculate rewards
# 存储可用于计算奖励的变量
self.lane_volume = {}
# User-defined variable
# 用户自定义变量
self.old_waiting_time = {}
self.enter_lane_count = 26
def init_road_info(self, start_info):
"""
Obtain fixed variables such as road structure
"""
"""
获取道路结构等信息固定的变量
"""
junctions, signals, edges = (
start_info["junctions"],
start_info["signals"],
start_info["edges"],
)
lane_configs, vehicle_configs = (
start_info["lane_configs"],
start_info["vehicle_configs"],
)
# Store road structure information in various variables
# 将道路结构信息存储到各个变量
for junction in junctions:
self.junction_dict[junction["j_id"]] = junction
self.l_id_to_index[junction["j_id"]] = {}
self.old_waiting_time[junction["j_id"]] = 0
index = 0
for approaching_edges in junction["enter_lanes_on_directions"]:
for lane in approaching_edges["lanes"]:
self.l_id_to_index[junction["j_id"]][lane] = index
index += 1
for edge in edges:
self.edge_dict[edge["e_id"]] = edge
for lane in lane_configs:
self.lane_dict[lane["l_id"]] = lane
for vehicle_config in vehicle_configs:
self.vehicle_configs_dict[vehicle_config["v_config_id"]] = vehicle_config
for lane in lane_configs:
self.lane_volume[lane["l_id"]] = []
def update_traffic_info(self, obs, extra_info):
"""
Update vehicle history information and calculate various dynamic traffic variables
"""
"""
更新车辆历史信息, 计算各项动态交通变量
"""
frame_state = obs["framestate"]
frame_no = frame_state["frame_no"]
frame_time, vehicles = frame_state["frame_time"], frame_state["vehicles"]
if frame_no <= 1:
# Initial frame loads road structure information
# 初始帧载入道路结构信息
game_info = extra_info["gameinfo"]
self.init_road_info(game_info)
for vehicle in vehicles:
# If the vehicle appears for the first time, initialize the vehicle's historical intersection information
# 如果车辆第一次出现,则初始化车辆的历史交叉口信息
if vehicle["v_id"] not in self.vehicle_prev_junction:
self.vehicle_prev_junction[vehicle["v_id"]] = vehicle["junction"]
# For vehicles that appear for the first time, if they are on the lane, record their appearance time
# 对于首次出现的车辆, 若在车道上则记录其出现时间
if (
self.vehicle_prev_junction[vehicle["v_id"]] == -1
and on_enter_lane(vehicle)
and vehicle["v_id"] not in self.enter_lane_time
):
self.enter_lane_time[vehicle["v_id"]] = frame_time
# When a vehicle enters another entrance lane from the intersection, recalculate its appearance time
# 当车辆从交叉口驶入另一进口车道时, 重新统计其出现时间
elif self.vehicle_prev_junction[vehicle["v_id"]] != vehicle["junction"]:
if self.vehicle_prev_junction[vehicle["v_id"]] != -1 and on_enter_lane(vehicle):
self.enter_lane_time[vehicle["v_id"]] = frame_time
self.cal_waiting_time(frame_time, vehicle)
self.cal_travel_distance(vehicle)
self.cal_v_num_in_lane(vehicle)
def cal_waiting_time(self, frame_time, vehicle):
"""
Calculate the waiting time of the vehicle. When the vehicle is on the enter lane,
count the accumulated time when its speed is <= 0.1m/s as its waiting time when driving at the intersection
"""
"""
计算车辆等待时间, 当车辆处于进口车道上时, 统计其车速<=0.1m/s的累计时长作为其在该交叉口行驶时的等待时间
"""
waiting_time = 0
# Determine whether the vehicle in the lane approaching the intersection is in a waiting state,
# and calculate the waiting time
# 对处于车道驶向交叉口的车辆判断是否处于等待状态, 计算等待时间
if on_enter_lane(vehicle):
# Determine whether the vehicle is in a waiting state.
# The determination condition is that the vehicle speed is <= 0.1m/s
# 判断车辆是否处于等待状态, 判定条件为车辆速度<=0.1m/s
if vehicle["speed"] <= 0.1:
if vehicle["v_id"] not in self.last_waiting_moment:
# Record the starting moment of each time the vehicle enters the waiting state
# 记录车辆在每次进入等待状态的起始时刻
self.last_waiting_moment[vehicle["v_id"]] = frame_time
# When the vehicle is in the waiting state for the first time,
# initialize its accumulated waiting time
# 车辆首次处于等待状态则初始化车辆累计等待时间
if vehicle["v_id"] not in self.waiting_time_store:
self.waiting_time_store[vehicle["v_id"]] = 0
else:
# When a vehicle enters the waiting state on a lane,
# waiting_time records the duration of the current waiting state
# 车辆在一条道路上进入等待状态, waiting_time记录本次等待状态已持续的时间
waiting_time = frame_time - self.last_waiting_moment[vehicle["v_id"]]
self.waiting_time_store[vehicle["v_id"]] += waiting_time
self.last_waiting_moment[vehicle["v_id"]] = frame_time
else:
if vehicle["v_id"] in self.last_waiting_moment:
del self.last_waiting_moment[vehicle["v_id"]]
else:
# Prevent repeated del when the vehicle is generated for the first time or at an intersection,
# v_id is not stored in self.waiting_time_store
# 防止车辆首次生成或位于交叉口时反复del, v_id未储存在self.waiting_time_store内
if vehicle["v_id"] in self.waiting_time_store:
del self.waiting_time_store[vehicle["v_id"]]
if vehicle["v_id"] in self.last_waiting_moment:
del self.last_waiting_moment[vehicle["v_id"]]
def cal_travel_distance(self, vehicle):
"""
Calculate the travel distance. When the vehicle is on the enter lane,
count the total distance it travels at the intersection
"""
"""
计算旅行路程, 当车辆处于进口车道上时, 统计其在该交叉口行驶时的总路程
"""
# When the vehicle is on the lane, calculate the cumulative distance
# 当车辆处于车道上时, 计算累计路程
if on_enter_lane(vehicle):
# When the vehicle enters the lane from inside the intersection for the second or subsequent time,
# clear the cumulative distance and prepare to calculate the distance of this entry into the inlane
# 车辆非首次从交叉口内部驶入车道时, 清空累计路程, 准备计算此次进入进口车道的路程
if self.vehicle_prev_junction[vehicle["v_id"]] != -1 and vehicle["v_id"] in self.vehicle_distance_store:
del self.vehicle_distance_store[vehicle["v_id"]]
if vehicle["v_id"] not in self.vehicle_distance_store:
self.vehicle_distance_store[vehicle["v_id"]] = 0
self.vehicle_prev_position[vehicle["v_id"]] = [
vehicle["position_in_lane"]["x"],
vehicle["position_in_lane"]["y"],
]
else:
if vehicle["v_id"] in self.vehicle_distance_store and vehicle["v_id"] in self.vehicle_prev_position:
try:
# Calculate Euclidean distance
# 计算欧氏距离
self.vehicle_distance_store[vehicle["v_id"]] += math.sqrt(
math.pow(
vehicle["position_in_lane"]["x"] - self.vehicle_prev_position[vehicle["v_id"]][0],
2,
)
+ math.pow(
vehicle["position_in_lane"]["y"] - self.vehicle_prev_position[vehicle["v_id"]][1],
2,
)
)
except Exception:
raise ValueError
# Update the vehicle's historical position after each distance calculation
# 每次计算距离后更新车辆历史位置
self.vehicle_prev_position[vehicle["v_id"]] = [
vehicle["position_in_lane"]["x"],
vehicle["position_in_lane"]["y"],
]
else:
# When the vehicle enters the intersection,
# delete the historical location information to avoid calculating the driving distance
# based on the last departure position when entering the lane next time
# 当车辆驶入交叉口, 删除历史位置信息, 避免下次进入车道时按上一次离开路口位置计算行驶距离
if vehicle["v_id"] in self.vehicle_prev_position:
del self.vehicle_prev_position[vehicle["v_id"]]
def cal_v_num_in_lane(self, vehicle):
"""
Calculate the number of vehicles on the lane.
When a vehicle is in the import lane, the number of vehicles on the enter lane increases
"""
"""
计算车道上的车辆数, 当车辆处于进口车道上时, 则该进口车道上车辆数增加
"""
# Update the number of vehicles on each lane
# 更新每条车道上的车辆数量
if on_enter_lane(vehicle):
if vehicle["v_id"] not in self.lane_volume[vehicle["lane"]]:
self.lane_volume[vehicle["lane"]].append(vehicle["v_id"])
# Update the vehicle's historical intersection information
# 更新车辆的历史所在交叉口信息
self.vehicle_prev_junction[vehicle["v_id"]] = vehicle["junction"]
def get_all_junction_waiting_time(self, vehicles: list):
"""
This function obtain a dict of waiting_time by junction
Args:
- vehicles (list): input list of Vehicle
- vehicle_waiting_time (list): input key = v_id (uint32), value = vehicle_waiting_time (list)
Returns:
- dict: key = vehicle.junction (uint32), value = junction waiting time
"""
"""
此函数获取交叉口车辆等待时间的字典
参数:
- vehicles (list): 车辆的输入列表
- vehicle_waiting_time (list): input key = v_id (uint32), value = vehicle_waiting_time (list)
返回:
- dict: key = vehicle.junction (uint32), value = junction waiting time
"""
res = {}
v_num = {}
for junction_id in self.junction_dict:
res[junction_id] = 0
v_num[junction_id] = 0
for vehicle in vehicles:
if vehicle["junction"] != -1 or vehicle["target_junction"] == -1:
continue
if vehicle["v_id"] in self.waiting_time_store:
t = self.waiting_time_store[vehicle["v_id"]]
else:
t = 0
res[vehicle["target_junction"]] += t
v_num[vehicle["target_junction"]] += 1
# Calculate the average waiting time of all vehicles in the scene
# 计算场景内所有车辆的平均等待时间
for junction_id in self.junction_dict:
if v_num[junction_id] != 0:
res[junction_id] /= v_num[junction_id]
return res