import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import matplotlib.font_manager as fm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
from torch.optim.lr_scheduler import CosineAnnealingLR
import random
import time
from sklearn.cluster import DBSCAN
# 字体设置(确保中文显示)
def set_safe_font():
try:
font_paths = fm.findSystemFonts()
system_fonts = set()
for path in font_paths:
try:
font_prop = fm.FontProperties(fname=path)
system_fonts.add(font_prop.get_name())
except:
continue
preferred_fonts = ["SimHei", "Microsoft YaHei", "Heiti TC", "WenQuanYi Zen Hei", "Arial Unicode MS"]
available_fonts = [f for f in preferred_fonts if f in system_fonts]
if available_fonts:
plt.rcParams["font.family"] = [available_fonts[0]]
else:
plt.rcParams["font.family"] = ["sans-serif"]
plt.rcParams["axes.unicode_minus"] = False
except Exception as e:
print(f"字体设置警告: {e}")
set_safe_font()
matplotlib.use('TkAgg') # 有GUI环境用这个,无GUI换为'Agg'
# 随机种子(确保实验可复现)
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
start_time = time.time()
# 坐标处理函数(矩形分组与折线拟合)
def group_coordinates_into_rectangles(raw_coords, group_size=4):
if len(raw_coords) % group_size != 0:
raise ValueError(f"坐标总数必须是{group_size}的倍数!当前共{len(raw_coords)}个坐标")
for i, coord in enumerate(raw_coords):
if not (isinstance(coord[0], (int, float)) and isinstance(coord[1], (int, float))):
raise TypeError(f"第{i + 1}个坐标格式错误,需为数值型(当前:{coord})")
return [np.array(raw_coords[i:i + group_size]) for i in range(0, len(raw_coords), group_size)]
def calculate_rectangle_centers(grouped_coords):
return np.array([np.mean(rectangle, axis=0) for rectangle in grouped_coords])
def fit_line_to_rectangle_group(rectangle_group):
centers = calculate_rectangle_centers(rectangle_group)
lon = centers[:, 0]
lat = centers[:, 1]
lon_var = np.var(lon)
lat_var = np.var(lat)
if lon_var > lat_var:
A = np.vstack([lon, np.ones(len(lon))]).T
k, b = np.linalg.lstsq(A, lat, rcond=None)[0]
lon_min, lon_max = lon.min() - 0.0001, lon.max() + 0.0001
lat_min = k * lon_min + b
lat_max = k * lon_max + b
start_point = np.array([lon_min, lat_min])
end_point = np.array([lon_max, lat_max])
else:
A = np.vstack([lat, np.ones(len(lat))]).T
k, b = np.linalg.lstsq(A, lon, rcond=None)[0]
lat_min, lat_max = lat.min() - 0.0001, lat.max() + 0.0001
lon_min = k * lat_min + b
lon_max = k * lat_max + b
start_point = np.array([lon_min, lat_min])
end_point = np.array([lon_max, lat_max])
line_vec = end_point - start_point
line_len = np.linalg.norm(line_vec)
unit_vec = line_vec / line_len if line_len > 1e-8 else np.array([0, 0])
return {
"start": start_point, "end": end_point, "vec": unit_vec,
"length": line_len, "centers": centers
}
def assign_agent_to_line(agents_centers, fitted_lines):
agent_lines = []
for agent_center in agents_centers:
line_distances = []
for line in fitted_lines:
dist = point_to_line_distance(agent_center, line["start"], line["end"])
line_distances.append((dist, line))
min_dist, assigned_line = min(line_distances, key=lambda x: x[0])
line_vec = assigned_line["vec"]
extend_len = assigned_line["length"] * 0.2
agent_start = assigned_line["start"] - line_vec * extend_len
agent_end = assigned_line["end"] + line_vec * extend_len
agent_line_vec = agent_end - agent_start
agent_line_len = np.linalg.norm(agent_line_vec)
agent_unit_vec = agent_line_vec / agent_line_len if agent_line_len > 1e-8 else np.array([0, 0])
agent_lines.append({
"start": agent_start, "end": agent_end, "vec": agent_unit_vec,
"length": agent_line_len, "center": agent_center,
"assigned_line_idx": fitted_lines.index(assigned_line)
})
return agent_lines
# 折线处理函数(自定义折线解析与智能体分配)
def process_polyline(polyline_points, agent_center=None):
polyline = np.array(polyline_points)
if len(polyline) < 2:
raise ValueError(f"折线至少需要2个坐标点(当前:{len(polyline)}个)")
if polyline.shape[1] != 2:
raise ValueError(f"折线坐标需为(x,y)格式(当前:{polyline.shape})")
start_point = polyline[0]
end_point = polyline[-1]
overall_vec = end_point - start_point
overall_len = np.linalg.norm(overall_vec)
unit_vec = overall_vec / overall_len if overall_len > 1e-8 else np.array([0, 0])
min_x, min_y = polyline.min(axis=0)
max_x, max_y = polyline.max(axis=0)
# 动态调整局部轨道阈值(根据折线长度)
local_segments = polyline
if agent_center is not None:
nearby_segments = []
# 动态阈值:折线长度的1/5(上限0.01≈1100米)
dynamic_threshold = min(0.01, overall_len / 5) if overall_len > 0 else 0.007
for i in range(len(polyline) - 1):
p1 = polyline[i]
p2 = polyline[i + 1]
dist = point_to_line_distance(agent_center, p1, p2)
if dist < dynamic_threshold:
nearby_segments.append(p1)
nearby_segments.append(p2)
# 确保局部轨道至少有2个点
if len(nearby_segments) >= 2:
local_segments = np.unique(nearby_segments, axis=0)
else:
# 无附近线段时,取智能体周围的折线片段(前后各1个点)
closest_idx = np.argmin([np.linalg.norm(agent_center - p) for p in polyline])
start_idx = max(0, closest_idx - 1)
end_idx = min(len(polyline) - 1, closest_idx + 1)
local_segments = polyline[start_idx:end_idx + 1]
return {
"original_points": polyline,
"local_points": local_segments, # 局部运动线段
"start": start_point,
"end": end_point,
"vec": unit_vec,
"length": overall_len,
"bounds": (min_x, max_x, min_y, max_y),
"local_bounds": (local_segments[:, 0].min() - 0.0001, local_segments[:, 0].max() + 0.0001,
local_segments[:, 1].min() - 0.0001, local_segments[:, 1].max() + 0.0001)
}
def assign_agents_to_polylines(agents_centers, custom_polylines, safe_distance=0.0005):
processed_lines = []
for i, polyline in enumerate(custom_polylines):
try:
processed = process_polyline(polyline) # 先不传入agent_center
# 计算折线最大承载量(长度//安全距离)
processed["max_agents"] = max(1, int(np.ceil(processed["length"] / safe_distance)))
processed["assigned_agents_count"] = 0 # 已分配智能体计数
processed_lines.append(processed)
except Exception as e:
raise ValueError(f"第{i + 1}条折线处理失败:{e}")
agent_lines = []
for agent_center in agents_centers:
line_distances = []
for idx, line in enumerate(processed_lines):
# 计算智能体到该折线的最小距离
min_dist = float('inf')
for i in range(len(line["original_points"]) - 1):
p1 = np.array(line["original_points"][i])
p2 = np.array(line["original_points"][i + 1])
dist = point_to_line_distance(agent_center, p1, p2)
if dist < min_dist:
min_dist = dist
# 记录:(距离,折线,折线索引,是否有剩余容量)
line_distances.append((min_dist, line, idx, line["assigned_agents_count"] < line["max_agents"]))
# 优先选择「有剩余容量」的折线中距离最近的
available_lines = [(d, l, idx) for d, l, idx, available in line_distances if available]
if available_lines:
min_dist, assigned_line, line_idx = min(available_lines, key=lambda x: x[0])
else:
# 所有折线都满了,选择距离最近的(降级策略)
min_dist, assigned_line, line_idx, _ = min(line_distances, key=lambda x: x[0])
# 更新折线的已分配计数
processed_lines[line_idx]["assigned_agents_count"] += 1
# 生成智能体的局部轨道
local_processed = process_polyline(assigned_line["original_points"], agent_center)
agent_lines.append({
"original_points": assigned_line["original_points"],
"local_points": local_processed["local_points"],
"start": assigned_line["start"],
"end": assigned_line["end"],
"vec": assigned_line["vec"],
"length": assigned_line["length"],
"bounds": local_processed["local_bounds"],
"center": agent_center,
"assigned_line_idx": line_idx,
"max_agents": assigned_line["max_agents"]
})
print(f"已将{len(agents_centers)}个智能体分配到{len(custom_polylines)}条折线(局部轨道)")
return agent_lines
# 线要素生成函数(自动聚类或手动指定)
def generate_custom_agent_lines(grouped_coords, line_count=None):
agents_centers = calculate_rectangle_centers(grouped_coords)
num_agents = len(agents_centers)
if line_count is None:
distances = []
for i in range(num_agents):
for j in range(i + 1, num_agents):
distances.append(np.linalg.norm(agents_centers[i] - agents_centers[j]))
avg_dist = np.mean(distances) if distances else 0.0002
clustering = DBSCAN(eps=avg_dist * 1.2, min_samples=2).fit(agents_centers)
labels = clustering.labels_
line_count = len(set(labels)) if len(set(labels)) > 0 else 1
fitted_lines = []
if line_count == 1:
fitted_lines.append(fit_line_to_rectangle_group(grouped_coords))
else:
agents_per_line = num_agents // line_count
for i in range(line_count):
start_idx = i * agents_per_line
end_idx = start_idx + agents_per_line if i < line_count - 1 else num_agents
line_group = grouped_coords[start_idx:end_idx]
fitted_lines.append(fit_line_to_rectangle_group(line_group))
agent_lines = assign_agent_to_line(agents_centers, fitted_lines)
print(f"生成自定义线要素:{len(fitted_lines)}条主线 → 分配给{len(agent_lines)}个智能体")
return agent_lines
# 辅助函数:点到线段的距离计算
def point_to_line_distance(point, line_start, line_end):
line_vec = line_end - line_start
point_vec = point - line_start
line_len = np.linalg.norm(line_vec)
if line_len < 1e-8:
return np.linalg.norm(point_vec)
proj_coeff = np.dot(point_vec, line_vec) / (line_len ** 2)
proj_coeff_clipped = np.clip(proj_coeff, 0.0, 1.0)
proj_point = line_start + proj_coeff_clipped * line_vec
return np.linalg.norm(point - proj_point)
# 智能体-折线分配可视化
def visualize_agent_line_assignment(agents_centers, agent_lines, custom_polylines,
save_path="agent_line_assignment.png"):
plt.figure(figsize=(10, 8))
if len(custom_polylines) > 0:
colors = plt.cm.tab10(np.linspace(0, 1, len(custom_polylines))) # 每条折线一个颜色
else:
colors = [plt.cm.tab10(0)] # 默认颜色
# 绘制所有折线
for i, polyline in enumerate(custom_polylines):
polyline_np = np.array(polyline)
plt.plot(polyline_np[:, 0], polyline_np[:, 1], '-', color=colors[i], linewidth=2, alpha=0.6,
label=f'折线{i + 1}')
# 绘制智能体及其分配的局部轨道
for agent_idx, agent_line in enumerate(agent_lines):
# 智能体初始中心
plt.scatter(agent_line["center"][0], agent_line["center"][1],
color=colors[agent_line["assigned_line_idx"]], s=100, edgecolors='black', zorder=5)
# 智能体编号
plt.text(agent_line["center"][0], agent_line["center"][1],
str(agent_idx + 1), ha='center', va='center', fontweight='bold', zorder=6)
# 局部轨道
local_poly = agent_line["local_points"]
plt.plot(local_poly[:, 0], local_poly[:, 1], '--', color=colors[agent_line["assigned_line_idx"]], linewidth=3,
alpha=0.8)
plt.xlabel("经度")
plt.ylabel("纬度")
plt.title("智能体-折线分配关系(实线=原始折线,虚线=局部轨道)")
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig(save_path, dpi=300)
plt.close()
print(f"智能体-折线分配关系图已保存至 {save_path}")
# 环境类(核心:重构step方法的奖惩机制)
class GeoEnv:
def __init__(self, grouped_initial_positions, global_step=4e-9,
min_agent_distance=0.0008, # 安全距离,避免压盖
max_agent_distance=0.0007,
custom_polylines=None):
self.grouped_initial_positions = grouped_initial_positions # 保存原始形状坐标
self.initial_centers = calculate_rectangle_centers(self.grouped_initial_positions)
self.num_agents = len(self.initial_centers)
print(f"环境初始化:{self.num_agents}个智能体")
if custom_polylines is not None:
self.agent_lines = assign_agents_to_polylines(
self.initial_centers,
custom_polylines,
safe_distance=min_agent_distance # 传入安全距离用于负载计算
)
else:
self.agent_lines = generate_custom_agent_lines(self.grouped_initial_positions)
# 动态折线容忍度参数
self.initial_line_tolerance = 0.00037
self.line_tolerance = self.initial_line_tolerance
self.tolerance_increase_rate = 0.00001
self.max_line_tolerance = 0.0009
self.slide_step = global_step
# 计算环境边界
all_coords = np.vstack(self.grouped_initial_positions)
min_lon, min_lat = all_coords.min(axis=0)
max_lon, max_lat = all_coords.max(axis=0)
self.grid_left, self.grid_right = min_lon - 0.001, max_lon + 0.001
self.grid_bottom, self.grid_top = min_lat - 0.001, max_lat + 0.001
self.grid_width = self.grid_right - self.grid_left
self.grid_height = self.grid_top - self.grid_bottom
self.min_agent_dist = min_agent_distance
self.max_agent_dist = max_agent_distance
self.state_dim = (self.num_agents * 2) + 5 * self.num_agents # 状态维度
# 冲突记忆机制
self.recent_conflict_pairs = []
self.conflict_memory_window = 10
self.continuous_conflict_steps = 0
self.reset()
def reset(self, episode=0):
self.agents_pos = self.initial_centers.copy()
self.episode_conflicts = []
self.previous_conflict_count = None
self.previous_positions = self.agents_pos.copy()
self.boundary_violations = [0] * self.num_agents
self.line_violations = [0] * self.num_agents
self.continuous_conflict_free_steps = np.zeros(self.num_agents)
self.stay_count = [0] * self.num_agents
self.recent_conflict_pairs = []
self.continuous_boundary_steps = np.zeros(self.num_agents)
# 动态调整折线容忍度
if episode > 200:
self.line_tolerance = min(
self.initial_line_tolerance + self.tolerance_increase_rate * (episode - 200),
self.max_line_tolerance
)
else:
self.line_tolerance = self.initial_line_tolerance
return self.get_state()
def normalize_state(self, state):
norm_state = state.copy()
for i in range(0, self.num_agents * 2, 2):
norm_state[i] = (norm_state[i] - self.grid_left) / self.grid_width
norm_state[i + 1] = (norm_state[i + 1] - self.grid_bottom) / self.grid_height
return norm_state
def get_state(self):
# 状态组成:归一化位置 + 边界距离 + 最近智能体距离 + 冲突标记 + 边界违规标记 + 折线违规标记
pos_state = self.normalize_state(self.agents_pos.flatten())
# 1. 智能体到折线边界的距离(归一化)
boundary_dist = []
for i in range(self.num_agents):
min_x, max_x, min_y, max_y = self.agent_lines[i]["bounds"]
pos = self.agents_pos[i]
min_dist = min(pos[0] - min_x, max_x - pos[0], pos[1] - min_y, max_y - pos[1])
line_len = self.agent_lines[i]["length"]
boundary_dist.append(min_dist / (line_len / 2) if line_len > 0 else 0)
# 2. 每个智能体到其他智能体的最近距离(归一化)
agent_dist = []
for i in range(self.num_agents):
min_dist = float('inf')
for j in range(self.num_agents):
if i == j:
continue
dist = np.linalg.norm(self.agents_pos[i] - self.agents_pos[j])
if dist < min_dist:
min_dist = dist
agent_dist.append(min_dist / self.max_agent_dist)
# 3. 冲突标记(1=冲突,0=无冲突)
conflict_flag = [0.0] * self.num_agents
current_conflicts = self.detect_conflicts()
if current_conflicts > 0:
for i in range(self.num_agents):
for j in range(i + 1, self.num_agents):
if np.linalg.norm(self.agents_pos[i] - self.agents_pos[j]) < self.min_agent_dist:
conflict_flag[i] = 1.0
conflict_flag[j] = 1.0
# 4. 边界违规标记(1=违规,0=合规)
violation_flag = [0.0] * self.num_agents
for i in range(self.num_agents):
if self.check_boundary_violation(i, self.agents_pos[i]) > 0:
violation_flag[i] = 1.0
# 5. 折线违规标记(1=违规,0=合规)
line_violation_flag = [0.0] * self.num_agents
for i in range(self.num_agents):
polyline = self.agent_lines[i]["original_points"]
min_dist_to_line = float('inf')
for j in range(len(polyline) - 1):
p1 = polyline[j]
p2 = polyline[j + 1]
dist = point_to_line_distance(self.agents_pos[i], p1, p2)
if dist < min_dist_to_line:
min_dist_to_line = dist
if min_dist_to_line > self.line_tolerance:
line_violation_flag[i] = 1.0
return np.concatenate([pos_state, boundary_dist, agent_dist,
conflict_flag, violation_flag, line_violation_flag])
def get_agents_pos(self):
return self.agents_pos.copy()
def get_initial_positions(self):
return self.initial_centers.copy()
def get_agent_shapes(self):
"""获取智能体的原始形状"""
return self.grouped_initial_positions.copy()
def detect_conflicts(self):
conflict_count = 0
rect_diag = 0.0005 # 调整为适合原始形状的对角线长度
for i in range(self.num_agents):
for j in range(i + 1, self.num_agents):
if np.linalg.norm(self.agents_pos[i] - self.agents_pos[j]) < (self.min_agent_dist + rect_diag):
conflict_count += 1
return conflict_count
def detect_all_conflicts(self):
"""检测所有冲突对和涉及的智能体"""
conflict_count = 0
conflict_pairs = [] # 存储冲突对 (i,j),i < j 避免重复
conflict_agents = set() # 存储所有涉及冲突的智能体索引
rect_diag = 0.0005 # 调整为适合原始形状的对角线长度
safe_dist = self.min_agent_dist + rect_diag
for i in range(self.num_agents):
for j in range(i + 1, self.num_agents):
if np.linalg.norm(self.agents_pos[i] - self.agents_pos[j]) < safe_dist:
conflict_count += 1
conflict_pairs.append((i, j))
conflict_agents.add(i)
conflict_agents.add(j)
return conflict_count, conflict_pairs, conflict_agents
@staticmethod
def calculate_agent_distance(agent1_pos, agent2_pos):
return np.linalg.norm(agent1_pos - agent2_pos)
def check_boundary_violation(self, agent_idx, pos):
line = self.agent_lines[agent_idx]
min_x, max_x, min_y, max_y = line["bounds"]
if pos[0] < min_x or pos[0] > max_x or pos[1] < min_y or pos[1] > max_y:
dx = max(min_x - pos[0], 0, pos[0] - max_x)
dy = max(min_y - pos[1], 0, pos[1] - max_y)
return np.sqrt(dx ** 2 + dy ** 2)
return 0.0
def validate_position(self, agent_idx, current_pos, action):
line = self.agent_lines[agent_idx]
line_vec = line["vec"]
# 计算目标位置(沿折线方向)
if action == 0:
target_pos = current_pos + line_vec * self.slide_step
elif action == 1:
target_pos = current_pos - line_vec * self.slide_step
else:
target_pos = current_pos.copy() # 不动时保留原始位置
# 用局部边界裁剪
min_x, max_x, min_y, max_y = line["bounds"]
clamped_x = np.clip(target_pos[0], min_x, max_x)
clamped_y = np.clip(target_pos[1], min_y, max_y)
valid_pos = np.array([clamped_x, clamped_y])
# 弱化折线合规性:允许偏离折线,但不超出局部边界
is_line_valid = True
# 冲突检测与避障
is_dist_valid = True
rect_diag = 0.0005 # 调整为适合原始形状的对角线长度
safe_dist = self.min_agent_dist + rect_diag
for other_idx in range(self.num_agents):
if other_idx == agent_idx:
continue
dist = np.linalg.norm(valid_pos - self.agents_pos[other_idx])
if dist < safe_dist:
is_dist_valid = False
conflict_dir = valid_pos - self.agents_pos[other_idx]
# 优先沿折线方向避障
line_dir = line_vec
perp_dir1 = np.array([-line_vec[1], line_vec[0]])
perp_dir2 = np.array([line_vec[1], -line_vec[0]])
forward_pos = valid_pos + line_dir * (safe_dist - dist + 1e-8)
backward_pos = valid_pos - line_dir * (safe_dist - dist + 1e-8)
if np.linalg.norm(forward_pos - self.agents_pos[other_idx]) >= safe_dist:
adjust_dir = line_dir
elif np.linalg.norm(backward_pos - self.agents_pos[other_idx]) >= safe_dist:
adjust_dir = -line_dir
else:
adjust_dir = perp_dir1 if np.dot(conflict_dir, perp_dir1) > np.dot(conflict_dir,
perp_dir2) else perp_dir2
adjust_dist = safe_dist - dist + 1e-8
valid_pos += adjust_dir * adjust_dist
# 重新裁剪到局部边界
valid_pos[0] = np.clip(valid_pos[0], min_x, max_x)
valid_pos[1] = np.clip(valid_pos[1], min_y, max_y)
# 停留计数
if action == 2:
self.stay_count[agent_idx] += 1
else:
self.stay_count[agent_idx] = 0
return valid_pos, (target_pos[0] >= min_x and target_pos[0] <= max_x and target_pos[1] >= min_y and target_pos[
1] <= max_y), is_dist_valid, is_line_valid
def step(self, actions, step, max_steps_per_episode):
rewards = np.zeros(self.num_agents)
action_mapping = {0: 0, 1: 1, 2: 2}
scale_factor = 0.25
total_global_reward = 0.0
# 1. 统一的冲突检测
current_conflicts, conflict_pairs, conflict_agents = self.detect_all_conflicts()
repeat_conflict = sum(1 for pair in conflict_pairs if pair in self.recent_conflict_pairs)
# 2. 智能体位置更新与完整状态统计
boundary_violation_count = 0
line_violation_count = 0
boundary_violation_degree = 0.0
line_violation_degree = 0.0
for i in range(self.num_agents):
action = action_mapping[actions[i]]
valid_pos, is_boundary_valid, is_dist_valid, is_line_valid = self.validate_position(
i, self.agents_pos[i], action)
# 完整的边界违规统计
boundary_viol_dist = self.check_boundary_violation(i, valid_pos)
if boundary_viol_dist > 0:
boundary_violation_count += 1
boundary_violation_degree += boundary_viol_dist
self.boundary_violations[i] += 1
# 完整的折线违规统计
polyline = self.agent_lines[i]["original_points"]
min_dist_to_line = float('inf')
for j in range(len(polyline) - 1):
dist = point_to_line_distance(valid_pos, polyline[j], polyline[j + 1])
if dist < min_dist_to_line:
min_dist_to_line = dist
if min_dist_to_line > self.line_tolerance:
line_violation_count += 1
line_violation_degree += (min_dist_to_line - self.line_tolerance)
self.line_violations[i] += 1
self.agents_pos[i] = valid_pos
# 3. 检测新状态下的冲突
new_conflicts, new_conflict_pairs, new_conflict_agents = self.detect_all_conflicts()
conflict_diff = new_conflicts - current_conflicts
# 4. 平衡的奖惩计算
total_reward = 0.0
# 4.1 冲突奖惩(权重:40%)
if new_conflicts == 0:
global_conflict_reward = 6.0 * scale_factor
elif conflict_diff < 0:
global_conflict_reward = 4.0 * scale_factor * (-conflict_diff)
else:
global_conflict_reward = -5.0 * scale_factor * max(0, conflict_diff)
total_global_reward += global_conflict_reward * 0.40
# 2. 全局边界奖励
if boundary_violation_count == 0:
global_boundary_reward = 4.0 * scale_factor
else:
avg_viol_degree = boundary_violation_degree / max(1, boundary_violation_count)
global_boundary_reward = -3.0 * scale_factor * boundary_violation_count * (1 + avg_viol_degree)
total_global_reward += global_boundary_reward * 0.35
# 3. 全局折线奖励
if line_violation_count == 0:
global_line_reward = 2.0 * scale_factor
else:
avg_line_degree = line_violation_degree / max(1, line_violation_count)
global_line_reward = -1.5 * scale_factor * line_violation_count * (1 + avg_line_degree)
total_global_reward += global_line_reward * 0.30
# 4. 个体奖励调整
individual_adjustments = np.zeros(self.num_agents)
for i in range(self.num_agents):
# 冲突个体额外惩罚
if i in new_conflict_agents:
individual_adjustments[i] -= 2.0 * scale_factor
# 边界合规个体额外奖励
boundary_viol_dist = self.check_boundary_violation(i, self.agents_pos[i])
if boundary_viol_dist == 0:
individual_adjustments[i] += 1.0 * scale_factor
# 有效移动奖励
move_dist = np.linalg.norm(self.agents_pos[i] - self.previous_positions[i])
if actions[i] != 2 and move_dist > self.slide_step * 0.5:
individual_adjustments[i] += 0.5 * scale_factor
# 5. 最终奖励分配
base_global_reward = total_global_reward / self.num_agents
for i in range(self.num_agents):
rewards[i] = base_global_reward + individual_adjustments[i]
# 判断是否结束
done = (step + 1) >= max_steps_per_episode
# 终局奖励
if done:
milestone_bonus = 0.0
if new_conflicts == 0:
milestone_bonus += 4.0 * scale_factor
if boundary_violation_count == 0:
milestone_bonus += 3.0 * scale_factor
if line_violation_count == 0:
milestone_bonus += 1.0 * scale_factor
# 平均分配给所有智能体
milestone_bonus /= self.num_agents
rewards += milestone_bonus
return self.get_state(), rewards, done, {
"conflicts": new_conflicts,
"boundary_violations": boundary_violation_count,
"line_violations": line_violation_count
}
# 网络模型
class PolicyNetwork(nn.Module):
def __init__(self, state_dim, action_dim, num_agents, hidden_sizes=(128, 64, 32)):
super().__init__()
self.num_agents = num_agents
self.action_dim = action_dim
backbone_layers = []
in_features = state_dim
for hidden_size in hidden_sizes:
backbone_layers.append(nn.Linear(in_features, hidden_size))
backbone_layers.append(nn.LayerNorm(hidden_size))
backbone_layers.append(nn.LeakyReLU(0.1))
in_features = hidden_size
self.backbone = nn.Sequential(*backbone_layers)
self.agent_head = nn.Linear(in_features, num_agents * action_dim)
def forward(self, x):
features = self.backbone(x)
logits = self.agent_head(features)
probs = torch.softmax(logits.view(-1, self.num_agents, self.action_dim), dim=-1)
return probs
class ValueNetwork(nn.Module):
def __init__(self, state_dim, hidden_sizes=(128, 64, 32)):
super().__init__()
layers = [
nn.Linear(state_dim, hidden_sizes[0]),
nn.LayerNorm(hidden_sizes[0]),
nn.LeakyReLU(0.1),
nn.Dropout(0.1)
]
for i in range(1, len(hidden_sizes)):
layers.append(nn.Linear(hidden_sizes[i - 1], hidden_sizes[i]))
layers.append(nn.LayerNorm(hidden_sizes[i]))
layers.append(nn.LeakyReLU(0.1))
layers.append(nn.Linear(hidden_sizes[-1], 1))
self.net = nn.Sequential(*layers)
nn.init.uniform_(self.net[-1].weight, -0.001, 0.001)
nn.init.constant_(self.net[-1].bias, 0.0)
def forward(self, x):
return self.net(x)
# PPO算法
class PPO:
def __init__(self, state_dim, action_dim, num_agents,
lr=3.5e-5, gamma=0.99, gae_lambda=0.90, epsilon=0.15,
epochs=4, batch_size=64, ent_coef=0.4):
self.policy_net = PolicyNetwork(state_dim, action_dim, num_agents)
self.value_net = ValueNetwork(state_dim)
self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=lr, eps=1e-8)
self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=lr * 0.5, eps=1e-8)
self.num_agents = num_agents
self.action_dim = action_dim
self.gamma = gamma
self.gae_lambda = gae_lambda
self.epsilon = epsilon
self.epochs = epochs
self.batch_size = batch_size
self.ent_coef = ent_coef # 熵系数(训练中衰减)
self.memory = []
def select_action(self, state):
state_tensor = torch.FloatTensor(state).unsqueeze(0)
probs = self.policy_net(state_tensor)
actions = []
log_probs_list = []
for i in range(self.num_agents):
agent_probs = probs[0, i, :]
dist = Categorical(agent_probs)
action = dist.sample()
actions.append(action.item())
log_probs_list.append(dist.log_prob(action))
log_probs_tensor = torch.stack(log_probs_list).detach()
return actions, log_probs_tensor
def store_transition(self, transition):
self.memory.append(transition)
def update(self):
n_samples = len(self.memory)
if n_samples == 0:
return 0.0, 0.0
states = torch.FloatTensor(np.array([t[0] for t in self.memory]))
actions = torch.LongTensor(np.array([t[1] for t in self.memory]))
old_log_probs = torch.stack([t[2] for t in self.memory])
rewards = torch.FloatTensor(np.array([t[3] for t in self.memory]))
next_states = torch.FloatTensor(np.array([t[4] for t in self.memory]))
dones = torch.FloatTensor(np.array([t[5] for t in self.memory])).view(-1, 1)
global_rewards = rewards.mean(dim=1, keepdim=True)
with torch.no_grad():
values = self.value_net(states)
next_values = self.value_net(next_states)
deltas = global_rewards + self.gamma * next_values * (1 - dones) - values
advantages = torch.zeros_like(deltas)
advantage = 0.0
for t in reversed(range(len(deltas))):
advantage = deltas[t] + self.gamma * self.gae_lambda * advantage * (1 - dones[t])
advantages[t] = advantage
returns = advantages + values
returns = (returns[:-2] + returns[1:-1] + returns[2:]) / 3 # 3步滑动平均
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
total_policy_loss = 0.0
total_value_loss = 0.0
update_count = 0
for _ in range(self.epochs):
permutation = torch.randperm(n_samples)
num_batches = max(1, (n_samples + self.batch_size - 1) // self.batch_size)
permutation = permutation.repeat(num_batches)[:num_batches * self.batch_size]
for start in range(0, len(permutation), self.batch_size):
end = start + self.batch_size
batch_indices = permutation[start:end]
if len(batch_indices) == 0:
continue
max_return_idx = len(returns) - 1
valid_indices = batch_indices[batch_indices <= max_return_idx]
if len(valid_indices) == 0:
continue
b_states = states[valid_indices]
b_actions = actions[valid_indices]
b_old_log_probs = old_log_probs[valid_indices]
b_advantages = advantages[valid_indices]
b_returns = returns[valid_indices]
# 策略网络更新
self.policy_optimizer.zero_grad()
b_new_probs = self.policy_net(b_states)
b_new_log_probs_list = []
b_entropies_list = []
for i in range(self.num_agents):
dist_i = Categorical(b_new_probs[:, i, :])
log_prob_i = dist_i.log_prob(b_actions[:, i])
entropy_i = dist_i.entropy()
b_new_log_probs_list.append(log_prob_i.unsqueeze(1))
b_entropies_list.append(entropy_i.unsqueeze(1))
b_new_log_probs = torch.cat(b_new_log_probs_list, dim=1)
b_total_entropy = torch.cat(b_entropies_list, dim=1).mean(dim=1, keepdim=True)
ratio = torch.exp(b_new_log_probs - b_old_log_probs)
surr1 = ratio * b_advantages
surr2 = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon) * b_advantages
ppo_loss = -torch.min(surr1, surr2).mean()
entropy_loss = -self.ent_coef * b_total_entropy.mean()
policy_loss = ppo_loss + entropy_loss
policy_loss.backward()
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=0.5)
self.policy_optimizer.step()
# 价值网络更新
self.value_optimizer.zero_grad()
value_preds = self.value_net(b_states)
huber_loss = nn.HuberLoss(delta=0.3)
value_loss = huber_loss(value_preds, b_returns.detach())
l2_lambda = 2e-5
l2_reg = torch.tensor(0., requires_grad=True)
for param in self.value_net.parameters():
l2_reg = l2_reg + torch.norm(param, p=2) ** 2
value_loss = value_loss + l2_lambda * l2_reg
value_loss.backward()
torch.nn.utils.clip_grad_norm_(self.value_net.parameters(), max_norm=0.5)
self.value_optimizer.step()
total_policy_loss += policy_loss.item()
total_value_loss += value_loss.item()
update_count += 1
self.memory.clear()
avg_policy_loss = total_policy_loss / update_count if update_count > 0 else 0.0
avg_value_loss = total_value_loss / update_count if update_count > 0 else 0.0
return avg_policy_loss, avg_value_loss
# 训练函数
def train_ppo(env, agent, episodes=800, max_steps_per_episode=60, update_every=6):
reward_history = []
policy_loss_history = []
value_loss_history = []
conflict_history = []
boundary_violation_history = []
line_violation_history = []
initial_positions = env.get_initial_positions()
final_positions = None
agent_shapes = env.get_agent_shapes() # 获取智能体原始形状
plt.ion()
fig, ax = plt.subplots(figsize=(8, 6))
render_interval = 10
policy_lr_scheduler = CosineAnnealingLR(agent.policy_optimizer, T_max=600, eta_min=8e-7)
value_lr_scheduler = CosineAnnealingLR(agent.value_optimizer, T_max=600, eta_min=4e-7)
for episode in range(episodes):
state = env.reset(episode)
done = False
total_reward = 0.0
step = 0
episode_conflicts = []
episode_boundary_violations = []
episode_line_violations = []
agent.memory.clear()
# 熵系数衰减
max_ent = 0.15
min_ent = 0.01
decay_rate = (max_ent - min_ent) / 400
agent.ent_coef = max(min_ent, max_ent - episode * decay_rate)
while not done and step < max_steps_per_episode:
actions, log_probs = agent.select_action(state)
next_state, rewards, done, info = env.step(actions, step, max_steps_per_episode)
total_reward += np.sum(rewards)
episode_conflicts.append(info['conflicts'])
episode_boundary_violations.append(info['boundary_violations'])
episode_line_violations.append(info['line_violations'])
agent.store_transition((state, actions, log_probs, rewards, next_state, float(done)))
if step % render_interval == 0 or done:
ax.clear()
agents_pos = np.array(env.get_agents_pos())
agent_lines = env.agent_lines
# 绘制折线(原始折线+局部运动线段)
for i, line in enumerate(agent_lines):
# 绘制原始折线(灰色细线条)
polyline = line["original_points"]
ax.plot(polyline[:, 0], polyline[:, 1],
'-', color="gray", linewidth=1, alpha=0.5,
label='原始折线' if i == 0 else "")
# 绘制局部运动线段(黑色粗线条)
local_poly = line["local_points"]
ax.plot(local_poly[:, 0], local_poly[:, 1],
'-', color="black", linewidth=2, alpha=0.8,
label='局部运动轨道' if i == 0 else "")
# 绘制智能体(使用原始形状)
for i, (lon, lat) in enumerate(agents_pos):
# 获取智能体原始形状的相对坐标
shape_coords = agent_shapes[i]
# 计算中心到各顶点的偏移量
offsets = shape_coords - initial_positions[i]
# 根据当前位置调整形状坐标
current_shape = np.array([[lon + dx, lat + dy] for dx, dy in offsets])
# 绘制原始形状
shape_patch = plt.Polygon(
current_shape,
facecolor='none',
edgecolor='black',
linewidth=2,
alpha=0.7
)
ax.add_patch(shape_patch)
ax.text(lon, lat, str(i + 1), ha='center', va='center', fontweight='bold')
# 标记冲突(红色边框)
current_conflict = info['conflicts']
if current_conflict > 0:
conflict_pairs = []
rect_diag = 0.0013 # 适合原始形状的对角线长度
for i in range(env.num_agents):
for j in range(i + 1, env.num_agents):
if np.linalg.norm(agents_pos[i] - agents_pos[j]) < (env.min_agent_dist + rect_diag):
conflict_pairs.append(i)
conflict_pairs.append(j)
for idx in set(conflict_pairs):
lon, lat = agents_pos[idx]
shape_coords = agent_shapes[idx]
offsets = shape_coords - initial_positions[idx]
current_shape = np.array([[lon + dx, lat + dy] for dx, dy in offsets])
conflict_patch = plt.Polygon(
current_shape,
facecolor='none',
edgecolor='red',
linewidth=2,
alpha=0.7
)
ax.add_patch(conflict_patch)
# 标记线违规(橙色边框)
current_line_viol = info['line_violations']
if current_line_viol > 0:
for i in range(env.num_agents):
line = agent_lines[i]
polyline = line["original_points"]
min_dist_to_line = float('inf')
for j in range(len(polyline) - 1):
p1 = polyline[j]
p2 = polyline[j + 1]
dist = point_to_line_distance(agents_pos[i], p1, p2)
if dist < min_dist_to_line:
min_dist_to_line = dist
if min_dist_to_line > env.line_tolerance:
lon, lat = agents_pos[i]
shape_coords = agent_shapes[i]
offsets = shape_coords - initial_positions[i]
current_shape = np.array([[lon + dx, lat + dy] for dx, dy in offsets])
line_viol_patch = plt.Polygon(
current_shape,
facecolor='none',
edgecolor='orange',
linewidth=2,
alpha=0.7
)
ax.add_patch(line_viol_patch)
ax.set_xlim(env.grid_left, env.grid_right)
ax.set_ylim(env.grid_bottom, env.grid_top)
ax.set_title(
f"回合: {episode + 1}, 步数: {step + 1}, 冲突: {current_conflict}, "
f"边界违规: {info['boundary_violations']}, 线违规: {current_line_viol}"
)
ax.set_xlabel("经度")
ax.set_ylabel("纬度")
ax.legend()
plt.pause(0.00001)
# 定期更新模型
if (step + 1) % update_every == 0 or done:
policy_loss, value_loss = agent.update()
policy_loss_history.append(policy_loss)
value_loss_history.append(value_loss)
state = next_state
step += 1
policy_lr_scheduler.step()
value_lr_scheduler.step()
reward_history.append(total_reward)
conflict_history.append(np.mean(episode_conflicts) if episode_conflicts else 0)
boundary_violation_history.append(np.mean(episode_boundary_violations) if episode_boundary_violations else 0)
line_violation_history.append(np.mean(episode_line_violations) if episode_line_violations else 0)
if episode % 10 == 0:
latest_policy_loss = policy_loss_history[-1] if policy_loss_history else 0.0
latest_value_loss = value_loss_history[-1] if value_loss_history else 0.0
current_lr = agent.policy_optimizer.param_groups[0]['lr']
print(f"回合 {episode:4d}, 总奖励: {total_reward:6.2f}, "
f"冲突: {np.mean(episode_conflicts):.2f}, 边界违规: {np.mean(episode_boundary_violations):.2f}, "
f"线违规: {np.mean(episode_line_violations):.2f}, "
f"策略损失: {latest_policy_loss:.4f}, 价值损失: {latest_value_loss:.4f}, 学习率: {current_lr:.1e}")
final_positions = env.get_agents_pos()
plt.ioff()
plt.close()
torch.save(agent.policy_net.state_dict(), "rect_ppo_policy_polyline_optimized.pth")
torch.save(agent.value_net.state_dict(), "rect_ppo_value_polyline_optimized.pth")
print("\n训练完成!优化后的模型已保存")
return (reward_history, policy_loss_history, value_loss_history,
conflict_history, boundary_violation_history, line_violation_history,
initial_positions, final_positions, env.agent_lines, agent_shapes)
# 初始/最终位置可视化(分开保存)
def plot_initial_positions(env, initial_pos, agent_lines, agent_shapes):
"""绘制并保存初始位置图"""
rect_diag = 0.0013 # 适合原始形状的对角线长度
plt.figure(figsize=(10, 8))
ax = plt.gca()
for i, line in enumerate(agent_lines):
polyline = line["original_points"]
ax.plot(polyline[:, 0], polyline[:, 1], '-', color="gray", linewidth=1, alpha=0.5)
local_poly = line["local_points"]
ax.plot(local_poly[:, 0], local_poly[:, 1], '-', color="black", linewidth=2, alpha=0.8,
label='局部运动轨道' if i == 0 else "")
# 绘制智能体原始形状
for i, (lon, lat) in enumerate(initial_pos):
shape_coords = agent_shapes[i]
offsets = shape_coords - initial_pos[i]
current_shape = np.array([[lon + dx, lat + dy] for dx, dy in offsets])
shape_patch = plt.Polygon(
current_shape,
facecolor='none',
edgecolor='black',
linewidth=2,
alpha=0.7
)
ax.add_patch(shape_patch)
ax.text(lon, lat, str(i + 1), ha='center', va='center', fontweight='bold', fontsize=10)
# 标记冲突智能体
initial_conflict_idx = set()
for i in range(env.num_agents):
for j in range(i + 1, env.num_agents):
if np.linalg.norm(initial_pos[i] - initial_pos[j]) < (env.min_agent_dist + rect_diag):
initial_conflict_idx.add(i)
initial_conflict_idx.add(j)
for idx in initial_conflict_idx:
lon, lat = initial_pos[idx]
shape_coords = agent_shapes[idx]
offsets = shape_coords - initial_pos[idx]
current_shape = np.array([[lon + dx, lat + dy] for dx, dy in offsets])
conflict_patch = plt.Polygon(
current_shape,
facecolor='none',
edgecolor='red',
linewidth=2,
alpha=0.7
)
ax.add_patch(conflict_patch)
ax.set_xlim(env.grid_left, env.grid_right)
ax.set_ylim(env.grid_bottom, env.grid_top)
ax.set_title(f"初始位置分布(冲突智能体数:{len(initial_conflict_idx)})")
ax.set_xlabel("经度")
ax.set_ylabel("纬度")
ax.legend()
ax.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
plt.savefig("initial_positions_polyline_optimized.png", dpi=300)
plt.close()
print("初始位置图已保存为 'initial_positions_polyline_optimized.png'")
def plot_final_positions(env, final_pos, agent_lines, agent_shapes):
"""绘制并保存最终位置图"""
rect_diag = 0.0013 # 适合原始形状的对角线长度
plt.figure(figsize=(10, 8))
ax = plt.gca()
for i, line in enumerate(agent_lines):
polyline = line["original_points"]
ax.plot(polyline[:, 0], polyline[:, 1], '-', color="gray", linewidth=1, alpha=0.5)
local_poly = line["local_points"]
ax.plot(local_poly[:, 0], local_poly[:, 1], '-', color="black", linewidth=2, alpha=0.8,
label='局部运动轨道' if i == 0 else "")
# 绘制智能体原始形状
for i, (lon, lat) in enumerate(final_pos):
shape_coords = agent_shapes[i]
offsets = shape_coords - env.get_initial_positions()[i] # 基于初始位置计算偏移
current_shape = np.array([[lon + dx, lat + dy] for dx, dy in offsets])
shape_patch = plt.Polygon(
current_shape,
facecolor='none',
edgecolor='black',
linewidth=2,
alpha=0.7
)
ax.add_patch(shape_patch)
ax.text(lon, lat, str(i + 1), ha='center', va='center', fontweight='bold', fontsize=10)
# 标记冲突智能体
final_conflict_idx = set()
for i in range(env.num_agents):
for j in range(i + 1, env.num_agents):
if np.linalg.norm(final_pos[i] - final_pos[j]) < (env.min_agent_dist + rect_diag):
final_conflict_idx.add(i)
final_conflict_idx.add(j)
for idx in final_conflict_idx:
lon, lat = final_pos[idx]
shape_coords = agent_shapes[idx]
offsets = shape_coords - env.get_initial_positions()[idx]
current_shape = np.array([[lon + dx, lat + dy] for dx, dy in offsets])
conflict_patch = plt.Polygon(
current_shape,
facecolor='none',
edgecolor='red',
linewidth=2,
alpha=0.7
)
ax.add_patch(conflict_patch)
# 标记线违规智能体
line_viol_idx = set()
for i in range(env.num_agents):
line = agent_lines[i]
polyline = line["original_points"]
min_dist_to_line = float('inf')
for j in range(len(polyline) - 1):
p1 = polyline[j]
p2 = polyline[j + 1]
dist = point_to_line_distance(final_pos[i], p1, p2)
if dist < min_dist_to_line:
min_dist_to_line = dist
if min_dist_to_line > env.line_tolerance:
line_viol_idx.add(i)
for idx in line_viol_idx:
lon, lat = final_pos[idx]
shape_coords = agent_shapes[idx]
offsets = shape_coords - env.get_initial_positions()[idx]
current_shape = np.array([[lon + dx, lat + dy] for dx, dy in offsets])
line_viol_patch = plt.Polygon(
current_shape,
facecolor='none',
edgecolor='orange',
linewidth=2,
alpha=0.7
)
ax.add_patch(line_viol_patch)
ax.set_xlim(env.grid_left, env.grid_right)
ax.set_ylim(env.grid_bottom, env.grid_top)
ax.set_title(f"最终位置分布(冲突:{len(final_conflict_idx)},线违规:{len(line_viol_idx)})")
ax.set_xlabel("经度")
ax.set_ylabel("纬度")
ax.legend()
ax.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
plt.savefig("final_positions_polyline_optimized.png", dpi=300)
plt.close()
print("最终位置图已保存为 'final_positions_polyline_optimized.png'")
# 结果绘图函数
def plot_results(reward_history, policy_loss, value_loss, conflict_history, boundary_violation_history,
line_violation_history):
plt.figure(figsize=(10, 4))
plt.plot(reward_history, alpha=0.5, label="单回合奖励")
if len(reward_history) >= 10:
moving_avg = np.convolve(reward_history, np.ones(10) / 10, mode='valid')
plt.plot(range(9, len(reward_history)), moving_avg, label="10回合平均")
plt.title("奖励曲线(优化后)")
plt.xlabel("回合")
plt.ylabel("总奖励")
plt.legend()
plt.grid()
plt.savefig("reward_polyline_optimized.png", dpi=300)
plt.close()
plt.figure(figsize=(10, 4))
plt.plot(conflict_history, label="平均冲突数", color='blue')
plt.plot(boundary_violation_history, label="平均边界违规数", color='orange')
plt.plot(line_violation_history, label="平均线违规数", color='red')
plt.axhline(y=0, color='black', linestyle='--')
plt.title("冲突与违规曲线(优化后)")
plt.xlabel("回合")
plt.ylabel("数量")
plt.legend()
plt.grid()
plt.savefig("conflict_violation_polyline_optimized.png", dpi=300)
plt.close()
plt.figure(figsize=(10, 4))
if len(policy_loss) >= 10:
policy_smoothed = np.convolve(policy_loss, np.ones(10) / 10, mode='valid')
plt.plot(range(9, len(policy_loss)), policy_smoothed, label="策略损失(平滑)", color='green')
plt.plot(policy_loss, alpha=0.3, color='green')
plt.title("策略损失曲线(优化后)")
plt.xlabel("更新步数")
plt.ylabel("损失值")
plt.legend()
plt.grid()
plt.savefig("policy_loss_polyline_optimized.png", dpi=300)
plt.close()
plt.figure(figsize=(10, 4))
if len(value_loss) >= 10:
value_smoothed = np.convolve(value_loss, np.ones(10) / 10, mode='valid')
plt.plot(range(9, len(value_loss)), value_smoothed, label="价值损失(平滑)", color='purple')
plt.plot(value_loss, alpha=0.3, color='purple')
plt.title("价值损失曲线(优化后)")
plt.xlabel("更新步数")
plt.ylabel("损失值")
plt.legend()
plt.grid()
plt.savefig("value_loss_polyline_optimized.png", dpi=300)
plt.close()
plt.figure(figsize=(12, 6))
if len(reward_history) > 0 and len(conflict_history) > 0 and len(value_loss) > 0:
norm_reward = (np.array(reward_history) - np.min(reward_history)) / (
np.max(reward_history) - np.min(reward_history) + 1e-8)
norm_conflict = 1 - (np.array(conflict_history) - np.min(conflict_history)) / (
np.max(conflict_history) - np.min(conflict_history) + 1e-8)
norm_line_viol = 1 - (np.array(line_violation_history) - np.min(line_violation_history)) / (
np.max(line_violation_history) - np.min(line_violation_history) + 1e-8)
trunc_value_loss = value_loss[:len(reward_history)]
norm_value_loss = (np.array(trunc_value_loss) - np.min(trunc_value_loss)) / (
np.max(trunc_value_loss) - np.min(trunc_value_loss) + 1e-8)
plt.plot(norm_reward, label="归一化总奖励", color='blue')
plt.plot(norm_conflict, label="归一化冲突值(反向)", color='green')
plt.plot(norm_line_viol, label="归一化线违规值(反向)", color='orange')
plt.plot(norm_value_loss, label="归一化价值损失", color='red', alpha=0.7)
plt.title("总奖励-冲突-线违规-价值损失联动图(优化后)")
plt.xlabel("回合")
plt.ylabel("归一化值")
plt.legend()
plt.grid()
plt.savefig("correlation_polyline_optimized.png", dpi=300)
plt.close()
# 主函数
def main():
# 智能体矩形坐标(每组4个顶点,定义智能体原始形状)
your_coordinates = [
[121.44042, 31.323465], [121.440783, 31.323465], [121.440783, 31.323964], [121.44042, 31.323964],
[121.439167, 31.31262], [121.440305, 31.31262], [121.440305, 31.313174], [121.439167, 31.313174],
[121.45059, 31.311141], [121.451727, 31.311141], [121.451727, 31.311694], [121.45059, 31.311694],
[121.442881, 31.31078], [121.44389, 31.31078], [121.44389, 31.311334], [121.442881, 31.311334],
[121.443881, 31.312954], [121.445019, 31.312954], [121.445019, 31.313508], [121.443881, 31.313508],
[121.446896, 31.311852], [121.448033, 31.311852], [121.448033, 31.312406], [121.446896, 31.312406],
[121.444236, 31.31119], [121.445245, 31.31119], [121.445245, 31.311744], [121.444236, 31.311744],
[121.441675, 31.316022], [121.442684, 31.316022], [121.442684, 31.316576], [121.441675, 31.316576],
[121.442911, 31.312575], [121.44392, 31.312575], [121.44392, 31.313129], [121.442911, 31.313129],
[121.44394, 31.315784], [121.444949, 31.315784], [121.444949, 31.316338], [121.44394, 31.316338],
[121.451557, 31.313325], [121.452666, 31.313325], [121.452666, 31.313868], [121.451557, 31.313868],
[121.452448, 31.315506], [121.453935, 31.315506], [121.453935, 31.316004], [121.452448, 31.316004],
[121.447553, 31.315458], [121.44904, 31.315458], [121.44904, 31.316001], [121.447553, 31.316001],
[121.450557, 31.313825], [121.451666, 31.313825], [121.451666, 31.314368], [121.450557, 31.314368],
[121.440305, 31.322271], [121.441057, 31.322271], [121.441057, 31.322825], [121.440305, 31.322825],
[121.44275, 31.313706], [121.443759, 31.313706], [121.443759, 31.314259], [121.44275, 31.314259],
[121.44653, 31.315487], [121.448056, 31.315487], [121.448056, 31.316041], [121.44653, 31.316041],
[121.45081, 31.316222], [121.451948, 31.316222], [121.451948, 31.316776], [121.45081, 31.316776],
[121.45282, 31.31282], [121.453958, 31.31282], [121.453958, 31.313374], [121.45282, 31.313374]
]
# 多坐标点折线
your_custom_polylines = [
[
[121.444092, 31.310223], [121.443857, 31.31533], [121.443778, 31.316861],
[121.443699, 31.317496], [121.443428, 31.31852], [121.443264, 31.319012],
[121.442922, 31.319814], [121.441951, 31.322021], [121.437934, 31.33081],
[121.436298, 31.334379], [121.434728, 31.337951], [121.43302, 31.342159]
],
[
[121.450669, 31.315322], [121.451167, 31.315325], [121.452653, 31.315332],
[121.453431, 31.315334], [121.453678, 31.315331], [121.45395, 31.315338],
[121.454566, 31.315341], [121.455416, 31.315344]
],
[
[121.447862, 31.307364], [121.447802, 31.310141], [121.447802, 31.312613],
[121.447862, 31.315308], [121.447814, 31.31773], [121.447802, 31.318823],
[121.447905, 31.319138], [121.448575, 31.319629], [121.451242, 31.320819]
],
[[121.450669, 31.312621], [121.450666, 31.314346], [121.450669, 31.315322]],
[
[121.432436, 31.342711], [121.43314, 31.34107], [121.434084, 31.338915],
[121.436022, 31.334623], [121.437832, 31.330747], [121.440288, 31.325478],
[121.441313, 31.323083], [121.441547, 31.322548], [121.442606, 31.320254],
[121.443153, 31.318998], [121.443348, 31.31841], [121.44357, 31.317523],
[121.44365, 31.316862], [121.443714, 31.315821], [121.443842, 31.312865],
[121.443951, 31.310217], [121.4439, 31.308523], [121.44387, 31.307221]
],
[[121.443661, 31.315258], [121.443701, 31.314217], [121.443712, 31.31393], [121.443765, 31.312525]],
[
[121.444031, 31.31526], [121.445081, 31.315273], [121.445704, 31.315281],
[121.447862, 31.315308], [121.448595, 31.315312], [121.44911, 31.315314],
[121.450669, 31.315322]
],
[[121.450669, 31.315322], [121.450918, 31.316447], [121.450993, 31.316816], [121.451116, 31.317412]],
[
[121.455355, 31.312634], [121.454421, 31.312631], [121.453857, 31.31263],
[121.453396, 31.312628], [121.452636, 31.312626], [121.450669, 31.312621]
],
[
[121.443765, 31.312525], [121.442313, 31.312481], [121.441637, 31.312458],
[121.440438, 31.312447], [121.439588, 31.312426], [121.439197, 31.312417],
[121.438451, 31.3124], [121.43793, 31.312394], [121.437464, 31.312385],
[121.436576, 31.312355]
],
[[121.450738, 31.310214], [121.450669, 31.312621]],
[[121.443765, 31.312525], [121.443853, 31.310818], [121.443904, 31.309829]],
[[121.444225, 31.312532], [121.450669, 31.312621]],
[
[121.450669, 31.312621], [121.449561, 31.312623], [121.448839, 31.312611],
[121.447802, 31.312613], [121.447433, 31.312612], [121.445848, 31.312568],
[121.444225, 31.312532]
],
[[121.444302, 31.309839], [121.444348, 31.311538], [121.444266, 31.31196], [121.444225, 31.312532]],
[[121.443661, 31.315258], [121.444031, 31.31526]],
[[121.444225, 31.312532], [121.444142, 31.31276], [121.444051, 31.314857], [121.444031, 31.31526]],
[[121.444031, 31.31526], [121.443948, 31.316862]],
[[121.443585, 31.316781], [121.443658, 31.315359], [121.443661, 31.315258]]
]
# 坐标分组
try:
grouped_coords = group_coordinates_into_rectangles(your_coordinates, group_size=4)
print(f"坐标分组完成:共{len(grouped_coords)}个矩形 → 对应{len(grouped_coords)}个智能体")
except ValueError as e:
print(f"坐标错误:{e}")
return
# 创建环境
env = GeoEnv(
grouped_initial_positions=grouped_coords,
global_step=2e-5, # 滑动步长
min_agent_distance=0.0005, # 安全距离
max_agent_distance=0.0001,
custom_polylines=your_custom_polylines
)
# 可视化智能体-折线分配关系
agent_lines = env.agent_lines
visualize_agent_line_assignment(env.initial_centers, agent_lines, your_custom_polylines)
# 初始化PPO
state_dim = env.state_dim
action_dim = 3 # 0=向前,1=向后,2=不动
agent = PPO(
state_dim, action_dim, num_agents=env.num_agents,
lr=3e-6, gamma=0.93, epochs=10, batch_size=512, ent_coef=0.15,
epsilon=0.08, gae_lambda=0.80
)
# 开始训练
print("\n开始训练(优化版:保留原始位置,折线旁移动)...")
(reward_history, policy_loss, value_loss,
conflict_history, boundary_violation_history, line_violation_history,
initial_pos, final_pos, agent_lines, agent_shapes) = train_ppo(
env, agent, episodes=1200, max_steps_per_episode=100, update_every=30
)
# 绘制结果
plot_results(reward_history, policy_loss, value_loss, conflict_history, boundary_violation_history,
line_violation_history)
# 分开保存初始和最终位置图
plot_initial_positions(env, initial_pos, agent_lines, agent_shapes)
plot_final_positions(env, final_pos, agent_lines, agent_shapes)
if __name__ == "__main__":
main()
end_time = time.time()
print(f"总运行时间:{end_time - start_time:.2f}秒")
最新发布