238. Product of Array Except Self

本文介绍了一种解决数组连乘问题的方法,通过两次遍历数组实现:第一次正向遍历记录前缀连乘结果;第二次反向遍历,利用之前的结果计算每个元素排除自身外的连乘值。

自己想的就差一点点,好可惜!!

正着走一遍,每一格里面存的是,从0到这一位之前的数连乘结果。

再反着走一遍,用各自里本身的值,乘以从后往前除它之外的连乘结果,然后更新连乘的结果

 

 1     public int[] productExceptSelf(int[] nums) {
 2         int len = nums.length;
 3         int[] res = new int[len];
 4         res[0] = 1;
 5         for(int i = 1; i < len; i++) {
 6             res[i] = nums[i-1] * res[i-1];
 7         }
 8         int right = nums[len-1];
 9         for(int i = len - 2; i >= 0; i--) {
10             res[i] *= right;
11             right *= nums[i];
12         }
13         return res;
14     }

 

转载于:https://www.cnblogs.com/warmland/p/5716929.html

import sys import os import pandas as pd from PyQt5.QtWidgets import * from PyQt5.QtCore import * from PyQt5.QtGui import * from ollama import Client import numpy as np import re import torch import torch.nn as nn # LSTM模型类 class LSTMModel(nn.Module): def __init__(self, input_size, hidden_size, num_layers, output_size): super(LSTMModel, self).__init__() self.hidden_size = hidden_size self.num_layers = num_layers self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) self.fc = nn.Linear(hidden_size, output_size) def forward(self, x): h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device) c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device) out, _ = self.lstm(x, (h0, c0)) out = self.fc(out[:, -1, :]) return out # 向量相似度计算工具 class VectorSimilarity: @staticmethod def cosine_similarity(vec1, vec2): if len(vec1) == 0 or len(vec2) == 0: return 0 dot_product = np.dot(vec1, vec2) norm_vec1 = np.linalg.norm(vec1) norm_vec2 = np.linalg.norm(vec2) if norm_vec1 == 0 or norm_vec2 == 0: return 0 return dot_product / (norm_vec1 * norm_vec2) @staticmethod def euclidean_distance(vec1, vec2): return np.linalg.norm(vec1 - vec2) @staticmethod def combined_similarity(vec1, vec2): cosine = VectorSimilarity.cosine_similarity(vec1, vec2) euclidean = VectorSimilarity.euclidean_distance(vec1, vec2) if euclidean > 0: euclidean_sim = 1 / (1 + euclidean) else: euclidean_sim = 1 return 0.7 * cosine + 0.3 * euclidean_sim # 知识库处理类 class KnowledgeBase: def __init__(self): self.texts = {} self.vectors = {} self.metadata = {} self.similarity_threshold = 0.3 self.file_paths = {} def load_from_excel(self, file_path, kb_type): try: df = pd.read_excel(file_path) vector_cols = [col for col in df.columns if col.startswith('向量化特征_')] if not vector_cols: raise ValueError("未找到向量化特征列") self.texts[kb_type] = [] self.vectors[kb_type] = [] self.metadata[kb_type] = [] for _, row in df.iterrows(): text = str(row.get('文本块', '')) if not text: continue vector = row[vector_cols].values.astype(np.float32) metadata = { '来源': row.get('来源', ''), '类别': row.get('类别', ''), } self.texts[kb_type].append(text) self.vectors[kb_type].append(vector) self.metadata[kb_type].append(metadata) self._calculate_similarity_threshold(kb_type) self.file_paths[kb_type] = file_path return len(self.texts[kb_type]) except Exception as e: print(f"加载知识库失败: {e}") return 0 def _calculate_similarity_threshold(self, kb_type): if not self.vectors[kb_type] or len(self.vectors[kb_type]) < 2: return similarities = [] for i in range(min(100, len(self.vectors[kb_type]))): for j in range(i + 1, min(100, len(self.vectors[kb_type]))): sim = VectorSimilarity.combined_similarity(self.vectors[kb_type][i], self.vectors[kb_type][j]) similarities.append(sim) if similarities: self.similarity_threshold = max(0.2, np.mean(similarities) * 0.7) print(f"相似度阈值: {self.similarity_threshold:.4f}") def search(self, query_vector, kb_type, top_k=10): if kb_type not in self.vectors or not self.vectors[kb_type]: return [] results = [] for i, vector in enumerate(self.vectors[kb_type]): sim = VectorSimilarity.combined_similarity(query_vector, vector) if sim >= self.similarity_threshold: results.append((i, sim)) results.sort(key=lambda x: x[1], reverse=True) return results[:top_k] def add_entry(self, kb_type, text, source, category, vector): if kb_type not in self.texts: self.texts[kb_type] = [] self.vectors[kb_type] = [] self.metadata[kb_type] = [] self.texts[kb_type].append(text) self.vectors[kb_type].append(vector) self.metadata[kb_type].append({ '来源': source, '类别': category }) self.save_to_excel(kb_type) def save_to_excel(self, kb_type): if kb_type not in self.file_paths: return file_path = self.file_paths[kb_type] df = pd.DataFrame({ '文本块': self.texts[kb_type], **{f'向量化特征_{i}': [vec[i] for vec in self.vectors[kb_type]] for i in range(len(self.vectors[kb_type][0]))}, '来源': [meta['来源'] for meta in self.metadata[kb_type]], '类别': [meta['类别'] for meta in self.metadata[kb_type]] }) df.to_excel(file_path, index=False) def edit_entry(self, kb_type, index, text, source, category, vector): if kb_type in self.texts and 0 <= index < len(self.texts[kb_type]): self.texts[kb_type][index] = text self.vectors[kb_type][index] = vector self.metadata[kb_type][index] = { '来源': source, '类别': category } self.save_to_excel(kb_type) # AI处理线程 class AIWorker(QThread): thinking_signal = pyqtSignal(str) # 发送思考过程 answer_signal = pyqtSignal(str) # 发送最终回答 finish_signal = pyqtSignal() def __init__(self, client, query, knowledge_base, kb_type): super().__init__() self.client = client self.query = query self.knowledge_base = knowledge_base self.kb_type = kb_type def run(self): try: # 发送思考过程开始的信号 self.thinking_signal.emit("<b>知识库检索中...</b>") # 模拟生成查询向量 query_vector = np.random.rand(len(self.knowledge_base.vectors[self.kb_type][0]) if self.knowledge_base.vectors[self.kb_type] else 50) # 执行搜索 search_results = self.knowledge_base.search(query_vector, self.kb_type, top_k=5) if not search_results: self.thinking_signal.emit("<p style=\"color:#888888;\">未找到匹配的知识库内容</p>") thinking_process = "无" else: # 构建HTML格式的思考过程 thinking_html = "<p><b>匹配到的知识库内容:</b></p><ul>" for i, (idx, sim) in enumerate(search_results): text = self.knowledge_base.texts[self.kb_type][idx] metadata = self.knowledge_base.metadata[self.kb_type][idx] thinking_html += f"<li><b>匹配项 {i + 1} (相似度: {sim:.4f})</b><br>" thinking_html += f"<span style=\"color:#6c757d;\">来源: {metadata['来源']} | 类别: {metadata['类别']}</span><br>" # 提取前300个字符作为摘要,并保留段落结构 summary = text[:300] + ('...' if len(text) > 300 else '') # 将换行符转换为<br>标签 summary = summary.replace('\n', '<br>') thinking_html += f"{summary}</li><br>" thinking_html += "</ul>" self.thinking_signal.emit(thinking_html) # 构建LLM提示词中的思考过程 thinking_process = "以下是知识库中的相关参考资料:\n" for i, (idx, sim) in enumerate(search_results): text = self.knowledge_base.texts[self.kb_type][idx] metadata = self.knowledge_base.metadata[self.kb_type][idx] thinking_process += f"\n[{i + 1}] 相似度: {sim:.4f}\n" thinking_process += f"来源: {metadata['来源']} | 类别: {metadata['类别']}\n" # 提取前200个字符作为摘要 summary = text[:200] + ('...' if len(text) > 200 else '') thinking_process += f"{summary}\n\n" # 构建提示词 prompt = f""" 你是一位专业的涂装工程师。用户问题:{self.query}。 {thinking_process} 请根据上述资料提供准确、专业的回答。 如果资料不足,请补充涂装领域的通用知识和最佳实践。 请确保回答条理清晰,适当分段分点。 """ # 发送开始生成回答的信号 self.answer_signal.emit("<b>正在生成回答...</b>") # 调用模型生成回答 answer_html = "" stream = self.client.chat( model='deepseek-r1:1.5b', messages=[{"role": "user", "content": prompt}], stream=True ) for chunk in stream: content = chunk['message']['content'] answer_html += content # 处理回答内容,添加分段分点格式 formatted_answer = self._format_answer(answer_html) self.answer_signal.emit(formatted_answer) except Exception as e: self.answer_signal.emit(f"<p style=\"color:#FF0000;\">[Error] {str(e)}</p>") finally: self.finish_signal.emit() def _format_answer(self, answer): # 简单的文本格式化处理 # 将段落分隔(空行)转换为<p>标签 paragraphs = answer.split('\n\n') formatted = "" for para in paragraphs: para = para.strip() if not para: continue # 处理列表项(如果以数字+点或短横线开头) if re.match(r'^\d+\.', para): # 有序列表 if not formatted.endswith('</ol>'): formatted += '<ol>' else: formatted = formatted[:-5] # 移除最后的</ol>标签以便继续添加 # 提取序号和内容 match = re.match(r'^(\d+\.)\s*(.*)', para) if match: formatted += f'<li><b>{match.group(1)}</b> {match.group(2)}</li>' else: formatted += f'<li>{para}</li>' formatted += '</ol>' elif re.match(r'^[-*•]', para): # 无序列表 if not formatted.endswith('</ul>'): formatted += '<ul>' else: formatted = formatted[:-5] # 移除最后的</ul>标签 # 提取标记和内容 match = re.match(r'^[-*•]\s*(.*)', para) if match: formatted += f'<li>{match.group(1)}</li>' else: formatted += f'<li>{para}</li>' formatted += '</ul>' else: # 普通段落 formatted += f'<p>{para}</p>' return formatted # 主窗口 class PaintChatWindow(QWidget): def __init__(self): super().__init__() self.knowledge_base = KnowledgeBase() self.client = Client(host="http://localhost:11435") self.current_kb_type = None self.init_ui() def init_ui(self): self.setWindowTitle("涂装知识助手") self.setGeometry(100, 100, 1200, 800) # 主布局 main_layout = QVBoxLayout() # 顶部状态栏 status_bar = QWidget() status_layout = QHBoxLayout(status_bar) self.status_label = QLabel("知识库未加载") self.status_label.setStyleSheet("color: #888888; font-size: 12px; padding: 5px;") status_layout.addWidget(self.status_label) main_layout.addWidget(status_bar) # 知识库选择区域 kb_selection_layout = QHBoxLayout() kb_types = ["涂料信息库", "涂装设备库", "涂装工艺库", "涂装环境库", "行业标准与法规库"] for kb_type in kb_types: btn = QPushButton(kb_type) btn.setStyleSheet(""" QPushButton { background-color: #6c757d; color: white; font-size: 12px; padding: 5px; border-radius: 3px; } QPushButton:hover { background-color: #5a6268; } """) btn.clicked.connect(lambda _, t=kb_type: self.select_knowledge_base(t)) kb_selection_layout.addWidget(btn) main_layout.addLayout(kb_selection_layout) # 中间内容区域 - 分为思考过程和最终回答两栏 content_splitter = QSplitter(Qt.Horizontal) # 思考过程区域 self.thinking_area = QTextEdit() self.thinking_area.setReadOnly(True) self.thinking_area.setStyleSheet(""" QTextEdit { background-color: #f8f9fa; font-family: SimHei, sans-serif; font-size: 14px; padding: 15px; border: 1px solid #e9ecef; border-radius: 4px; } """) self.thinking_area.setHtml("<b>思考过程将显示在这里...</b>") content_splitter.addWidget(self.thinking_area) # 最终回答区域 self.answer_area = QTextEdit() self.answer_area.setReadOnly(True) self.answer_area.setStyleSheet(""" QTextEdit { background-color: #ffffff; font-family: SimHei, sans-serif; font-size: 14px; padding: 15px; border: 1px solid #e9ecef; border-radius: 4px; } """) self.answer_area.setHtml("<b>回答将显示在这里...</b>") content_splitter.addWidget(self.answer_area) # 设置两栏的初始大小比例 content_splitter.setSizes([400, 800]) main_layout.addWidget(content_splitter) # 底部控制区域 bottom_layout = QHBoxLayout() # 文件加载区域 file_layout = QVBoxLayout() self.file_label = QLabel("未选择知识库文件") self.file_label.setStyleSheet("color: #6c757d; font-size: 12px;") file_layout.addWidget(self.file_label) self.load_btn = QPushButton("加载知识库") self.load_btn.setStyleSheet(""" QPushButton { background-color: #6c757d; color: white; font-size: 12px; padding: 5px; border-radius: 3px; } QPushButton:hover { background-color: #5a6268; } """) self.load_btn.clicked.connect(self.load_knowledge_base) file_layout.addWidget(self.load_btn) bottom_layout.addLayout(file_layout, 1) # 输入区域 input_layout = QVBoxLayout() self.input_box = QTextEdit() self.input_box.setMaximumHeight(60) self.input_box.setPlaceholderText("输入您的问题...") self.input_box.setStyleSheet(""" QTextEdit { border: 1px solid #ced4da; border-radius: 4px; padding: 8px; font-family: SimHei, sans-serif; font-size: 14px; } """) input_layout.addWidget(self.input_box) self.send_btn = QPushButton("提问") self.send_btn.setStyleSheet(""" QPushButton { background-color: #007bff; color: white; font-size: 14px; padding: 8px; border-radius: 4px; } QPushButton:hover { background-color: #0069d9; } """) self.send_btn.clicked.connect(self.send_message) input_layout.addWidget(self.send_btn) bottom_layout.addLayout(input_layout, 3) # 添加、查看和编辑按钮 action_layout = QHBoxLayout() self.add_btn = QPushButton("添加知识库内容") self.add_btn.setStyleSheet(""" QPushButton { background-color: #28a745; color: white; font-size: 12px; padding: 5px; border-radius: 3px; } QPushButton:hover { background-color: #218838; } """) self.add_btn.clicked.connect(self.add_knowledge_entry) action_layout.addWidget(self.add_btn) self.view_btn = QPushButton("查看知识库内容") self.view_btn.setStyleSheet(""" QPushButton { background-color: #17a2b8; color: white; font-size: 12px; padding: 5px; border-radius: 3px; } QPushButton:hover { background-color: #138496; } """) self.view_btn.clicked.connect(self.view_knowledge_entries) action_layout.addWidget(self.view_btn) self.edit_btn = QPushButton("编辑知识库内容") self.edit_btn.setStyleSheet(""" QPushButton { background-color: #ffc107; color: white; font-size: 12px; padding: 5px; border-radius: 3px; } QPushButton:hover { background-color: #e0a800; } """) self.edit_btn.clicked.connect(self.edit_knowledge_entry) action_layout.addWidget(self.edit_btn) bottom_layout.addLayout(action_layout, 2) main_layout.addLayout(bottom_layout) self.setLayout(main_layout) def select_knowledge_base(self, kb_type): self.current_kb_type = kb_type if kb_type in self.knowledge_base.texts: self.status_label.setText(f"当前知识库: {kb_type} ({len(self.knowledge_base.texts[kb_type])} 条记录)") self.status_label.setStyleSheet("color: #28a745; font-size: 12px; padding: 5px;") else: self.status_label.setText(f"请加载 {kb_type} 知识库") self.status_label.setStyleSheet("color: #dc3545; font-size: 12px; padding: 5px;") def load_knowledge_base(self): if not self.current_kb_type: QMessageBox.warning(self, "警告", "请先选择知识库类型") return file_path, _ = QFileDialog.getOpenFileName( self, f"选择 {self.current_kb_type} 文件", "", "Excel Files (*.xlsx *.xls)" ) if file_path: self.file_label.setText(f"已加载: {os.path.basename(file_path)}") self.status_label.setText("正在加载知识库...") self.status_label.setStyleSheet("color: #007bff; font-size: 12px; padding: 5px;") # 在单独线程中加载知识库 QTimer.singleShot(0, lambda: self._load_knowledge_base_thread(file_path, self.current_kb_type)) def _load_knowledge_base_thread(self, file_path, kb_type): try: count = self.knowledge_base.load_from_excel(file_path, kb_type) if count > 0: self.status_label.setText(f"{kb_type} 已加载 ({count} 条记录)") self.status_label.setStyleSheet("color: #28a745; font-size: 12px; padding: 5px;") self.thinking_area.setHtml(f"<b>{kb_type} 加载成功</b>: 共{count}条知识条目") self.answer_area.setHtml("<b>回答将显示在这里...</b>") else: self.status_label.setText(f"{kb_type} 加载失败") self.status_label.setStyleSheet("color: #dc3545; font-size: 12px; padding: 5px;") except Exception as e: self.status_label.setText(f"{kb_type} 加载错误: {str(e)}") self.status_label.setStyleSheet("color: #dc3545; font-size: 12px; padding: 5px;") def send_message(self): user_input = self.input_box.toPlainText().strip() if not user_input: return if not self.current_kb_type: QMessageBox.warning(self, "警告", "请先选择知识库类型") return if self.current_kb_type not in self.knowledge_base.texts or not self.knowledge_base.texts[self.current_kb_type]: QMessageBox.warning(self, "警告", f"请先加载 {self.current_kb_type} 知识库") return # 清空之前的回答和思考过程 self.thinking_area.setHtml(f"<b>用户问题:</b> {user_input}<br><br><b>思考过程:</b>") self.answer_area.setHtml("<b>正在生成回答...</b>") self.input_box.clear() # 创建AI处理线程 self.ai_worker = AIWorker(self.client, user_input, self.knowledge_base, self.current_kb_type) self.ai_worker.thinking_signal.connect(self.update_thinking) self.ai_worker.answer_signal.connect(self.update_answer) self.ai_worker.finish_signal.connect(self.on_ai_finished) self.ai_worker.start() def update_thinking(self, message): # 将思考过程追加到思考区域 cursor = self.thinking_area.textCursor() cursor.movePosition(QTextCursor.End) cursor.insertHtml(f"<br><br>{message}") self.thinking_area.setTextCursor(cursor) self.thinking_area.ensureCursorVisible() def update_answer(self, message): # 更新回答区域 self.answer_area.setHtml(message) def on_ai_finished(self): pass def add_knowledge_entry(self): if not self.current_kb_type: QMessageBox.warning(self, "警告", "请先选择知识库类型") return dialog = QDialog(self) dialog.setWindowTitle(f"添加 {self.current_kb_type} 条目") layout = QVBoxLayout() text_label = QLabel("文本块:") text_input = QTextEdit() layout.addWidget(text_label) layout.addWidget(text_input) source_label = QLabel("来源:") source_input = QLineEdit() layout.addWidget(source_label) layout.addWidget(source_input) category_label = QLabel("类别:") category_input = QLineEdit() layout.addWidget(category_label) layout.addWidget(category_input) vector_label = QLabel("向量化特征 (以逗号分隔):") vector_input = QLineEdit() layout.addWidget(vector_label) layout.addWidget(vector_input) button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel) button_box.accepted.connect(dialog.accept) button_box.rejected.connect(dialog.reject) layout.addWidget(button_box) dialog.setLayout(layout) if dialog.exec_() == QDialog.Accepted: text = text_input.toPlainText() source = source_input.text() category = category_input.text() vector_str = vector_input.text() try: vector = np.array([float(x) for x in vector_str.split(',')], dtype=np.float32) self.knowledge_base.add_entry(self.current_kb_type, text, source, category, vector) QMessageBox.information(self, "成功", "条目已添加到知识库") except ValueError: QMessageBox.warning(self, "错误", "向量化特征输入无效,请输入有效的浮点数,以逗号分隔") def view_knowledge_entries(self): if not self.current_kb_type: QMessageBox.warning(self, "警告", "请先选择知识库类型") return if self.current_kb_type not in self.knowledge_base.texts or not self.knowledge_base.texts[self.current_kb_type]: QMessageBox.warning(self, "警告", f"请先加载 {self.current_kb_type} 知识库") return dialog = QDialog(self) dialog.setWindowTitle(f"{self.current_kb_type} 内容") layout = QVBoxLayout() list_widget = QListWidget() for i, text in enumerate(self.knowledge_base.texts[self.current_kb_type]): metadata = self.knowledge_base.metadata[self.current_kb_type][i] item_text = f"[{i + 1}] 来源: {metadata['来源']} | 类别: {metadata['类别']}\n{text[:200]}" list_widget.addItem(item_text) layout.addWidget(list_widget) dialog.setLayout(layout) dialog.exec_() def edit_knowledge_entry(self): if not self.current_kb_type: QMessageBox.warning(self, "警告", "请先选择知识库类型") return if self.current_kb_type not in self.knowledge_base.texts or not self.knowledge_base.texts[self.current_kb_type]: QMessageBox.warning(self, "警告", f"请先加载 {self.current_kb_type} 知识库") return dialog = QDialog(self) dialog.setWindowTitle(f"编辑 {self.current_kb_type} 条目") layout = QVBoxLayout() index_label = QLabel("请输入要编辑的条目编号 (从1开始):") index_input = QLineEdit() layout.addWidget(index_label) layout.addWidget(index_input) button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel) button_box.accepted.connect(dialog.accept) button_box.rejected.connect(dialog.reject) layout.addWidget(button_box) dialog.setLayout(layout) if dialog.exec_() == QDialog.Accepted: try: index = int(index_input.text()) - 1 if 0 <= index < len(self.knowledge_base.texts[self.current_kb_type]): text = self.knowledge_base.texts[self.current_kb_type][index] metadata = self.knowledge_base.metadata[self.current_kb_type][index] vector = self.knowledge_base.vectors[self.current_kb_type][index] edit_dialog = QDialog(self) edit_dialog.setWindowTitle(f"编辑 {self.current_kb_type} 条目 {index + 1}") edit_layout = QVBoxLayout() text_label = QLabel("文本块:") text_input = QTextEdit() text_input.setPlainText(text) edit_layout.addWidget(text_label) edit_layout.addWidget(text_input) source_label = QLabel("来源:") source_input = QLineEdit() source_input.setText(metadata['来源']) edit_layout.addWidget(source_label) edit_layout.addWidget(source_input) category_label = QLabel("类别:") category_input = QLineEdit() category_input.setText(metadata['类别']) edit_layout.addWidget(category_label) edit_layout.addWidget(category_input) vector_label = QLabel("向量化特征 (以逗号分隔):") vector_input = QLineEdit() vector_str = ','.join([str(x) for x in vector]) vector_input.setText(vector_str) edit_layout.addWidget(vector_label) edit_layout.addWidget(vector_input) edit_button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel) edit_button_box.accepted.connect(edit_dialog.accept) edit_button_box.rejected.connect(edit_dialog.reject) edit_layout.addWidget(edit_button_box) edit_dialog.setLayout(edit_layout) if edit_dialog.exec_() == QDialog.Accepted: new_text = text_input.toPlainText() new_source = source_input.text() new_category = category_input.text() new_vector_str = vector_input.text() try: new_vector = np.array([float(x) for x in new_vector_str.split(',')], dtype=np.float32) self.knowledge_base.edit_entry(self.current_kb_type, index, new_text, new_source, new_category, new_vector) QMessageBox.information(self, "成功", "条目已更新到知识库") except ValueError: QMessageBox.warning(self, "错误", "向量化特征输入无效,请输入有效的浮点数,以逗号分隔") else: QMessageBox.warning(self, "错误", "输入的条目编号无效") except ValueError: QMessageBox.warning(self, "错误", "请输入有效的整数作为条目编号") if __name__ == "__main__": app = QApplication(sys.argv) app.setFont(QFont("SimHei")) window = PaintChatWindow() window.show() sys.exit(app.exec_())把这段代码里的“相似度”全部修改为“匹配度”,ollama调用窗口11435修改为11434,其余内容不变,完整代码给我
06-22
lass RRTStar3D: def __init__(self, start, goal, builds, bounds, max_iter=RRT_MAX_ITER, step_size=RRT_STEP, neighbor_radius=RRT_NEIGHBOR_RADIUS): self.start = np.array(start) self.goal = np.array(goal) # Pre-calculate builds with safety height buffer self.builds_with_safety = builds.copy() self.builds_with_safety[:, 4] -= SAFE_HEIGHT # Decrease zmin self.builds_with_safety[:, 5] += SAFE_HEIGHT # Increase zmax # Ensure zmin is not negative if SAFE_HEIGHT is large self.builds_with_safety[:, 4] = np.maximum(0, self.builds_with_safety[:, 4]) self.bounds = np.array(bounds) # Ensure bounds is numpy array self.max_iter = max_iter self.step_size = step_size self.neighbor_radius = neighbor_radius self.nodes = [self.start] self.parent = {tuple(self.start): None} self.cost = {tuple(self.start): 0.0} # Initialize KDTree with the start node self.kdtree = cKDTree(np.array([self.start])) # Ensure it's a 2D array def sample(self): # Biased sampling towards goal occasionally if np.random.rand() < 0.1: # 10% chance to sample goal return self.goal # Sample within bounds return np.random.uniform(self.bounds[:, 0], self.bounds[:, 1]) def nearest(self, q): _, idx = self.kdtree.query(q) # Handle case where KDTree might have only one node initially if isinstance(idx, (int, np.integer)): return self.nodes[idx] else: # Should not happen if tree has >= 1 node, but safety check return self.nodes[0] def steer(self, q_near, q_rand): delta = q_rand - q_near dist = np.linalg.norm(delta) if dist == 0: # Avoid division by zero return q_near ratio = self.step_size / dist if ratio >= 1.0: return q_rand return q_near + delta * ratio def near_neighbors(self, q_new): # Ensure nodes list is not empty before querying if not self.nodes: return [] # Ensure kdtree has points before querying if self.kdtree.n == 0: return [] # Use query_ball_point which is efficient for radius searches indices = self.kdtree.query_ball_point(q_new, self.neighbor_radius) # Filter out the index of q_new itself if it's already in nodes (might happen during rewiring) q_new_tuple = tuple(q_new) neighbors = [] for i in indices: # Check bounds and ensure it's not the node itself if already added # This check might be redundant if q_new isn't added before calling this if i < len(self.nodes): # Ensure index is valid node = self.nodes[i] if tuple(node) != q_new_tuple: neighbors.append(node) return neighbors def plan(self): for i in range(self.max_iter): q_rand = self.sample() # Check if nodes list is empty (shouldn't happen after init) if not self.nodes: print("Warning: Node list empty during planning.") continue # Or handle appropriately q_near = self.nearest(q_rand) q_new = self.steer(q_near, q_rand) # Check collision for the new segment using the pre-calculated safe builds if check_segment_collision(q_near, q_new, self.builds_with_safety) > 0: continue # If collision-free, add the node and update KD-Tree periodically q_new_tuple = tuple(q_new) q_near_tuple = tuple(q_near) # Choose parent with minimum cost among neighbors min_cost = self.cost[q_near_tuple] + np.linalg.norm(q_new - q_near) best_parent_node = q_near neighbors = self.near_neighbors(q_new) # Find neighbors first for q_neighbor in neighbors: q_neighbor_tuple = tuple(q_neighbor) # Check connectivity collision if check_segment_collision(q_neighbor, q_new, self.builds_with_safety) == 0: new_cost = self.cost[q_neighbor_tuple] + np.linalg.norm(q_new - q_neighbor) if new_cost < min_cost: min_cost = new_cost best_parent_node = q_neighbor # Add the new node with the best parent found self.nodes.append(q_new) q_best_parent_tuple = tuple(best_parent_node) self.parent[q_new_tuple] = q_best_parent_tuple self.cost[q_new_tuple] = min_cost # Rebuild KDTree periodically if len(self.nodes) % KD_REBUILD_EVERY == 0 or i == self.max_iter - 1: # Important: Ensure nodes is a list of arrays before creating KDTree if self.nodes: # Check if nodes is not empty self.kdtree = cKDTree(np.array(self.nodes)) # Rewire neighbors to go through q_new if it provides a shorter path for q_neighbor in neighbors: q_neighbor_tuple = tuple(q_neighbor) # Check if rewiring through q_new is shorter and collision-free cost_via_new = min_cost + np.linalg.norm(q_neighbor - q_new) if cost_via_new < self.cost[q_neighbor_tuple]: if check_segment_collision(q_new, q_neighbor, self.builds_with_safety) == 0: self.parent[q_neighbor_tuple] = q_new_tuple self.cost[q_neighbor_tuple] = cost_via_new # Check if goal is reached if np.linalg.norm(q_new - self.goal) < self.step_size: # Check final segment collision if check_segment_collision(q_new, self.goal, self.builds_with_safety) == 0: goal_tuple = tuple(self.goal) self.nodes.append(self.goal) # Add goal node self.parent[goal_tuple] = q_new_tuple self.cost[goal_tuple] = min_cost + np.linalg.norm(self.goal - q_new) print(f"RRT*: Goal reached at iteration {i+1}") # Rebuild KDTree one last time if goal is reached self.kdtree = cKDTree(np.array(self.nodes)) break # Exit planning loop else: # Loop finished without reaching goal condition print(f"RRT*: Max iterations ({self.max_iter}) reached. Connecting nearest node to goal.") # Find node closest to goal among existing nodes if not self.nodes: print("Error: No nodes generated by RRT*.") return None # Or raise error nodes_arr = np.array(self.nodes) distances_to_goal = np.linalg.norm(nodes_arr - self.goal, axis=1) nearest_node_idx = np.argmin(distances_to_goal) q_final_near = self.nodes[nearest_node_idx] q_final_near_tuple = tuple(q_final_near) goal_tuple = tuple(self.goal) # Try connecting nearest found node to goal if check_segment_collision(q_final_near, self.goal, self.builds_with_safety) == 0: self.nodes.append(self.goal) self.parent[goal_tuple] = q_final_near_tuple self.cost[goal_tuple] = self.cost[q_final_near_tuple] + np.linalg.norm(self.goal - q_final_near) print("RRT*: Connected nearest node to goal.") else: print("RRT*: Could not connect nearest node to goal collision-free. Returning path to nearest node.") # Path will be constructed to q_final_near instead of goal goal_tuple = q_final_near_tuple # Target for path reconstruction # Backtrack path from goal (or nearest reachable node) path = [] # Start backtracking from the actual last node added (goal or nearest) curr_tuple = goal_tuple if curr_tuple not in self.parent and curr_tuple != tuple(self.start): print(f"Warning: Target node {curr_tuple} not found in parent dict. Path reconstruction might fail.") # Fallback to the last added node if goal wasn't reachable/added correctly if self.nodes: curr_tuple = tuple(self.nodes[-1]) else: return None # No path possible while curr_tuple is not None: # Ensure the node corresponding to the tuple exists # This requires searching self.nodes, which is inefficient. # A better approach is to store nodes in the dict or use indices. # For now, let's assume tuple keys match numpy arrays. path.append(np.array(curr_tuple)) curr_tuple = self.parent.get(curr_tuple, None) if not path: print("Error: Path reconstruction failed.") return None if tuple(path[-1]) != tuple(self.start): print("Warning: Path does not end at start node.") return np.array(path[::-1]) # Reverse to get start -> goal order # --------------------- Path Cost Function (Use safe builds) --------------------- def path_cost(path_pts, builds_with_safety, drone_speed=DRONE_SPEED, penalty_k=PENALTY_K): total_time = 0.0 total_penalty = 0.0 num_segments = len(path_pts) - 1 if num_segments < 1: return 0.0 # No cost for a single point path # Vectorized calculations where possible p = path_pts[:-1] # Start points of segments q = path_pts[1:] # End points of segments segments = q - p distances = np.linalg.norm(segments, axis=1) # Avoid division by zero for zero-length segments valid_segments = distances > 1e-6 if not np.any(valid_segments): return 0.0 # Path has no length p = p[valid_segments] q = q[valid_segments] segments = segments[valid_segments] distances = distances[valid_segments] dir_unit = segments / distances[:, np.newaxis] # Interpolate wind at midpoints for better average (optional, could use start/end) midpoints = p + segments / 2.0 # try: # wind_vectors = interp(midpoints) # except ValueError as e: # print(f"Interpolation error: {e}") # print(f"Midpoints shape: {midpoints.shape}") # # Handle error, e.g., return a large cost or use zero wind # wind_vectors = np.zeros_like(midpoints) # Calculate ground speed component along each segment # Ensure wind_vectors and dir_unit have compatible shapes for dot product # np.einsum is efficient for row-wise dot products # wind_along_path = np.einsum('ij,ij->i', wind_vectors, dir_unit) ground_speeds = np.maximum(drone_speed , 1e-3) # Avoid zero/negative speed # Calculate time for each segment segment_times = distances / ground_speeds total_time = np.sum(segment_times) # Calculate collision penalty (iterate segments as check_segment_collision is per-segment) for i in range(len(p)): # Pass pre-calculated builds_with_safety penetration = check_segment_collision(p[i], q[i], builds_with_safety) if penetration > 0: total_penalty += penalty_k * penetration**2 # Quadratic penalty return total_time + total_penalty生成注释
05-14
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值