Bus Stop(贪心)

博客围绕泰国乡村一条道路上公交站设置问题展开。政府要让每个房子距公交站不超10公里,需确定最少公交站数量。给出了题目描述、输入输出要求及样例输入等信息,属于算法相关问题。

题目链接 

[提交] [状态] [命题人:admin]

题目描述

In a rural village in Thailand, there is a long, straight, road with houses scattered along it.  
(We can picture the road as a long line segment, with an eastern endpoint and a western endpoint.)  The Thai government is planning to set up bus stops on this road in such a way that every house is within ten kilometers of one of the stops.  What is the minimum number of stops the government need to set up to satisfy the requirement? 
 

 

输入

The first line contains a single integer m representing the number of the test cases.  Eachtest case consists of the following two lines:   
The first line contains a single integer n representing the number of houses on the road, where 0 ≤ n ≤ 2,000,000. 
The second line contains n integers h1  h2  … hn  representing the locations of the houses in kilometers as measured from the start of the road, where 0 ≤ hi  ≤ 90,000,000.  That is, the first house is h1  kilometers away from the start of the road, the second house is h2  kilometers 
away from the start of the road, and so on.  You may assume that hi ’s are sorted in ascending order, e.g., h1  ≤ h2  ≤ h3  ≤ … ≤ hn . 

 

输出

For each test case, print out in a single line the minimum number of bus stops that satisfy the requirement. 

 

样例输入

复制样例数据

2
5
1 2 3 200 210
4
10 30 80 200

样例输出

2
3
#include<iostream>
#include<map>
using namespace std;
typedef long long ll;
map<ll,ll>mapp;
ll a[2000010];
int main()
{
    ll t;
    cin>>t;
    while(t--)
    {
        ll n;
        cin>>n;
        if(n==0)
            cout<<"0"<<endl;
        else
        {
            ll p;
            ll s;
            cin>>s;
            ll sum=1;
            for(ll i=1; i<n; i++)
            {
                cin>>p;
                if(p>s+20)
                {
                    sum++;
                    s=p;
                }
                else
                {
                    continue;
                }
            }
            cout<<sum<<endl;
        }
    }
 
 
    return 0;
}

 

class Algorithm: def __init__(self, model, optimizer, device=None, logger=None, monitor=None): self.device = device or torch.device("cpu") self.model = model self.optim = optimizer self.logger = logger self.monitor = monitor # 算法参数(与Config一致) self.num_head = Config.NUMB_HEAD # 4 self._gamma = Config.GAMMA # 0.99 self.target_update_freq = Config.TARGET_UPDATE_FREQ # 500步 # 目标网络 self.target_model = deepcopy(self.model) self.target_model.eval() # 训练状态 self.train_step = 0 self.last_report_time = time.time() def learn(self, list_sample_data): """学习逻辑:按head分别处理legal_action,避免堆叠""" if len(list_sample_data) == 0: self.logger.warning("样本为空,跳过学习") return {"loss": 0.0} # ------------------------------ 1. 提取样本数据(按head处理legal_action) ------------------------------ # 提取非legal_action数据(可直接stack) obs_list = [] action_list = [] rew_list = [] next_obs_list = [] not_done_list = [] # 提取legal_action(按head存储:head_idx -> [样本1的掩码, 样本2的掩码, ...]) legal_action_per_head = [[] for _ in range(self.num_head)] for sample in list_sample_data: # 基础数据 obs_list.append(torch.tensor(sample.obs, dtype=torch.float32, device=self.device)) action_list.append(torch.tensor(sample.act, dtype=torch.float32, device=self.device)) rew_list.append(torch.tensor(sample.rew, dtype=torch.float32, device=self.device)) next_obs_list.append(torch.tensor(sample._obs, dtype=torch.float32, device=self.device)) not_done_list.append(torch.tensor([sample.done], dtype=torch.float32, device=self.device)) # 处理legal_action(样本的legal_action是4个head的列表) legal = sample.legal_action for head_idx in range(self.num_head): if head_idx < len(legal): # 转换为张量并添加到对应head的列表 legal_tensor = torch.tensor(legal[head_idx], dtype=torch.float32, device=self.device) legal_action_per_head[head_idx].append(legal_tensor) # 堆叠基础数据 obs = torch.stack(obs_list) # (batch, obs_dim) action = torch.stack(action_list) # (batch, num_head) rew = torch.stack(rew_list) # (batch, num_head) next_obs = torch.stack(next_obs_list) # (batch, obs_dim) not_done = torch.stack(not_done_list).squeeze(1) # (batch,) # 堆叠legal_action(每个head单独堆叠:(batch, action_dim_per_head)) for head_idx in range(self.num_head): legal_action_per_head[head_idx] = torch.stack(legal_action_per_head[head_idx]) # ------------------------------ 2. 计算目标Q值(DDQN逻辑) ------------------------------ self.target_model.eval() self.model.eval() q_targets = [] with torch.no_grad(): for head_idx in range(self.num_head): # 2.1 当前网络选择下一状态的最佳动作(考虑合法动作) current_q_next = self.model(next_obs)[head_idx] # (batch, action_dim_per_head) legal_mask = legal_action_per_head[head_idx] # (batch, action_dim_per_head) # 非法动作Q值设为-1e10(不被选中) current_q_next_masked = current_q_next + (1 - legal_mask) * (-1e10) best_actions = torch.argmax(current_q_next_masked, dim=1, keepdim=True) # (batch, 1) # 2.2 目标网络评估最佳动作的价值 target_q_next = self.target_model(next_obs)[head_idx] # (batch, action_dim_per_head) target_q_best = target_q_next.gather(1, best_actions) # (batch, 1) # 2.3 计算目标Q值:rew + gamma * target_q_best * not_done rew_head = rew[:, head_idx].unsqueeze(1) # (batch, 1) q_target_head = rew_head + self._gamma * target_q_best * not_done.unsqueeze(1) q_targets.append(q_target_head) # 拼接所有head的目标Q值:(batch, num_head) q_targets = torch.cat(q_targets, dim=1) # ------------------------------ 3. 计算当前Q值与损失 ------------------------------ self.model.train() q_values = [] for head_idx in range(self.num_head): # 提取当前动作对应的Q值 current_q = self.model(obs)[head_idx] # (batch, action_dim_per_head) # 动作索引:(batch, 1)(确保是long类型) action_idx = action[:, head_idx].long().unsqueeze(1) # 提取当前动作的Q值 q_value_head = current_q.gather(1, action_idx) # (batch, 1) q_values.append(q_value_head) # 拼接当前Q值:(batch, num_head) q_values = torch.cat(q_values, dim=1) # 计算MSE损失(按head累加) loss = 0.0 for head_idx in range(self.num_head): loss += F.mse_loss(q_values[:, head_idx], q_targets[:, head_idx]) # ------------------------------ 4. 优化与目标网络更新 ------------------------------ self.optim.zero_grad() loss.backward() # 梯度裁剪(防止梯度爆炸) grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0).item() self.optim.step() # 定期更新目标网络 self.train_step += 1 if self.train_step % self.target_update_freq == 0: self.target_model.load_state_dict(self.model.state_dict()) self.logger.info(f"训练步数{self.train_step}:更新目标网络") # ------------------------------ 5. 日志与监控 ------------------------------ # 定期上报监控数据(每30秒) now = time.time() if now - self.last_report_time >= 30: avg_q = q_values.mean().item() avg_target_q = q_targets.mean().item() monitor_data = { "value_loss": loss.item(), "avg_q_value": avg_q, "avg_target_q_value": avg_target_q, "grad_norm": grad_norm, "train_step": self.train_step } if self.monitor: self.monitor.put_data({os.getpid(): monitor_data}) self.logger.info( f"损失: {loss.item():.4f}, 平均Q值: {avg_q:.4f}, " f"平均目标Q值: {avg_target_q:.4f}, 梯度范数: {grad_norm:.4f}" ) self.last_report_time = now return {"loss": loss.item(), "grad_norm": grad_norm} class Agent(BaseAgent): def __init__(self, agent_type="player", device=None, logger=None, monitor=None): super().__init__(agent_type, device, logger, monitor) self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = Model(device=self.device) # 优化器与探索参数 self.optim = torch.optim.RMSprop(self.model.parameters(), lr=Config.LR) self._eps = Config.START_EPSILON_GREEDY self.end_eps = Config.END_EPSILON_GREEDY self.eps_decay = Config.EPSILON_DECAY # 动作维度(与Config一致) self.head_dim = [ Config.DIM_OF_ACTION_PHASE_1, Config.DIM_OF_ACTION_DURATION_1, Config.DIM_OF_ACTION_PHASE_2, Config.DIM_OF_ACTION_DURATION_2, ] self.num_head = Config.NUMB_HEAD # 4 # 预处理模块(核心依赖) self.preprocess = FeatureProcess(logger, usr_conf={}) self.road_info_initialized = False # 标记道路信息是否初始化 self.last_action = None # 算法模块 from your_algorithm import Algorithm # 替换为实际Algorithm导入 self.algorithm = Algorithm( self.model, self.optim, self.device, self.logger, self.monitor ) def reset(self): self.preprocess.reset() self._eps = Config.START_EPSILON_GREEDY def init_road_info(self, extra_info): """初始化道路信息(从extra_info加载静态配置)""" if self.road_info_initialized: return game_info = extra_info.get("gameinfo", {}) self.preprocess.init_road_info(game_info) self.road_info_initialized = True self.logger.info("Agent道路信息初始化完成") # ------------------------------ 预测与动作处理 ------------------------------ def __predict_detail(self, list_obs_data, exploit_flag=False): """内部预测逻辑:生成4个head的动作""" feature = [torch.tensor(obs.feature, dtype=torch.float32, device=self.device) for obs in list_obs_data] feature = torch.stack(feature) self.model.eval() self._eps = max(self.end_eps, self._eps * self.eps_decay) with torch.no_grad(): if np.random.rand() >= self._eps or exploit_flag: # 贪心选择:取Q值最大的动作 res = self.model(feature) # 假设model输出为 tuple([4个head的Q值]) list_phase_1 = torch.argmax(res[0], dim=1).cpu().tolist() list_duration_1 = torch.argmax(res[1], dim=1).cpu().tolist() list_phase_2 = torch.argmax(res[2], dim=1).cpu().tolist() list_duration_2 = torch.argmax(res[3], dim=1).cpu().tolist() else: # 随机探索:按动作维度随机选择 list_phase_1 = np.random.choice(self.head_dim[0], len(list_obs_data)).tolist() list_duration_1 = np.random.choice(self.head_dim[1], len(list_obs_data)).tolist() list_phase_2 = np.random.choice(self.head_dim[2], len(list_obs_data)).tolist() list_duration_2 = np.random.choice(self.head_dim[3], len(list_obs_data)).tolist() # 构建ActData列表 return [ ActData( phase_index_1=list_phase_1[i], duration_1=list_duration_1[i], phase_index_2=list_phase_2[i], duration_2=list_duration_2[i], ) for i in range(len(list_obs_data)) ] @predict_wrapper def predict(self, list_obs_data): return self.__predict_detail(list_obs_data, exploit_flag=False) @exploit_wrapper def exploit(self, observation): """推理模式:生成环境可执行的动作""" # 初始化道路信息(首次调用时) if not self.road_info_initialized: self.init_road_info(observation["extra_info"]) # 处理观测 obs_data = self.observation_process(observation["obs"], observation["extra_info"]) if not obs_data: self.logger.warning("观测处理失败,返回默认动作") return [[1, 0, 1], [2, 0, 1]] # 默认动作:路口1相位0时长1,路口2相位0时长1 # 预测动作 act_data = self.__predict_detail([obs_data], exploit_flag=True)[0] return self.action_process(act_data) # ------------------------------ 观测处理(核心修正) ------------------------------ def observation_process(self, obs, extra_info): """处理观测:生成feature和合法动作掩码""" # 1. 初始化道路信息(首次调用) if not self.road_info_initialized: self.init_road_info(extra_info) # 2. 更新预处理模块的动态数据 self.preprocess.update_traffic_info(obs, extra_info) frame_state = obs.get("framestate", {}) vehicles = frame_state.get("vehicles", []) phases = frame_state.get("phases", []) # 3. 栅格化处理(位置+速度) grid_w, grid_num = Config.GRID_WIDTH, Config.GRID_NUM speed_dict = np.zeros((grid_w, grid_num), dtype=np.float32) position_dict = np.zeros((grid_w, grid_num), dtype=np.int32) # 4. 遍历车辆填充栅格(容错处理) for vehicle in vehicles: if not self.preprocess.on_enter_lane(vehicle): continue # 仅处理进口道车辆 # 计算x_pos(车道编码) try: x_pos = get_lane_code(vehicle) # 假设该函数返回0-25的编码 if x_pos < 0 or x_pos >= grid_w: continue # 计算y_pos(栅格索引) pos_in_lane = vehicle.get("position_in_lane", {}) y_coord = pos_in_lane.get("y", 0) y_pos = int(y_coord // Config.GRID_LENGTH) if y_pos < 0 or y_pos >= grid_num: continue # 计算归一化速度(调用FeatureProcess的容错接口) v_config_id = vehicle.get("v_config_id", 1) max_speed = self.preprocess.get_vehicle_max_speed(v_config_id) vehicle_speed = vehicle.get("speed", 0) normalized_speed = vehicle_speed / max_speed if max_speed != 0 else 0.0 # 填充栅格 speed_dict[x_pos, y_pos] = normalized_speed position_dict[x_pos, y_pos] = 1 except Exception as e: v_id = vehicle.get("v_id", "未知") self.logger.warning(f"处理车辆{v_id}栅格化失败: {str(e)}") continue # 5. 解析当前相位(按信号灯ID匹配,非硬编码) # 5.1 路口1相位(对应Config.DIM_OF_ACTION_PHASE_1=4) j1 = self.preprocess.get_junction_by_signal(0) # 信号灯0对应路口1 current_phase1 = 0 for phase in phases: if phase.get("s_id") == 0: current_phase1 = phase.get("phase_id", 0) jct_phase_1 = [0] * Config.DIM_OF_ACTION_PHASE_1 if 0 <= current_phase1 < Config.DIM_OF_ACTION_PHASE_1: jct_phase_1[current_phase1] = 1 # 5.2 路口2相位(对应Config.DIM_OF_ACTION_PHASE_2=3) j2 = self.preprocess.get_junction_by_signal(1) # 信号灯1对应路口2 current_phase2 = 0 for phase in phases: if phase.get("s_id") == 1: current_phase2 = phase.get("phase_id", 0) jct_phase_2 = [0] * Config.DIM_OF_ACTION_PHASE_2 if 0 <= current_phase2 < Config.DIM_OF_ACTION_PHASE_2: jct_phase_2[current_phase2] = 1 # 6. 生成合法动作掩码(4个head,列表形式,不转张量) legal_action = self._generate_legal_action(current_phase1, current_phase2) # 7. 构建feature(确保维度与Config.DIM_OF_OBSERVATION=1056一致) position_flat = position_dict.flatten().tolist() # 26*20=520 speed_flat = speed_dict.flatten().tolist() # 26*20=520 weather_feature = [self.preprocess.get_weather()] # 1(天气编码) peak_feature = [1 if self.preprocess.is_peak_hour() else 0] # 1(高峰期标记) # 拼接:520+520+4+3+1+1=1049 → 补充7个0凑够1056(可替换为其他特征) padding = [0] * (Config.DIM_OF_OBSERVATION - len(position_flat + speed_flat + jct_phase_1 + jct_phase_2 + weather_feature + peak_feature)) feature = position_flat + speed_flat + jct_phase_1 + jct_phase_2 + weather_feature + peak_feature + padding # 8. 构建ObsData(携带legal_action,不加入feature) obs_data = ObsData(feature=feature) obs_data.legal_action = legal_action # 关键:传递4个head的合法掩码列表 return obs_data def _generate_legal_action(self, current_phase1, current_phase2): """生成4个head的合法动作掩码(列表形式,长度分别为4、30、3、30)""" legal = [] # Head0:路口1相位(禁止当前相位) phase1_dim = Config.DIM_OF_ACTION_PHASE_1 legal_phase1 = [1] * phase1_dim if 0 <= current_phase1 < phase1_dim: legal_phase1[current_phase1] = 0 # 禁止重复选择当前相位 legal.append(legal_phase1) # Head1:路口1时长(1-30s,全部合法) duration1_dim = Config.DIM_OF_ACTION_DURATION_1 legal_duration1 = [1] * duration1_dim # 可选:限制最大时长(如≤20s) # for i in range(20, duration1_dim): # legal_duration1[i] = 0 legal.append(legal_duration1) # Head2:路口2相位(禁止当前相位) phase2_dim = Config.DIM_OF_ACTION_PHASE_2 legal_phase2 = [1] * phase2_dim if 0 <= current_phase2 < phase2_dim: legal_phase2[current_phase2] = 0 legal.append(legal_phase2) # Head3:路口2时长(全部合法) duration2_dim = Config.DIM_OF_ACTION_DURATION_2 legal_duration2 = [1] * duration2_dim legal.append(legal_duration2) return legal def action_process(self, act_data): """将ActData转换为环境可执行的动作格式:[[j_id, phase, duration], ...]""" # 时长+1(动作从0开始,环境需要1-30s) duration1 = act_data.duration_1 + 1 duration2 = act_data.duration_2 + 1 # 限制时长在合理范围(1-Config.MAX_GREEN_DURATION) duration1 = max(1, min(duration1, Config.MAX_GREEN_DURATION)) duration2 = max(1, min(duration2, Config.MAX_GREEN_DURATION)) # 路口ID:从FeatureProcess获取(或默认j1→1,j2→2) junction_ids = self.preprocess.get_sorted_junction_ids() j1_id = junction_ids[0] if len(junction_ids) >=1 else 1 j2_id = junction_ids[1] if len(junction_ids) >=2 else 2 return [[j1_id, act_data.phase_index_1, duration1], [j2_id, act_data.phase_index_2, duration2]] # ------------------------------ 学习与模型保存 ------------------------------ @learn_wrapper def learn(self, list_sample_data): """调用算法学习""" return self.algorithm.learn(list_sample_data) @save_model_wrapper def save_model(self, path=None, id="1"): """保存模型(确保CPU兼容)""" model_path = f"{path}/model.ckpt-{id}.pkl" state_dict = {k: v.clone().cpu() for k, v in self.model.state_dict().items()} torch.save(state_dict, model_path) self.logger.info(f"保存模型到{model_path}") @load_model_wrapper def load_model(self, path=None, id="1"): """加载模型""" model_path = f"{path}/model.ckpt-{id}.pkl" state_dict = torch.load(model_path, map_location=self.device) self.model.load_state_dict(state_dict) self.logger.info(f"从{model_path}加载模型") class FeatureProcess: def __init__(self, logger, usr_conf=None): self.logger = logger self.usr_conf = usr_conf or {} self.reset() def reset(self): # 1. 静态道路信息(统一用 self.junction_dict 存储路口) self.junction_dict = {} # 核心:存储路口信息(j_id -> 路口数据) self.edge_dict = {} self.lane_dict = {} self.vehicle_configs = {} self.l_id_to_index = {} self.lane_to_junction = {} self.phase_lane_mapping = {} # 2. 动态帧信息 self.vehicle_status = {} self.lane_volume = {} # 车道ID -> 车流量(整数,协议v_count) self.lane_congestion = {} self.current_phases = {} # s_id -> {remaining_duration, phase_id} self.lane_demand = {} # 3. 车辆历史轨迹数据 self.vehicle_prev_junction = {} self.vehicle_prev_position = {} self.vehicle_distance_store = {} self.last_waiting_moment = {} self.waiting_time_store = {} self.enter_lane_time = {} self.vehicle_enter_time = {} self.vehicle_ideal_time = {} self.current_vehicles = {} # 车辆ID -> 完整信息 self.vehicle_trajectory = {} # 4. 场景状态 self.peak_hour = False self.weather = 0 # 0=晴,1=雨,2=雪,3=雾 self.accident_lanes = set() self.control_lanes = set() self.accident_configs = [] self.control_configs = [] # 5. 路口级指标累计 self.junction_metrics = {} # j_id -> 指标(delay、waiting等) self.enter_lane_ids = set() # 区域信息 self.region_dict = {} # region_id -> [j_id1, j_id2...] self.region_capacity = {} # region_id -> 总容量 self.junction_region_map = {} # j_id -> region_id def init_road_info(self, start_info): """完全兼容协议的初始化方法:确保所有数据存储到正确字段""" junctions = start_info.get("junctions", []) signals = start_info.get("signals", []) edges = start_info.get("edges", []) lane_configs = start_info.get("lane_configs", []) vehicle_configs = start_info.get("vehicle_configs", []) # 1. 初始化路口映射(关键:给 self.junction_dict 赋值) self._init_junction_mapping(junctions) # 2. 初始化相位-车道映射(依赖 self.junction_dict) self._init_phase_mapping(signals) # 3. 处理车道配置 self._init_lane_configs(lane_configs) # 4. 处理车辆配置(确保包含默认配置) self._init_vehicle_configs(vehicle_configs) # 5. 初始化区域(依赖 self.junction_dict) self._init_protocol_regions() # 6. 初始化场景配置 self._init_scene_config() self.logger.info(f"道路信息初始化完成:路口{len(self.junction_dict)}个,车辆配置{len(self.vehicle_configs)}个") def _init_junction_mapping(self, junctions): """修正:将路口数据存储到 self.junction_dict,处理缺失j_id""" if not junctions: self.logger.warning("路口列表为空,初始化默认路口(j1、j2)") # 兜底:添加默认路口(避免后续逻辑空指针) self.junction_dict = { "j1": {"j_id": "j1", "signal": 0, "cached_enter_lanes": [], "enter_lanes_on_directions": []}, "j2": {"j_id": "j2", "signal": 1, "cached_enter_lanes": [], "enter_lanes_on_directions": []} } return for idx, junction in enumerate(junctions): # 处理缺失j_id:生成默认j_id(j_原始id 或 j_idx) if "j_id" not in junction: if "id" in junction: junction["j_id"] = f"j_{junction['id']}" else: junction["j_id"] = f"j_{idx}" self.logger.debug(f"为路口{idx}生成默认j_id: {junction['j_id']}") j_id = junction["j_id"] # 提取并缓存进口车道(协议字段:enter_lanes_on_directions) all_enter_lanes = [] for dir_info in junction.get("enter_lanes_on_directions", []): all_enter_lanes.extend(dir_info.get("lanes", [])) junction["cached_enter_lanes"] = all_enter_lanes # 缓存进口车道列表 # 初始化路口指标 self.junction_metrics[j_id] = { "total_delay": 0.0, "total_vehicles": 0, "total_waiting": 0.0, "total_queue": 0, "queue_count": 0, "counted_vehicles": set(), "completed_vehicles": set() } # 存储到核心字典(供其他方法调用) self.junction_dict[j_id] = junction self.logger.debug(f"加载路口{j_id}:进口车道{len(all_enter_lanes)}条,关联信号灯{junction.get('signal', -1)}") def _init_phase_mapping(self, signals): """修正:基于 self.junction_dict 构建信号灯-路口映射,解析相位-车道""" self.phase_lane_mapping = {} # 从已初始化的路口字典中构建:信号灯ID -> 路口ID(signal_junction_map) signal_junction_map = {} for j_id, junction in self.junction_dict.items(): s_id = junction.get("signal", -1) if s_id != -1: signal_junction_map[s_id] = j_id for signal in signals: s_id = signal.get("s_id", -1) if s_id == -1: self.logger.warning("信号灯缺失s_id,跳过") continue # 关联路口(从signal_junction_map获取,而非原始junctions) j_id = signal_junction_map.get(s_id) if not j_id or j_id not in self.junction_dict: self.logger.warning(f"信号灯{s_id}未关联有效路口,跳过") continue junction = self.junction_dict[j_id] all_enter_lanes = junction["cached_enter_lanes"] self.phase_lane_mapping[s_id] = {} # s_id -> phase_idx -> [lane_id1...] for phase_idx, phase in enumerate(signal.get("phases", [])): controlled_lanes = [] for light_cfg in phase.get("lights_on_configs", []): green_mask = light_cfg.get("green_mask", 0) turns = self._mask_to_turns(green_mask) # 解析转向(直/左/右/掉头) for turn in turns: # 按转向匹配车道 controlled_lanes.extend(self._get_turn_lanes(junction, turn)) # 过滤:仅保留进口车道 valid_lanes = list(set(controlled_lanes) & set(all_enter_lanes)) self.phase_lane_mapping[s_id][phase_idx] = valid_lanes self.logger.debug(f"信号灯{s_id}相位{phase_idx}:控制车道{valid_lanes}") def _init_lane_configs(self, lane_configs): """基于协议DirectionMask解析车道转向类型""" self.lane_dict = {} for lane in lane_configs: l_id = lane.get("l_id", -1) if l_id == -1: self.logger.warning("车道缺失l_id,跳过") continue # 解析转向类型(协议:dir_mask) dir_mask = lane.get("dir_mask", 0) turn_type = 0 # 0=直行,1=左转,2=右转,3=掉头 if dir_mask & 2: turn_type = 1 elif dir_mask & 4: turn_type = 2 elif dir_mask & 8: turn_type = 3 # 存储车道信息 self.lane_dict[l_id] = { "l_id": l_id, "edge_id": lane.get("edge_id", 0), "length": lane.get("length", 100), "width": lane.get("width", 3), "turn_type": turn_type } def _init_vehicle_configs(self, vehicle_configs): """初始化车辆配置,确保包含默认配置(避免KeyError)""" self.vehicle_configs = {} # 协议VehicleType默认配置(覆盖常见v_config_id) default_configs = { 1: {"v_type": 1, "v_type_name": "CAR", "max_speed": 60, "length": 5}, 2: {"v_type": 2, "v_type_name": "BUS", "max_speed": 40, "length": 12}, 3: {"v_type": 3, "v_type_name": "TRUCK", "max_speed": 50, "length": 10}, 4: {"v_type": 4, "v_type_name": "MOTORCYCLE", "max_speed": 55, "length": 2}, 5: {"v_type": 5, "v_type_name": "BICYCLE", "max_speed": 15, "length": 1} } # 加载传入的配置,覆盖默认值 for cfg in vehicle_configs: cfg_id = cfg.get("v_config_id") if cfg_id is None: self.logger.warning("车辆配置缺失v_config_id,跳过") continue self.vehicle_configs[cfg_id] = { "v_type": cfg.get("v_type", default_configs.get(cfg_id, {}).get("v_type", 0)), "v_type_name": cfg.get("v_type_name", default_configs.get(cfg_id, {}).get("v_type_name", "Unknown")), "max_speed": cfg.get("max_speed", default_configs.get(cfg_id, {}).get("max_speed", 60)), "length": cfg.get("length", default_configs.get(cfg_id, {}).get("length", 5)) } # 补充未覆盖的默认配置(确保关键ID存在) for cfg_id, cfg in default_configs.items(): if cfg_id not in self.vehicle_configs: self.vehicle_configs[cfg_id] = cfg self.logger.debug(f"补充默认车辆配置:v_config_id={cfg_id}") def _init_protocol_regions(self): """基于 self.junction_dict 初始化区域(每个路口一个区域)""" self.region_dict = {} self.region_capacity = {} self.junction_region_map = {} for j_id, junction in self.junction_dict.items(): region_id = f"region_{j_id}" self.region_dict[region_id] = [j_id] self.junction_region_map[j_id] = region_id # 计算区域容量(进口车道总容量) total_cap = sum(self.get_lane_capacity(lane_id) for lane_id in junction["cached_enter_lanes"]) self.region_capacity[region_id] = total_cap if total_cap > 0 else 20 def _init_scene_config(self): """初始化场景参数(天气、高峰期、事故/管制)""" self.weather = self.usr_conf.get("weather", 0) self.peak_hour = self.usr_conf.get("rush_hour", 0) == 1 # 处理事故配置 self.accident_configs = [ cfg for cfg in self.usr_conf.get("traffic_accidents", {}).get("custom_configuration", []) if all(k in cfg for k in ["lane_index", "start_time", "end_time"]) ] # 处理管制配置 self.control_configs = [ cfg for cfg in self.usr_conf.get("traffic_control", {}).get("custom_configuration", []) if all(k in cfg for k in ["lane_index", "start_time", "end_time"]) ] def _mask_to_turns(self, green_mask): """解析green_mask为转向类型""" turns = [] if green_mask & 1: turns.append("straight") if green_mask & 2: turns.append("left") if green_mask & 4: turns.append("right") if green_mask & 8: turns.append("uturn") return turns def _get_turn_lanes(self, junction, turn): """根据转向类型获取车道(依赖车道的turn_type)""" turn_map = {"straight": 0, "left": 1, "right": 2, "uturn": 3} target_turn = turn_map.get(turn) if target_turn is None: return [] # 从进口车道中匹配转向 return [ lane_id for lane_id in junction["cached_enter_lanes"] if self.lane_dict.get(lane_id, {}).get("turn_type") == target_turn ] # ------------------------------ 动态数据更新 ------------------------------ def update_traffic_info(self, obs, extra_info): if "framestate" not in obs: self.logger.error("观测数据缺失framestate,跳过更新") return frame_state = obs["framestate"] frame_no = frame_state.get("frame_no", 0) frame_time = frame_state.get("frame_time", 0) / 1000.0 # 转秒 # 1. 更新相位信息(s_id -> 相位状态) self.current_phases.clear() for phase_info in frame_state.get("phases", []): s_id = phase_info.get("s_id", -1) if s_id == -1: continue self.current_phases[s_id] = { "remaining_duration": phase_info.get("remaining_duration", 0), "phase_id": phase_info.get("phase_id", 0) } # 2. 更新车道车流量(协议v_count) self.lane_volume.clear() for lane in frame_state.get("lanes", []): l_id = lane.get("lane_id", -1) if l_id == -1: continue self.lane_volume[l_id] = lane.get("v_count", 0) # 3. 更新车辆信息 vehicles = frame_state.get("vehicles", []) current_v_ids = set() self.current_vehicles.clear() self.vehicle_status.clear() for vehicle in vehicles: v_id = vehicle.get("v_id") if not v_id: continue current_v_ids.add(v_id) self.current_vehicles[v_id] = vehicle self.vehicle_status[v_id] = vehicle.get("v_status", 0) # 0=正常,1=事故,2=无规则 # 4. 清理过期车辆数据 self._clean_expired_vehicle_data(current_v_ids) # 5. 更新车辆位置、等待时间、行驶距离 for vehicle in vehicles: v_id = vehicle.get("v_id") if not v_id: continue if "position_in_lane" in vehicle: self._update_vehicle_position(vehicle, frame_time) self.cal_waiting_time(frame_time, vehicle) self.cal_travel_distance(vehicle) # 6. 更新事故/管制车道 self._update_active_invalid_lanes(frame_no) # 7. 计算车道需求 self.calculate_lane_demand() def _update_vehicle_position(self, vehicle, frame_time): v_id = vehicle["v_id"] current_pos = vehicle["position_in_lane"] if v_id in self.vehicle_prev_position: # 计算行驶距离 prev_pos = self.vehicle_prev_position[v_id] dx = current_pos["x"] - prev_pos["x"] dy = current_pos["y"] - prev_pos["y"] self.vehicle_distance_store[v_id] = self.vehicle_distance_store.get(v_id, 0.0) + math.hypot(dx, dy) self.vehicle_prev_position[v_id] = current_pos def _clean_expired_vehicle_data(self, current_v_ids): """清理不在当前帧的车辆数据""" expired_v_ids = set(self.vehicle_prev_position.keys()) - current_v_ids for v_id in expired_v_ids: self.vehicle_prev_position.pop(v_id, None) self.vehicle_distance_store.pop(v_id, None) self.waiting_time_store.pop(v_id, None) self.last_waiting_moment.pop(v_id, None) def _update_active_invalid_lanes(self, frame_no): """更新当前生效的事故/管制车道""" self.accident_lanes = set() for cfg in self.accident_configs: if cfg["start_time"] <= frame_no <= cfg["end_time"]: self.accident_lanes.add(cfg["lane_index"]) self.control_lanes = set() for cfg in self.control_configs: if cfg["start_time"] <= frame_no <= cfg["end_time"]: self.control_lanes.add(cfg["lane_index"]) # ------------------------------ 指标计算 ------------------------------ def calculate_lane_demand(self): """计算车道需求(车流量 + 接近停止线的车辆)""" self.lane_demand = {} for lane_id, count in self.lane_volume.items(): demand = count # 叠加接近停止线的车辆(y<100) for vehicle in self.current_vehicles.values(): if vehicle.get("lane") == lane_id: y_pos = vehicle.get("position_in_lane", {}).get("y", 200) if y_pos < 100: demand += 0.5 self.lane_demand[lane_id] = demand def calculate_junction_queue(self, j_id): """计算路口排队长度(速度≤1m/s的进口道车辆)""" junction = self.junction_dict.get(j_id) if not junction: return 0 valid_lanes = set(junction["cached_enter_lanes"]) - self.get_invalid_lanes() queue = 0 for vehicle in self.current_vehicles.values(): if (vehicle.get("lane") in valid_lanes and self.vehicle_status.get(vehicle["v_id"], 0) == 0 and vehicle.get("speed", 0) <= 1.0): queue += 1 return queue def cal_waiting_time(self, frame_time, vehicle): """计算车辆等待时间(进口道内、速度≤0.1m/s、接近停止线)""" v_id = vehicle["v_id"] if not self.on_enter_lane(vehicle): self.waiting_time_store.pop(v_id, None) self.last_waiting_moment.pop(v_id, None) return # 满足等待条件:速度低 + 接近停止线 if vehicle.get("speed", 0) <= 0.1 and vehicle["position_in_lane"]["y"] < 50: if v_id not in self.last_waiting_moment: self.last_waiting_moment[v_id] = frame_time else: duration = frame_time - self.last_waiting_moment[v_id] self.waiting_time_store[v_id] = self.waiting_time_store.get(v_id, 0.0) + duration self.last_waiting_moment[v_id] = frame_time else: self.last_waiting_moment.pop(v_id, None) def cal_travel_distance(self, vehicle): """计算车辆行驶距离""" v_id = vehicle["v_id"] if not self.on_enter_lane(vehicle): self.vehicle_prev_position.pop(v_id, None) self.vehicle_distance_store.pop(v_id, None) return # 初始化历史位置 if v_id not in self.vehicle_prev_position: current_pos = vehicle.get("position_in_lane", {}) if "x" in current_pos and "y" in current_pos: self.vehicle_prev_position[v_id] = { "x": current_pos["x"], "y": current_pos["y"], "distance_to_stop": current_pos["y"] } self.vehicle_distance_store[v_id] = 0.0 return # 计算距离 prev_pos = self.vehicle_prev_position[v_id] current_pos = vehicle["position_in_lane"] try: dx = current_pos["x"] - prev_pos["x"] dy = current_pos["y"] - prev_pos["y"] euclid_dist = math.hypot(dx, dy) stop_dist_reduce = prev_pos["distance_to_stop"] - current_pos["y"] self.vehicle_distance_store[v_id] += max(euclid_dist, stop_dist_reduce) self.vehicle_prev_position[v_id]["distance_to_stop"] = current_pos["y"] except Exception as e: self.logger.error(f"计算车辆{v_id}距离失败: {str(e)}") def on_enter_lane(self, vehicle): """判断车辆是否在进口道(依赖路口的cached_enter_lanes)""" lane_id = vehicle.get("lane") if lane_id is None: return False # 遍历所有路口的进口道,判断车道是否属于进口道 for junction in self.junction_dict.values(): if lane_id in junction["cached_enter_lanes"]: return True return False # ------------------------------ 外部调用接口 ------------------------------ def get_sorted_junction_ids(self): """获取排序后的路口ID列表(供Agent使用)""" return sorted(self.junction_dict.keys()) def get_junction_by_signal(self, s_id): """根据信号灯ID获取路口(供Agent解析相位)""" for j_id, junction in self.junction_dict.items(): if junction.get("signal") == s_id: return junction return None def get_invalid_lanes(self): """获取当前无效车道(事故+管制)""" return self.accident_lanes.union(self.control_lanes) def get_lane_capacity(self, lane_id): """获取车道容量(根据转向类型)""" turn_type = self.lane_dict.get(lane_id, {}).get("turn_type", 0) return 15 if turn_type in [0, 1] else 10 # 直行/左转15辆,右转/掉头10辆 def get_weather(self): return self.weather def is_peak_hour(self): return self.peak_hour def get_vehicle_max_speed(self, v_config_id): """获取车辆最大速度(容错接口)""" return self.vehicle_configs.get(v_config_id, {}).get("max_speed", 60)检查这些代码哪里不匹配,有没有地方不合理。修正一下
最新发布
08-29
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值