import copy
import math
import random
import time
import json
from datetime import datetime
import requests
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from scipy.spatial.transform import Rotation as Rot
import numpy as np
# ----------------------------
# 配置参数
# ----------------------------
API_BASE_URL = "http://62.234.16.239:18080/api/v1/alg/task" # 结果提交API
AUTH_BASE_URL = "http://62.234.16.239:18080/api/v1/auth" # 认证接口基础地址
ALG_NAME = "路径规划" # 算法标识
# 登录凭证
LOGIN_USERNAME = "admin"
LOGIN_PASSWORD = "_!ucs@dev"
show_animation = False # 不显示动画
# 全局Token存储(仅保留accessToken相关)
ACCESS_TOKEN = ""
TOKEN_EXPIRE_TIME = 0 # Token过期时间(秒级)
# 任务处理间隔(秒)
TASK_CHECK_INTERVAL = 2
# 分页参数
PAGE_SIZE = 10
BATCH_MODE = 0 # batch模式标识
class Node:
def __init__(self, x, y):
self.x = x
self.y = y
self.cost = 0.0
self.parent = None
class RRT:
def __init__(self, obstacleList, randArea,
expandDis=2.0, goalSampleRate=10, maxIter=200, car_radius=0.25):
self.start = None
self.goal = None
self.min_rand = randArea[0]
self.max_rand = randArea[1]
self.expand_dis = expandDis
self.goal_sample_rate = goalSampleRate
self.max_iter = maxIter
self.obstacle_list = obstacleList
self.car_radius = car_radius
self.node_list = None
self.fig = None
self.ax = None
if show_animation:
figg, axx = plt.subplots()
self.fig = figg
self.ax = axx
def rrt_planning(self, start, goal, animation=True):
start_time = time.time()
self.start = Node(start[0], start[1])
self.goal = Node(goal[0], goal[1])
self.node_list = [self.start]
path = None
for i in range(self.max_iter):
rnd = self.sample()
n_ind = self.get_nearest_list_index(self.node_list, rnd)
nearestNode = self.node_list[n_ind]
theta = math.atan2(rnd[1] - nearestNode.y, rnd[0] - nearestNode.x)
newNode = self.get_new_node(theta, n_ind, nearestNode)
noCollision = self.check_segment_collision(newNode.x, newNode.y, nearestNode.x, nearestNode.y)
if noCollision:
self.node_list.append(newNode)
if animation:
self.draw_graph(newNode, path)
if self.is_near_goal(newNode):
if self.check_segment_collision(newNode.x, newNode.y, self.goal.x, self.goal.y):
lastIndex = len(self.node_list) - 1
path = self.get_final_course(lastIndex)
pathLen = self.get_path_len(path)
print(f"当前路径长度: {pathLen}, 耗时: {time.time() - start_time} s")
if animation:
self.draw_graph(newNode, path)
return path
def rrt_star_planning(self, start, goal, animation=True):
start_time = time.time()
self.start = Node(start[0], start[1])
self.goal = Node(goal[0], goal[1])
self.node_list = [self.start]
path = None
lastPathLength = float('inf')
for i in range(self.max_iter):
rnd = self.sample()
n_ind = self.get_nearest_list_index(self.node_list, rnd)
nearestNode = self.node_list[n_ind]
theta = math.atan2(rnd[1] - nearestNode.y, rnd[0] - nearestNode.x)
newNode = self.get_new_node(theta, n_ind, nearestNode)
noCollision = self.check_segment_collision(newNode.x, newNode.y, nearestNode.x, nearestNode.y)
if noCollision:
nearInds = self.find_near_nodes(newNode)
newNode = self.choose_parent(newNode, nearInds)
self.node_list.append(newNode)
self.rewire(newNode, nearInds)
if animation:
self.draw_graph(newNode, path)
if self.is_near_goal(newNode):
if self.check_segment_collision(newNode.x, newNode.y, self.goal.x, self.goal.y):
lastIndex = len(self.node_list) - 1
tempPath = self.get_final_course(lastIndex)
tempPathLen = self.get_path_len(tempPath)
if lastPathLength > tempPathLen:
path = tempPath
lastPathLength = tempPathLen
print(f"当前路径长度: {tempPathLen}, 耗时: {time.time() - start_time} s")
return path
def informed_rrt_star_planning(self, start, goal, animation=True):
start_time = time.time()
self.start = Node(start[0], start[1])
self.goal = Node(goal[0], goal[1])
self.node_list = [self.start]
cBest = float('inf')
path = None
cMin = math.sqrt(pow(self.start.x - self.goal.x, 2) + pow(self.start.y - self.goal.y, 2))
xCenter = np.array([[(self.start.x + self.goal.x) / 2.0], [(self.start.y + self.goal.y) / 2.0], [0]])
a1 = np.array([[(self.goal.x - self.start.x) / cMin], [(self.goal.y - self.start.y) / cMin], [0]])
e_theta = math.atan2(a1[1], a1[0])
C = np.array([[math.cos(e_theta), -math.sin(e_theta), 0],
[math.sin(e_theta), math.cos(e_theta), 0],
[0, 0, 1]])
for i in range(self.max_iter):
rnd = self.informed_sample(cBest, cMin, xCenter, C)
n_ind = self.get_nearest_list_index(self.node_list, rnd)
nearestNode = self.node_list[n_ind]
theta = math.atan2(rnd[1] - nearestNode.y, rnd[0] - nearestNode.x)
newNode = self.get_new_node(theta, n_ind, nearestNode)
noCollision = self.check_segment_collision(newNode.x, newNode.y, nearestNode.x, nearestNode.y)
if noCollision:
nearInds = self.find_near_nodes(newNode)
newNode = self.choose_parent(newNode, nearInds)
self.node_list.append(newNode)
self.rewire(newNode, nearInds)
if self.is_near_goal(newNode):
if self.check_segment_collision(newNode.x, newNode.y, self.goal.x, self.goal.y):
lastIndex = len(self.node_list) - 1
tempPath = self.get_final_course(lastIndex)
tempPathLen = self.get_path_len(tempPath)
if tempPathLen < cBest:
path = tempPath
cBest = tempPathLen
print(f"当前路径长度: {tempPathLen}, 耗时: {time.time() - start_time} s")
if animation:
self.draw_graph_informed_RRTStar(xCenter=xCenter, cBest=cBest, cMin=cMin, e_theta=e_theta, rnd=rnd, path=path)
return path
# RRT类内其他方法保持不变(省略)
def sample(self):
if random.randint(0, 100) > self.goal_sample_rate:
rnd = [random.uniform(self.min_rand, self.max_rand), random.uniform(self.min_rand, self.max_rand)]
else:
rnd = [self.goal.x, self.goal.y]
return rnd
def choose_parent(self, newNode, nearInds):
if len(nearInds) == 0:
return newNode
dList = []
for i in nearInds:
dx = newNode.x - self.node_list[i].x
dy = newNode.y - self.node_list[i].y
d = math.hypot(dx, dy)
theta = math.atan2(dy, dx)
if self.check_collision(self.node_list[i], theta, d):
dList.append(self.node_list[i].cost + d)
else:
dList.append(float('inf'))
minCost = min(dList)
minInd = nearInds[dList.index(minCost)]
if minCost == float('inf'):
print("最小代价为无穷大")
return newNode
newNode.cost = minCost
newNode.parent = minInd
return newNode
def find_near_nodes(self, newNode):
n_node = len(self.node_list)
r = 50.0 * math.sqrt((math.log(n_node) / n_node))
d_list = [(node.x - newNode.x) **2 + (node.y - newNode.y)** 2 for node in self.node_list]
near_inds = [d_list.index(i) for i in d_list if i <= r **2]
return near_inds
def informed_sample(self, cMax, cMin, xCenter, C):
if cMax < float('inf'):
r = [cMax / 2.0, math.sqrt(cMax** 2 - cMin **2) / 2.0, math.sqrt(cMax** 2 - cMin **2) / 2.0]
L = np.diag(r)
xBall = self.sample_unit_ball()
rnd = np.dot(np.dot(C, L), xBall) + xCenter
rnd = [rnd[(0, 0)], rnd[(1, 0)]]
else:
rnd = self.sample()
return rnd
@staticmethod
def sample_unit_ball():
a = random.random()
b = random.random()
if b < a:
a, b = b, a
sample = (b * math.cos(2 * math.pi * a / b), b * math.sin(2 * math.pi * a / b))
return np.array([[sample[0]], [sample[1]], [0]])
@staticmethod
def get_path_len(path):
pathLen = 0
for i in range(1, len(path)):
node1_x = path[i][0]
node1_y = path[i][1]
node2_x = path[i - 1][0]
node2_y = path[i - 1][1]
pathLen += math.sqrt((node1_x - node2_x)** 2 + (node1_y - node2_y) **2)
return pathLen
@staticmethod
def line_cost(node1, node2):
return math.sqrt((node1.x - node2.x)** 2 + (node1.y - node2.y) **2)
@staticmethod
def get_nearest_list_index(nodes, rnd):
dList = [(node.x - rnd[0])** 2 + (node.y - rnd[1]) **2 for node in nodes]
minIndex = dList.index(min(dList))
return minIndex
def get_new_node(self, theta, n_ind, nearestNode):
newNode = copy.deepcopy(nearestNode)
newNode.x += self.expand_dis * math.cos(theta)
newNode.y += self.expand_dis * math.sin(theta)
newNode.cost += self.expand_dis
newNode.parent = n_ind
return newNode
def is_near_goal(self, node):
d = self.line_cost(node, self.goal)
return d < self.expand_dis
def rewire(self, newNode, nearInds):
n_node = len(self.node_list)
for i in nearInds:
nearNode = self.node_list[i]
d = math.sqrt((nearNode.x - newNode.x)** 2 + (nearNode.y - newNode.y) **2)
s_cost = newNode.cost + d
if nearNode.cost > s_cost:
theta = math.atan2(newNode.y - nearNode.y, newNode.x - nearNode.x)
if self.check_collision(nearNode, theta, d):
nearNode.parent = n_node - 1
nearNode.cost = s_cost
@staticmethod
def distance_squared_point_to_segment(v, w, p):
if np.array_equal(v, w):
return (p - v).dot(p - v)
l2 = (w - v).dot(w - v)
t = max(0, min(1, (p - v).dot(w - v) / l2))
projection = v + t * (w - v)
return (p - projection).dot(p - projection)
@staticmethod
def cross_product(o, a, b):
return (a[0] - o[0]) * (b[1] - o[1]) - (a[1] - o[1]) * (b[0] - o[0])
def is_intersect(self, p1, p2, q1, q2):
d1 = self.cross_product(p1, p2, q1)
d2 = self.cross_product(p1, p2, q2)
d3 = self.cross_product(q1, q2, p1)
d4 = self.cross_product(q1, q2, p2)
return (d1, d2, d3, d4)
def check_segment_collision(self, x1, y1, x2, y2):
for a in self.obstacle_list:
if len(a) > 3:
(ox, oy, size_x, size_y) = a
V = (x1, y1)
W = (x2, y2)
A = (ox - size_x - self.car_radius, oy - size_y - self.car_radius)
B = (ox - size_x - self.car_radius, oy + size_y + self.car_radius)
C = (ox + size_x + self.car_radius, oy - size_y - self.car_radius)
D = (ox + size_x + self.car_radius, oy + size_y + self.car_radius)
d1, d2, d3, d4 = self.is_intersect(V, W, A, B)
if d1 * d2 <= 0 and d3 * d4 <= 0:
return False
d1, d2, d3, d4 = self.is_intersect(V, W, B, C)
if d1 * d2 <= 0 and d3 * d4 <= 0:
return False
d1, d2, d3, d4 = self.is_intersect(V, W, C, D)
if d1 * d2 <= 0 and d3 * d4 <= 0:
return False
d1, d2, d3, d4 = self.is_intersect(V, W, D, A)
if d1 * d2 <= 0 and d3 * d4 <= 0:
return False
else:
(ox, oy, size) = a
dd = self.distance_squared_point_to_segment(
np.array([x1, y1]), np.array([x2, y2]), np.array([ox, oy]))
if dd <= (size + self.car_radius) **2:
return False
return True
def check_collision(self, nearNode, theta, d):
tmpNode = copy.deepcopy(nearNode)
end_x = tmpNode.x + math.cos(theta) * d
end_y = tmpNode.y + math.sin(theta) * d
return self.check_segment_collision(tmpNode.x, tmpNode.y, end_x, end_y)
def get_final_course(self, lastIndex):
path = [[self.goal.x, self.goal.y]]
while self.node_list[lastIndex].parent is not None:
node = self.node_list[lastIndex]
path.append([node.x, node.y])
lastIndex = node.parent
path.append([self.start.x, self.start.y])
return path
def draw_graph_informed_RRTStar(self, xCenter=None, cBest=None, cMin=None, e_theta=None, rnd=None, path=None):
plt.gcf().canvas.mpl_connect('key_release_event', lambda event: [exit(0) if event.key == 'escape' else None])
if rnd is not None:
self.ax.plot(rnd[0], rnd[1], "^k")
if cBest != float('inf'):
self.plot_ellipse(xCenter, cBest, cMin, e_theta)
for node in self.node_list:
if node.parent is not None:
if node.x or node.y is not None:
self.ax.plot([node.x, self.node_list[node.parent].x], [node.y, self.node_list[node.parent].y], "-g")
for a in self.obstacle_list:
if len(a) > 3:
(ox, oy, size_x, size_y) = a
rectangle = patches.Rectangle((ox - size_x, oy - size_y), 2 * size_x, 2 * size_y, edgecolor='black', facecolor='black')
self.ax.add_artist(rectangle)
else:
(ox, oy, size) = a
circle = plt.Circle((ox, oy), size, color='black', fill=True)
self.ax.add_artist(circle)
self.ax.plot(self.start.x, self.start.y, "xr")
self.ax.plot(self.goal.x, self.goal.y, "xr")
if path is not None:
self.ax.plot([x for (x, y) in path], [y for (x, y) in path], '-r')
self.ax.set_aspect('equal')
plt.axis([-2, 18, -2, 15])
plt.grid(True)
plt.pause(0.01)
@staticmethod
def plot_ellipse(xCenter, cBest, cMin, e_theta):
a = math.sqrt(cBest** 2 - cMin **2) / 2.0
b = cBest / 2.0
angle = math.pi / 2.0 - e_theta
cx = xCenter[0]
cy = xCenter[1]
t = np.arange(0, 2 * math.pi + 0.1, 0.1)
x = [a * math.cos(it) for it in t]
y = [b * math.sin(it) for it in t]
rot = Rot.from_euler('z', -angle).as_matrix()[0:2, 0:2]
fx = rot @ np.array([x, y])
px = np.array(fx[0, :] + cx).flatten()
py = np.array(fx[1, :] + cy).flatten()
plt.plot(cx, cy, "xc")
plt.plot(px, py, "--c")
def draw_graph(self, rnd=None, path=None):
plt.gcf().canvas.mpl_connect('key_release_event', lambda event: [exit(0) if event.key == 'escape' else None])
if rnd is not None:
self.ax.plot(rnd.x, rnd.y, "^k")
for node in self.node_list:
if node.parent is not None:
if node.x or node.y is not None:
self.ax.plot([node.x, self.node_list[node.parent].x], [node.y, self.node_list[node.parent].y], "-g")
for a in self.obstacle_list:
if len(a) > 3:
(ox, oy, size_x, size_y) = a
rectangle = patches.Rectangle((ox - size_x, oy - size_y), 2 * size_x, 2 * size_y, edgecolor='black', facecolor='black')
self.ax.add_artist(rectangle)
else:
(ox, oy, size) = a
circle = plt.Circle((ox, oy), size, color='black', fill=True)
self.ax.add_artist(circle)
self.ax.plot(self.start.x, self.start.y, "xr")
self.ax.plot(self.goal.x, self.goal.y, "xr")
if path is not None:
self.ax.plot([x for (x, y) in path], [y for (x, y) in path], '-r')
self.ax.set_aspect('equal')
plt.axis([-2, 18, -2, 18])
plt.grid(True)
plt.pause(0.01)
# ----------------------------
# Token管理逻辑
# ----------------------------
def login():
global ACCESS_TOKEN, TOKEN_EXPIRE_TIME
try:
url = f"{AUTH_BASE_URL}/signin"
data = {"username": LOGIN_USERNAME, "password": LOGIN_PASSWORD}
response = requests.post(url, json=data, timeout=10)
if response.status_code == 200:
result = response.json()
data_content = result.get("data", {})
access_token_data = data_content.get("accessToken", {})
ACCESS_TOKEN = access_token_data.get("tokenValue")
print(f"获取到的ACCESS_TOKEN: {ACCESS_TOKEN[:20]}...")
# 直接设置1小时有效期,无需解析expiresAt
TOKEN_EXPIRE_TIME = int(time.time()) + 3600
print("✅ 登录成功,Token有效期强制设为1小时")
return True
else:
print(f"❌ 登录失败,状态码: {response.status_code}, 响应: {response.text}")
return False
except Exception as e:
print(f"❌ 登录请求异常: {str(e)}")
return False
def ensure_valid_token():
"""优化Token检查逻辑,剩余时间>60秒才有效"""
current_time = int(time.time())
# 计算剩余有效期(秒)
remaining_time = TOKEN_EXPIRE_TIME - current_time
if ACCESS_TOKEN and remaining_time > 60: # 剩余>60秒才认为有效
print(f"Token剩余有效期: {remaining_time}秒")
return True
print(f"🔄 Token已过期(剩余{remaining_time}秒)或无效,重新登录...")
return login()
# ----------------------------
# 任务获取与处理逻辑(核心修改)
# ----------------------------
def fetch_tasks(page=0):
"""从接口获取任务列表"""
if not ensure_valid_token():
print("❌ Token无效,无法获取任务列表")
return []
url = f"{API_BASE_URL}/all"
headers = {
"Authorization": f"Bearer {ACCESS_TOKEN}"
}
params = {
"page": page,
"pageSize": PAGE_SIZE
}
try:
response = requests.get(url, headers=headers, params=params, timeout=10)
if response.status_code == 200:
result = response.json()
if result.get("code") == 200:
return result.get("data", [])
else:
print(f"❌ 获取任务列表失败: {result.get('msg')}")
return []
elif response.status_code == 401:
print("🔄 Token过期,尝试重新登录")
if login():
return fetch_tasks(page)
else:
return []
else:
print(f"❌ 获取任务列表失败,状态码: {response.status_code}")
return []
except Exception as e:
print(f"❌ 获取任务列表异常: {str(e)}")
return []
def process_task(task):
"""处理任务,严格遵循生命周期规则"""
try:
# 1. 提取任务基础信息
task_id = task.get("id")
project_id = task.get("projectId")
mode = task.get("mode", 0) # 模式(0=batch)
current_state = task.get("state", 0) # 当前状态
submit_at = task.get("submitAt", 0) # 提交时间(毫秒)
timeout = task.get("timeout", 0) # 超时时间(毫秒)
current_time = int(time.time() * 1000) # 当前时间(毫秒)
# 2. 基础校验:仅处理batch模式(mode=0)且pending状态(state=0)的任务
if mode != BATCH_MODE:
print(f"任务 {task_id} 不是batch模式(mode={mode}),跳过")
return
if current_state != 0: # 仅pending(state=0)可处理
print(f"任务 {task_id} 非pending状态(state={current_state}),跳过")
return
if not task_id or not project_id:
print(f"任务 {task_id} 缺少id或projectId,跳过")
return
# 3. 超时判断(核心规则1)
is_timeout = False
if timeout > 0 and (submit_at + timeout < current_time):
is_timeout = True
print(f"任务 {task_id} 已超时({submit_at + timeout} < {current_time}),标注为超时")
output = {"error": "任务超时"}
# 超时任务直接标注为state=3,且后续不可修改
submit_result(task_data=task, output=output, state=3)
return
# 4. 未超时任务:执行计算并更新状态(核心规则2)
print(f"任务 {task_id} 未超时,开始计算...")
input_str = task.get("input", "")
input_params = json.loads(input_str) if input_str else {}
# 执行路径规划(batch模式计算)
start = input_params.get("start", [0, 0])
goal = input_params.get("goal", [10, 10])
obstacle_list = input_params.get("obstacleList", [])
rrt = RRT(randArea=[-2, 18], obstacleList=obstacle_list, maxIter=100)
path = rrt.rrt_planning(start=start, goal=goal, animation=show_animation)
path = path[::-1] # 反转
# 5. 根据计算结果更新为not pending状态
if path:
print(f"任务 {task_id} 计算成功")
output = {"path": path, "length": len(path)}
submit_result(task_data=task, output=output, state=1) # finished
else:
print(f"任务 {task_id} 计算失败")
output = {"error": "未找到有效路径"}
submit_result(task_data=task, output=output, state=2) # fail
except Exception as e:
print(f"任务 {task_id} 处理异常: {str(e)}")
# 异常任务更新为fail(not pending)
submit_result(task_data=task, output={"error": str(e)}, state=2)
def submit_result(task_data, output, state):
"""提交结果:machId设为null,output转为字符串"""
if not ensure_valid_token():
print("❌ Token无效,提交失败")
return False
# 校验:仅pending状态(state=0)可更新
current_state = task_data.get("state", 0)
# 所有not pending状态(1/2/3/4)均需校验当前状态为0
if state in (1, 2, 3, 4) and current_state != 0:
print(f"任务 {task_data.get('id')} 已非pending状态(当前state={current_state}),无法更新为state={state}")
return False
# 构造提交参数
url = f"{API_BASE_URL}/update"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {ACCESS_TOKEN}"
}
try:
output_str = json.dumps(output, ensure_ascii=False)
# 二次校验:确保能被正确解析(避免服务器解析失败)
json.loads(output_str)
except Exception as e:
print(f"❌ output格式错误: {str(e)}")
output_str = json.dumps({"error": "输出格式错误"})
# 动态计算finishAt
current_time = int(time.time() * 1000)
submit_at = task_data.get("submitAt", 0) # 任务提交时间(毫秒)
timeout = task_data.get("timeout", 0) # 超时时限(毫秒)
if state == 3: # 超时任务:finishAt = 提交时间 + 超时时限
finish_at = submit_at + timeout if timeout > 0 else current_time
else: # 1=完成, 2=失败, 4=取消:finishAt = 当前时间
finish_at = current_time
payload = {
"id": task_data.get("id", ""), # 任务ID
# "id": int(task_data.get("id", 0)) if task_data.get("id") else 0, # 转为整数
"submitAt": task_data.get("submitAt", int(time.time() * 1000)), # 提交时间
"finishAt": finish_at, # 完成时间
"algName": task_data.get("algName", ALG_NAME), # 算法名称
"mode": task_data.get("mode", 0), # 模式(batch=0)
"timeout": task_data.get("timeout", 0), # 超时时间
"projectId": task_data.get("projectId", ""), # 项目ID
"machId": None, # 保持为null
"input":task_data.get("input", ""),
"output": output_str, # 输出结果(确保可解析)
"state": state, # 状态(使用服务器认可的值,如1=成功,2=失败)
}
try:
print(f"提交参数: {json.dumps(payload, ensure_ascii=False, indent=2)}")
response = requests.post(url, json=payload, headers=headers, timeout=10)
print(f"任务 {task_data.get('id')} 提交响应: {response.status_code}")
print(f"响应内容: {response.text}")
if response.status_code == 400:
print("⚠️ 提交参数错误,检查字段格式")
elif response.status_code == 401:
print("🔄 Token过期,重新登录后重试")
if login():
headers["Authorization"] = f"Bearer {ACCESS_TOKEN}"
response = requests.post(url, json=payload, headers=headers, timeout=10)
print(f"重试响应: {response.status_code}")
return response.status_code == 200
except Exception as e:
print(f"提交失败: {str(e)}")
return False
def main():
if not login():
print("❌ 登录失败,无法启动")
return
print("🚀 启动batch模式任务处理器...")
try:
while True:
for page in range(10):
print(f"获取第 {page} 页任务...")
tasks = fetch_tasks(page)
if not tasks:
break
# 筛选:仅batch模式(mode=0)且pending(state=0)的任务
pending_tasks = [
task for task in tasks
if task.get("mode") == BATCH_MODE
and task.get("state") == 0
and task.get("algName") == ALG_NAME
]
for task in pending_tasks:
process_task(task)
print(f"等待 {TASK_CHECK_INTERVAL} 秒后检查任务...")
time.sleep(TASK_CHECK_INTERVAL)
except KeyboardInterrupt:
print("\n🛑 服务停止")
if __name__ == '__main__':
main()