class Algorithm:
def __init__(self, model, optimizer, device=None, logger=None, monitor=None):
self.device = device or torch.device("cpu")
self.model = model
self.optim = optimizer
self.logger = logger
self.monitor = monitor
# 算法参数(与Config一致)
self.num_head = Config.NUMB_HEAD # 4
self._gamma = Config.GAMMA # 0.99
self.target_update_freq = Config.TARGET_UPDATE_FREQ # 500步
# 目标网络
self.target_model = deepcopy(self.model)
self.target_model.eval()
# 训练状态
self.train_step = 0
self.last_report_time = time.time()
def learn(self, list_sample_data):
"""学习逻辑:按head分别处理legal_action,避免堆叠"""
if len(list_sample_data) == 0:
self.logger.warning("样本为空,跳过学习")
return {"loss": 0.0}
# ------------------------------ 1. 提取样本数据(按head处理legal_action) ------------------------------
# 提取非legal_action数据(可直接stack)
obs_list = []
action_list = []
rew_list = []
next_obs_list = []
not_done_list = []
# 提取legal_action(按head存储:head_idx -> [样本1的掩码, 样本2的掩码, ...])
legal_action_per_head = [[] for _ in range(self.num_head)]
for sample in list_sample_data:
# 基础数据
obs_list.append(torch.tensor(sample.obs, dtype=torch.float32, device=self.device))
action_list.append(torch.tensor(sample.act, dtype=torch.float32, device=self.device))
rew_list.append(torch.tensor(sample.rew, dtype=torch.float32, device=self.device))
next_obs_list.append(torch.tensor(sample._obs, dtype=torch.float32, device=self.device))
not_done_list.append(torch.tensor([sample.done], dtype=torch.float32, device=self.device))
# 处理legal_action(样本的legal_action是4个head的列表)
legal = sample.legal_action
for head_idx in range(self.num_head):
if head_idx < len(legal):
# 转换为张量并添加到对应head的列表
legal_tensor = torch.tensor(legal[head_idx], dtype=torch.float32, device=self.device)
legal_action_per_head[head_idx].append(legal_tensor)
# 堆叠基础数据
obs = torch.stack(obs_list) # (batch, obs_dim)
action = torch.stack(action_list) # (batch, num_head)
rew = torch.stack(rew_list) # (batch, num_head)
next_obs = torch.stack(next_obs_list) # (batch, obs_dim)
not_done = torch.stack(not_done_list).squeeze(1) # (batch,)
# 堆叠legal_action(每个head单独堆叠:(batch, action_dim_per_head))
for head_idx in range(self.num_head):
legal_action_per_head[head_idx] = torch.stack(legal_action_per_head[head_idx])
# ------------------------------ 2. 计算目标Q值(DDQN逻辑) ------------------------------
self.target_model.eval()
self.model.eval()
q_targets = []
with torch.no_grad():
for head_idx in range(self.num_head):
# 2.1 当前网络选择下一状态的最佳动作(考虑合法动作)
current_q_next = self.model(next_obs)[head_idx] # (batch, action_dim_per_head)
legal_mask = legal_action_per_head[head_idx] # (batch, action_dim_per_head)
# 非法动作Q值设为-1e10(不被选中)
current_q_next_masked = current_q_next + (1 - legal_mask) * (-1e10)
best_actions = torch.argmax(current_q_next_masked, dim=1, keepdim=True) # (batch, 1)
# 2.2 目标网络评估最佳动作的价值
target_q_next = self.target_model(next_obs)[head_idx] # (batch, action_dim_per_head)
target_q_best = target_q_next.gather(1, best_actions) # (batch, 1)
# 2.3 计算目标Q值:rew + gamma * target_q_best * not_done
rew_head = rew[:, head_idx].unsqueeze(1) # (batch, 1)
q_target_head = rew_head + self._gamma * target_q_best * not_done.unsqueeze(1)
q_targets.append(q_target_head)
# 拼接所有head的目标Q值:(batch, num_head)
q_targets = torch.cat(q_targets, dim=1)
# ------------------------------ 3. 计算当前Q值与损失 ------------------------------
self.model.train()
q_values = []
for head_idx in range(self.num_head):
# 提取当前动作对应的Q值
current_q = self.model(obs)[head_idx] # (batch, action_dim_per_head)
# 动作索引:(batch, 1)(确保是long类型)
action_idx = action[:, head_idx].long().unsqueeze(1)
# 提取当前动作的Q值
q_value_head = current_q.gather(1, action_idx) # (batch, 1)
q_values.append(q_value_head)
# 拼接当前Q值:(batch, num_head)
q_values = torch.cat(q_values, dim=1)
# 计算MSE损失(按head累加)
loss = 0.0
for head_idx in range(self.num_head):
loss += F.mse_loss(q_values[:, head_idx], q_targets[:, head_idx])
# ------------------------------ 4. 优化与目标网络更新 ------------------------------
self.optim.zero_grad()
loss.backward()
# 梯度裁剪(防止梯度爆炸)
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0).item()
self.optim.step()
# 定期更新目标网络
self.train_step += 1
if self.train_step % self.target_update_freq == 0:
self.target_model.load_state_dict(self.model.state_dict())
self.logger.info(f"训练步数{self.train_step}:更新目标网络")
# ------------------------------ 5. 日志与监控 ------------------------------
# 定期上报监控数据(每30秒)
now = time.time()
if now - self.last_report_time >= 30:
avg_q = q_values.mean().item()
avg_target_q = q_targets.mean().item()
monitor_data = {
"value_loss": loss.item(),
"avg_q_value": avg_q,
"avg_target_q_value": avg_target_q,
"grad_norm": grad_norm,
"train_step": self.train_step
}
if self.monitor:
self.monitor.put_data({os.getpid(): monitor_data})
self.logger.info(
f"损失: {loss.item():.4f}, 平均Q值: {avg_q:.4f}, "
f"平均目标Q值: {avg_target_q:.4f}, 梯度范数: {grad_norm:.4f}"
)
self.last_report_time = now
return {"loss": loss.item(), "grad_norm": grad_norm}
class Agent(BaseAgent):
def __init__(self, agent_type="player", device=None, logger=None, monitor=None):
super().__init__(agent_type, device, logger, monitor)
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = Model(device=self.device)
# 优化器与探索参数
self.optim = torch.optim.RMSprop(self.model.parameters(), lr=Config.LR)
self._eps = Config.START_EPSILON_GREEDY
self.end_eps = Config.END_EPSILON_GREEDY
self.eps_decay = Config.EPSILON_DECAY
# 动作维度(与Config一致)
self.head_dim = [
Config.DIM_OF_ACTION_PHASE_1,
Config.DIM_OF_ACTION_DURATION_1,
Config.DIM_OF_ACTION_PHASE_2,
Config.DIM_OF_ACTION_DURATION_2,
]
self.num_head = Config.NUMB_HEAD # 4
# 预处理模块(核心依赖)
self.preprocess = FeatureProcess(logger, usr_conf={})
self.road_info_initialized = False # 标记道路信息是否初始化
self.last_action = None
# 算法模块
from your_algorithm import Algorithm # 替换为实际Algorithm导入
self.algorithm = Algorithm(
self.model, self.optim, self.device, self.logger, self.monitor
)
def reset(self):
self.preprocess.reset()
self._eps = Config.START_EPSILON_GREEDY
def init_road_info(self, extra_info):
"""初始化道路信息(从extra_info加载静态配置)"""
if self.road_info_initialized:
return
game_info = extra_info.get("gameinfo", {})
self.preprocess.init_road_info(game_info)
self.road_info_initialized = True
self.logger.info("Agent道路信息初始化完成")
# ------------------------------ 预测与动作处理 ------------------------------
def __predict_detail(self, list_obs_data, exploit_flag=False):
"""内部预测逻辑:生成4个head的动作"""
feature = [torch.tensor(obs.feature, dtype=torch.float32, device=self.device) for obs in list_obs_data]
feature = torch.stack(feature)
self.model.eval()
self._eps = max(self.end_eps, self._eps * self.eps_decay)
with torch.no_grad():
if np.random.rand() >= self._eps or exploit_flag:
# 贪心选择:取Q值最大的动作
res = self.model(feature) # 假设model输出为 tuple([4个head的Q值])
list_phase_1 = torch.argmax(res[0], dim=1).cpu().tolist()
list_duration_1 = torch.argmax(res[1], dim=1).cpu().tolist()
list_phase_2 = torch.argmax(res[2], dim=1).cpu().tolist()
list_duration_2 = torch.argmax(res[3], dim=1).cpu().tolist()
else:
# 随机探索:按动作维度随机选择
list_phase_1 = np.random.choice(self.head_dim[0], len(list_obs_data)).tolist()
list_duration_1 = np.random.choice(self.head_dim[1], len(list_obs_data)).tolist()
list_phase_2 = np.random.choice(self.head_dim[2], len(list_obs_data)).tolist()
list_duration_2 = np.random.choice(self.head_dim[3], len(list_obs_data)).tolist()
# 构建ActData列表
return [
ActData(
phase_index_1=list_phase_1[i],
duration_1=list_duration_1[i],
phase_index_2=list_phase_2[i],
duration_2=list_duration_2[i],
)
for i in range(len(list_obs_data))
]
@predict_wrapper
def predict(self, list_obs_data):
return self.__predict_detail(list_obs_data, exploit_flag=False)
@exploit_wrapper
def exploit(self, observation):
"""推理模式:生成环境可执行的动作"""
# 初始化道路信息(首次调用时)
if not self.road_info_initialized:
self.init_road_info(observation["extra_info"])
# 处理观测
obs_data = self.observation_process(observation["obs"], observation["extra_info"])
if not obs_data:
self.logger.warning("观测处理失败,返回默认动作")
return [[1, 0, 1], [2, 0, 1]] # 默认动作:路口1相位0时长1,路口2相位0时长1
# 预测动作
act_data = self.__predict_detail([obs_data], exploit_flag=True)[0]
return self.action_process(act_data)
# ------------------------------ 观测处理(核心修正) ------------------------------
def observation_process(self, obs, extra_info):
"""处理观测:生成feature和合法动作掩码"""
# 1. 初始化道路信息(首次调用)
if not self.road_info_initialized:
self.init_road_info(extra_info)
# 2. 更新预处理模块的动态数据
self.preprocess.update_traffic_info(obs, extra_info)
frame_state = obs.get("framestate", {})
vehicles = frame_state.get("vehicles", [])
phases = frame_state.get("phases", [])
# 3. 栅格化处理(位置+速度)
grid_w, grid_num = Config.GRID_WIDTH, Config.GRID_NUM
speed_dict = np.zeros((grid_w, grid_num), dtype=np.float32)
position_dict = np.zeros((grid_w, grid_num), dtype=np.int32)
# 4. 遍历车辆填充栅格(容错处理)
for vehicle in vehicles:
if not self.preprocess.on_enter_lane(vehicle):
continue # 仅处理进口道车辆
# 计算x_pos(车道编码)
try:
x_pos = get_lane_code(vehicle) # 假设该函数返回0-25的编码
if x_pos < 0 or x_pos >= grid_w:
continue
# 计算y_pos(栅格索引)
pos_in_lane = vehicle.get("position_in_lane", {})
y_coord = pos_in_lane.get("y", 0)
y_pos = int(y_coord // Config.GRID_LENGTH)
if y_pos < 0 or y_pos >= grid_num:
continue
# 计算归一化速度(调用FeatureProcess的容错接口)
v_config_id = vehicle.get("v_config_id", 1)
max_speed = self.preprocess.get_vehicle_max_speed(v_config_id)
vehicle_speed = vehicle.get("speed", 0)
normalized_speed = vehicle_speed / max_speed if max_speed != 0 else 0.0
# 填充栅格
speed_dict[x_pos, y_pos] = normalized_speed
position_dict[x_pos, y_pos] = 1
except Exception as e:
v_id = vehicle.get("v_id", "未知")
self.logger.warning(f"处理车辆{v_id}栅格化失败: {str(e)}")
continue
# 5. 解析当前相位(按信号灯ID匹配,非硬编码)
# 5.1 路口1相位(对应Config.DIM_OF_ACTION_PHASE_1=4)
j1 = self.preprocess.get_junction_by_signal(0) # 信号灯0对应路口1
current_phase1 = 0
for phase in phases:
if phase.get("s_id") == 0:
current_phase1 = phase.get("phase_id", 0)
jct_phase_1 = [0] * Config.DIM_OF_ACTION_PHASE_1
if 0 <= current_phase1 < Config.DIM_OF_ACTION_PHASE_1:
jct_phase_1[current_phase1] = 1
# 5.2 路口2相位(对应Config.DIM_OF_ACTION_PHASE_2=3)
j2 = self.preprocess.get_junction_by_signal(1) # 信号灯1对应路口2
current_phase2 = 0
for phase in phases:
if phase.get("s_id") == 1:
current_phase2 = phase.get("phase_id", 0)
jct_phase_2 = [0] * Config.DIM_OF_ACTION_PHASE_2
if 0 <= current_phase2 < Config.DIM_OF_ACTION_PHASE_2:
jct_phase_2[current_phase2] = 1
# 6. 生成合法动作掩码(4个head,列表形式,不转张量)
legal_action = self._generate_legal_action(current_phase1, current_phase2)
# 7. 构建feature(确保维度与Config.DIM_OF_OBSERVATION=1056一致)
position_flat = position_dict.flatten().tolist() # 26*20=520
speed_flat = speed_dict.flatten().tolist() # 26*20=520
weather_feature = [self.preprocess.get_weather()] # 1(天气编码)
peak_feature = [1 if self.preprocess.is_peak_hour() else 0] # 1(高峰期标记)
# 拼接:520+520+4+3+1+1=1049 → 补充7个0凑够1056(可替换为其他特征)
padding = [0] * (Config.DIM_OF_OBSERVATION - len(position_flat + speed_flat + jct_phase_1 + jct_phase_2 + weather_feature + peak_feature))
feature = position_flat + speed_flat + jct_phase_1 + jct_phase_2 + weather_feature + peak_feature + padding
# 8. 构建ObsData(携带legal_action,不加入feature)
obs_data = ObsData(feature=feature)
obs_data.legal_action = legal_action # 关键:传递4个head的合法掩码列表
return obs_data
def _generate_legal_action(self, current_phase1, current_phase2):
"""生成4个head的合法动作掩码(列表形式,长度分别为4、30、3、30)"""
legal = []
# Head0:路口1相位(禁止当前相位)
phase1_dim = Config.DIM_OF_ACTION_PHASE_1
legal_phase1 = [1] * phase1_dim
if 0 <= current_phase1 < phase1_dim:
legal_phase1[current_phase1] = 0 # 禁止重复选择当前相位
legal.append(legal_phase1)
# Head1:路口1时长(1-30s,全部合法)
duration1_dim = Config.DIM_OF_ACTION_DURATION_1
legal_duration1 = [1] * duration1_dim
# 可选:限制最大时长(如≤20s)
# for i in range(20, duration1_dim):
# legal_duration1[i] = 0
legal.append(legal_duration1)
# Head2:路口2相位(禁止当前相位)
phase2_dim = Config.DIM_OF_ACTION_PHASE_2
legal_phase2 = [1] * phase2_dim
if 0 <= current_phase2 < phase2_dim:
legal_phase2[current_phase2] = 0
legal.append(legal_phase2)
# Head3:路口2时长(全部合法)
duration2_dim = Config.DIM_OF_ACTION_DURATION_2
legal_duration2 = [1] * duration2_dim
legal.append(legal_duration2)
return legal
def action_process(self, act_data):
"""将ActData转换为环境可执行的动作格式:[[j_id, phase, duration], ...]"""
# 时长+1(动作从0开始,环境需要1-30s)
duration1 = act_data.duration_1 + 1
duration2 = act_data.duration_2 + 1
# 限制时长在合理范围(1-Config.MAX_GREEN_DURATION)
duration1 = max(1, min(duration1, Config.MAX_GREEN_DURATION))
duration2 = max(1, min(duration2, Config.MAX_GREEN_DURATION))
# 路口ID:从FeatureProcess获取(或默认j1→1,j2→2)
junction_ids = self.preprocess.get_sorted_junction_ids()
j1_id = junction_ids[0] if len(junction_ids) >=1 else 1
j2_id = junction_ids[1] if len(junction_ids) >=2 else 2
return [[j1_id, act_data.phase_index_1, duration1], [j2_id, act_data.phase_index_2, duration2]]
# ------------------------------ 学习与模型保存 ------------------------------
@learn_wrapper
def learn(self, list_sample_data):
"""调用算法学习"""
return self.algorithm.learn(list_sample_data)
@save_model_wrapper
def save_model(self, path=None, id="1"):
"""保存模型(确保CPU兼容)"""
model_path = f"{path}/model.ckpt-{id}.pkl"
state_dict = {k: v.clone().cpu() for k, v in self.model.state_dict().items()}
torch.save(state_dict, model_path)
self.logger.info(f"保存模型到{model_path}")
@load_model_wrapper
def load_model(self, path=None, id="1"):
"""加载模型"""
model_path = f"{path}/model.ckpt-{id}.pkl"
state_dict = torch.load(model_path, map_location=self.device)
self.model.load_state_dict(state_dict)
self.logger.info(f"从{model_path}加载模型")
class FeatureProcess:
def __init__(self, logger, usr_conf=None):
self.logger = logger
self.usr_conf = usr_conf or {}
self.reset()
def reset(self):
# 1. 静态道路信息(统一用 self.junction_dict 存储路口)
self.junction_dict = {} # 核心:存储路口信息(j_id -> 路口数据)
self.edge_dict = {}
self.lane_dict = {}
self.vehicle_configs = {}
self.l_id_to_index = {}
self.lane_to_junction = {}
self.phase_lane_mapping = {}
# 2. 动态帧信息
self.vehicle_status = {}
self.lane_volume = {} # 车道ID -> 车流量(整数,协议v_count)
self.lane_congestion = {}
self.current_phases = {} # s_id -> {remaining_duration, phase_id}
self.lane_demand = {}
# 3. 车辆历史轨迹数据
self.vehicle_prev_junction = {}
self.vehicle_prev_position = {}
self.vehicle_distance_store = {}
self.last_waiting_moment = {}
self.waiting_time_store = {}
self.enter_lane_time = {}
self.vehicle_enter_time = {}
self.vehicle_ideal_time = {}
self.current_vehicles = {} # 车辆ID -> 完整信息
self.vehicle_trajectory = {}
# 4. 场景状态
self.peak_hour = False
self.weather = 0 # 0=晴,1=雨,2=雪,3=雾
self.accident_lanes = set()
self.control_lanes = set()
self.accident_configs = []
self.control_configs = []
# 5. 路口级指标累计
self.junction_metrics = {} # j_id -> 指标(delay、waiting等)
self.enter_lane_ids = set()
# 区域信息
self.region_dict = {} # region_id -> [j_id1, j_id2...]
self.region_capacity = {} # region_id -> 总容量
self.junction_region_map = {} # j_id -> region_id
def init_road_info(self, start_info):
"""完全兼容协议的初始化方法:确保所有数据存储到正确字段"""
junctions = start_info.get("junctions", [])
signals = start_info.get("signals", [])
edges = start_info.get("edges", [])
lane_configs = start_info.get("lane_configs", [])
vehicle_configs = start_info.get("vehicle_configs", [])
# 1. 初始化路口映射(关键:给 self.junction_dict 赋值)
self._init_junction_mapping(junctions)
# 2. 初始化相位-车道映射(依赖 self.junction_dict)
self._init_phase_mapping(signals)
# 3. 处理车道配置
self._init_lane_configs(lane_configs)
# 4. 处理车辆配置(确保包含默认配置)
self._init_vehicle_configs(vehicle_configs)
# 5. 初始化区域(依赖 self.junction_dict)
self._init_protocol_regions()
# 6. 初始化场景配置
self._init_scene_config()
self.logger.info(f"道路信息初始化完成:路口{len(self.junction_dict)}个,车辆配置{len(self.vehicle_configs)}个")
def _init_junction_mapping(self, junctions):
"""修正:将路口数据存储到 self.junction_dict,处理缺失j_id"""
if not junctions:
self.logger.warning("路口列表为空,初始化默认路口(j1、j2)")
# 兜底:添加默认路口(避免后续逻辑空指针)
self.junction_dict = {
"j1": {"j_id": "j1", "signal": 0, "cached_enter_lanes": [], "enter_lanes_on_directions": []},
"j2": {"j_id": "j2", "signal": 1, "cached_enter_lanes": [], "enter_lanes_on_directions": []}
}
return
for idx, junction in enumerate(junctions):
# 处理缺失j_id:生成默认j_id(j_原始id 或 j_idx)
if "j_id" not in junction:
if "id" in junction:
junction["j_id"] = f"j_{junction['id']}"
else:
junction["j_id"] = f"j_{idx}"
self.logger.debug(f"为路口{idx}生成默认j_id: {junction['j_id']}")
j_id = junction["j_id"]
# 提取并缓存进口车道(协议字段:enter_lanes_on_directions)
all_enter_lanes = []
for dir_info in junction.get("enter_lanes_on_directions", []):
all_enter_lanes.extend(dir_info.get("lanes", []))
junction["cached_enter_lanes"] = all_enter_lanes # 缓存进口车道列表
# 初始化路口指标
self.junction_metrics[j_id] = {
"total_delay": 0.0, "total_vehicles": 0, "total_waiting": 0.0,
"total_queue": 0, "queue_count": 0, "counted_vehicles": set(), "completed_vehicles": set()
}
# 存储到核心字典(供其他方法调用)
self.junction_dict[j_id] = junction
self.logger.debug(f"加载路口{j_id}:进口车道{len(all_enter_lanes)}条,关联信号灯{junction.get('signal', -1)}")
def _init_phase_mapping(self, signals):
"""修正:基于 self.junction_dict 构建信号灯-路口映射,解析相位-车道"""
self.phase_lane_mapping = {}
# 从已初始化的路口字典中构建:信号灯ID -> 路口ID(signal_junction_map)
signal_junction_map = {}
for j_id, junction in self.junction_dict.items():
s_id = junction.get("signal", -1)
if s_id != -1:
signal_junction_map[s_id] = j_id
for signal in signals:
s_id = signal.get("s_id", -1)
if s_id == -1:
self.logger.warning("信号灯缺失s_id,跳过")
continue
# 关联路口(从signal_junction_map获取,而非原始junctions)
j_id = signal_junction_map.get(s_id)
if not j_id or j_id not in self.junction_dict:
self.logger.warning(f"信号灯{s_id}未关联有效路口,跳过")
continue
junction = self.junction_dict[j_id]
all_enter_lanes = junction["cached_enter_lanes"]
self.phase_lane_mapping[s_id] = {} # s_id -> phase_idx -> [lane_id1...]
for phase_idx, phase in enumerate(signal.get("phases", [])):
controlled_lanes = []
for light_cfg in phase.get("lights_on_configs", []):
green_mask = light_cfg.get("green_mask", 0)
turns = self._mask_to_turns(green_mask) # 解析转向(直/左/右/掉头)
for turn in turns:
# 按转向匹配车道
controlled_lanes.extend(self._get_turn_lanes(junction, turn))
# 过滤:仅保留进口车道
valid_lanes = list(set(controlled_lanes) & set(all_enter_lanes))
self.phase_lane_mapping[s_id][phase_idx] = valid_lanes
self.logger.debug(f"信号灯{s_id}相位{phase_idx}:控制车道{valid_lanes}")
def _init_lane_configs(self, lane_configs):
"""基于协议DirectionMask解析车道转向类型"""
self.lane_dict = {}
for lane in lane_configs:
l_id = lane.get("l_id", -1)
if l_id == -1:
self.logger.warning("车道缺失l_id,跳过")
continue
# 解析转向类型(协议:dir_mask)
dir_mask = lane.get("dir_mask", 0)
turn_type = 0 # 0=直行,1=左转,2=右转,3=掉头
if dir_mask & 2:
turn_type = 1
elif dir_mask & 4:
turn_type = 2
elif dir_mask & 8:
turn_type = 3
# 存储车道信息
self.lane_dict[l_id] = {
"l_id": l_id, "edge_id": lane.get("edge_id", 0),
"length": lane.get("length", 100), "width": lane.get("width", 3),
"turn_type": turn_type
}
def _init_vehicle_configs(self, vehicle_configs):
"""初始化车辆配置,确保包含默认配置(避免KeyError)"""
self.vehicle_configs = {}
# 协议VehicleType默认配置(覆盖常见v_config_id)
default_configs = {
1: {"v_type": 1, "v_type_name": "CAR", "max_speed": 60, "length": 5},
2: {"v_type": 2, "v_type_name": "BUS", "max_speed": 40, "length": 12},
3: {"v_type": 3, "v_type_name": "TRUCK", "max_speed": 50, "length": 10},
4: {"v_type": 4, "v_type_name": "MOTORCYCLE", "max_speed": 55, "length": 2},
5: {"v_type": 5, "v_type_name": "BICYCLE", "max_speed": 15, "length": 1}
}
# 加载传入的配置,覆盖默认值
for cfg in vehicle_configs:
cfg_id = cfg.get("v_config_id")
if cfg_id is None:
self.logger.warning("车辆配置缺失v_config_id,跳过")
continue
self.vehicle_configs[cfg_id] = {
"v_type": cfg.get("v_type", default_configs.get(cfg_id, {}).get("v_type", 0)),
"v_type_name": cfg.get("v_type_name", default_configs.get(cfg_id, {}).get("v_type_name", "Unknown")),
"max_speed": cfg.get("max_speed", default_configs.get(cfg_id, {}).get("max_speed", 60)),
"length": cfg.get("length", default_configs.get(cfg_id, {}).get("length", 5))
}
# 补充未覆盖的默认配置(确保关键ID存在)
for cfg_id, cfg in default_configs.items():
if cfg_id not in self.vehicle_configs:
self.vehicle_configs[cfg_id] = cfg
self.logger.debug(f"补充默认车辆配置:v_config_id={cfg_id}")
def _init_protocol_regions(self):
"""基于 self.junction_dict 初始化区域(每个路口一个区域)"""
self.region_dict = {}
self.region_capacity = {}
self.junction_region_map = {}
for j_id, junction in self.junction_dict.items():
region_id = f"region_{j_id}"
self.region_dict[region_id] = [j_id]
self.junction_region_map[j_id] = region_id
# 计算区域容量(进口车道总容量)
total_cap = sum(self.get_lane_capacity(lane_id) for lane_id in junction["cached_enter_lanes"])
self.region_capacity[region_id] = total_cap if total_cap > 0 else 20
def _init_scene_config(self):
"""初始化场景参数(天气、高峰期、事故/管制)"""
self.weather = self.usr_conf.get("weather", 0)
self.peak_hour = self.usr_conf.get("rush_hour", 0) == 1
# 处理事故配置
self.accident_configs = [
cfg for cfg in self.usr_conf.get("traffic_accidents", {}).get("custom_configuration", [])
if all(k in cfg for k in ["lane_index", "start_time", "end_time"])
]
# 处理管制配置
self.control_configs = [
cfg for cfg in self.usr_conf.get("traffic_control", {}).get("custom_configuration", [])
if all(k in cfg for k in ["lane_index", "start_time", "end_time"])
]
def _mask_to_turns(self, green_mask):
"""解析green_mask为转向类型"""
turns = []
if green_mask & 1:
turns.append("straight")
if green_mask & 2:
turns.append("left")
if green_mask & 4:
turns.append("right")
if green_mask & 8:
turns.append("uturn")
return turns
def _get_turn_lanes(self, junction, turn):
"""根据转向类型获取车道(依赖车道的turn_type)"""
turn_map = {"straight": 0, "left": 1, "right": 2, "uturn": 3}
target_turn = turn_map.get(turn)
if target_turn is None:
return []
# 从进口车道中匹配转向
return [
lane_id for lane_id in junction["cached_enter_lanes"]
if self.lane_dict.get(lane_id, {}).get("turn_type") == target_turn
]
# ------------------------------ 动态数据更新 ------------------------------
def update_traffic_info(self, obs, extra_info):
if "framestate" not in obs:
self.logger.error("观测数据缺失framestate,跳过更新")
return
frame_state = obs["framestate"]
frame_no = frame_state.get("frame_no", 0)
frame_time = frame_state.get("frame_time", 0) / 1000.0 # 转秒
# 1. 更新相位信息(s_id -> 相位状态)
self.current_phases.clear()
for phase_info in frame_state.get("phases", []):
s_id = phase_info.get("s_id", -1)
if s_id == -1:
continue
self.current_phases[s_id] = {
"remaining_duration": phase_info.get("remaining_duration", 0),
"phase_id": phase_info.get("phase_id", 0)
}
# 2. 更新车道车流量(协议v_count)
self.lane_volume.clear()
for lane in frame_state.get("lanes", []):
l_id = lane.get("lane_id", -1)
if l_id == -1:
continue
self.lane_volume[l_id] = lane.get("v_count", 0)
# 3. 更新车辆信息
vehicles = frame_state.get("vehicles", [])
current_v_ids = set()
self.current_vehicles.clear()
self.vehicle_status.clear()
for vehicle in vehicles:
v_id = vehicle.get("v_id")
if not v_id:
continue
current_v_ids.add(v_id)
self.current_vehicles[v_id] = vehicle
self.vehicle_status[v_id] = vehicle.get("v_status", 0) # 0=正常,1=事故,2=无规则
# 4. 清理过期车辆数据
self._clean_expired_vehicle_data(current_v_ids)
# 5. 更新车辆位置、等待时间、行驶距离
for vehicle in vehicles:
v_id = vehicle.get("v_id")
if not v_id:
continue
if "position_in_lane" in vehicle:
self._update_vehicle_position(vehicle, frame_time)
self.cal_waiting_time(frame_time, vehicle)
self.cal_travel_distance(vehicle)
# 6. 更新事故/管制车道
self._update_active_invalid_lanes(frame_no)
# 7. 计算车道需求
self.calculate_lane_demand()
def _update_vehicle_position(self, vehicle, frame_time):
v_id = vehicle["v_id"]
current_pos = vehicle["position_in_lane"]
if v_id in self.vehicle_prev_position:
# 计算行驶距离
prev_pos = self.vehicle_prev_position[v_id]
dx = current_pos["x"] - prev_pos["x"]
dy = current_pos["y"] - prev_pos["y"]
self.vehicle_distance_store[v_id] = self.vehicle_distance_store.get(v_id, 0.0) + math.hypot(dx, dy)
self.vehicle_prev_position[v_id] = current_pos
def _clean_expired_vehicle_data(self, current_v_ids):
"""清理不在当前帧的车辆数据"""
expired_v_ids = set(self.vehicle_prev_position.keys()) - current_v_ids
for v_id in expired_v_ids:
self.vehicle_prev_position.pop(v_id, None)
self.vehicle_distance_store.pop(v_id, None)
self.waiting_time_store.pop(v_id, None)
self.last_waiting_moment.pop(v_id, None)
def _update_active_invalid_lanes(self, frame_no):
"""更新当前生效的事故/管制车道"""
self.accident_lanes = set()
for cfg in self.accident_configs:
if cfg["start_time"] <= frame_no <= cfg["end_time"]:
self.accident_lanes.add(cfg["lane_index"])
self.control_lanes = set()
for cfg in self.control_configs:
if cfg["start_time"] <= frame_no <= cfg["end_time"]:
self.control_lanes.add(cfg["lane_index"])
# ------------------------------ 指标计算 ------------------------------
def calculate_lane_demand(self):
"""计算车道需求(车流量 + 接近停止线的车辆)"""
self.lane_demand = {}
for lane_id, count in self.lane_volume.items():
demand = count
# 叠加接近停止线的车辆(y<100)
for vehicle in self.current_vehicles.values():
if vehicle.get("lane") == lane_id:
y_pos = vehicle.get("position_in_lane", {}).get("y", 200)
if y_pos < 100:
demand += 0.5
self.lane_demand[lane_id] = demand
def calculate_junction_queue(self, j_id):
"""计算路口排队长度(速度≤1m/s的进口道车辆)"""
junction = self.junction_dict.get(j_id)
if not junction:
return 0
valid_lanes = set(junction["cached_enter_lanes"]) - self.get_invalid_lanes()
queue = 0
for vehicle in self.current_vehicles.values():
if (vehicle.get("lane") in valid_lanes
and self.vehicle_status.get(vehicle["v_id"], 0) == 0
and vehicle.get("speed", 0) <= 1.0):
queue += 1
return queue
def cal_waiting_time(self, frame_time, vehicle):
"""计算车辆等待时间(进口道内、速度≤0.1m/s、接近停止线)"""
v_id = vehicle["v_id"]
if not self.on_enter_lane(vehicle):
self.waiting_time_store.pop(v_id, None)
self.last_waiting_moment.pop(v_id, None)
return
# 满足等待条件:速度低 + 接近停止线
if vehicle.get("speed", 0) <= 0.1 and vehicle["position_in_lane"]["y"] < 50:
if v_id not in self.last_waiting_moment:
self.last_waiting_moment[v_id] = frame_time
else:
duration = frame_time - self.last_waiting_moment[v_id]
self.waiting_time_store[v_id] = self.waiting_time_store.get(v_id, 0.0) + duration
self.last_waiting_moment[v_id] = frame_time
else:
self.last_waiting_moment.pop(v_id, None)
def cal_travel_distance(self, vehicle):
"""计算车辆行驶距离"""
v_id = vehicle["v_id"]
if not self.on_enter_lane(vehicle):
self.vehicle_prev_position.pop(v_id, None)
self.vehicle_distance_store.pop(v_id, None)
return
# 初始化历史位置
if v_id not in self.vehicle_prev_position:
current_pos = vehicle.get("position_in_lane", {})
if "x" in current_pos and "y" in current_pos:
self.vehicle_prev_position[v_id] = {
"x": current_pos["x"], "y": current_pos["y"],
"distance_to_stop": current_pos["y"]
}
self.vehicle_distance_store[v_id] = 0.0
return
# 计算距离
prev_pos = self.vehicle_prev_position[v_id]
current_pos = vehicle["position_in_lane"]
try:
dx = current_pos["x"] - prev_pos["x"]
dy = current_pos["y"] - prev_pos["y"]
euclid_dist = math.hypot(dx, dy)
stop_dist_reduce = prev_pos["distance_to_stop"] - current_pos["y"]
self.vehicle_distance_store[v_id] += max(euclid_dist, stop_dist_reduce)
self.vehicle_prev_position[v_id]["distance_to_stop"] = current_pos["y"]
except Exception as e:
self.logger.error(f"计算车辆{v_id}距离失败: {str(e)}")
def on_enter_lane(self, vehicle):
"""判断车辆是否在进口道(依赖路口的cached_enter_lanes)"""
lane_id = vehicle.get("lane")
if lane_id is None:
return False
# 遍历所有路口的进口道,判断车道是否属于进口道
for junction in self.junction_dict.values():
if lane_id in junction["cached_enter_lanes"]:
return True
return False
# ------------------------------ 外部调用接口 ------------------------------
def get_sorted_junction_ids(self):
"""获取排序后的路口ID列表(供Agent使用)"""
return sorted(self.junction_dict.keys())
def get_junction_by_signal(self, s_id):
"""根据信号灯ID获取路口(供Agent解析相位)"""
for j_id, junction in self.junction_dict.items():
if junction.get("signal") == s_id:
return junction
return None
def get_invalid_lanes(self):
"""获取当前无效车道(事故+管制)"""
return self.accident_lanes.union(self.control_lanes)
def get_lane_capacity(self, lane_id):
"""获取车道容量(根据转向类型)"""
turn_type = self.lane_dict.get(lane_id, {}).get("turn_type", 0)
return 15 if turn_type in [0, 1] else 10 # 直行/左转15辆,右转/掉头10辆
def get_weather(self):
return self.weather
def is_peak_hour(self):
return self.peak_hour
def get_vehicle_max_speed(self, v_config_id):
"""获取车辆最大速度(容错接口)"""
return self.vehicle_configs.get(v_config_id, {}).get("max_speed", 60)检查这些代码哪里不匹配,有没有地方不合理。修正一下
最新发布