import os
import logging
import numpy as np
import torch
import open3d as o3d
from pathlib import Path
from torch_cluster import fps
import sys
# 🔧 配置参数 ====================================================
config = {
# 数据参数
"input_file": r"D:\project\Pointnet_Pointnet2_pytorch-master\data\tomato_WUR_cloud\000101\6.6-0.6 - Cloud.txt",
"model_path": r"D:\project\Pointnet_Pointnet2_pytorch-master\logs\part_seg\tomato_segnext_focal(3_0.9)_random\best_model.pth",
# 预处理参数
"npoints": 102400,
"use_normals": True,
"sampling_method": "random", #'random', 'fps', 'grid', 'semantic', 'hybrid', 'curvature'
"voxel_size": 0.001,
"keep_ratio": 0.6,
"has_labels": True,
"iou_classes": [1],
# 模型参数
"num_classes": 1,
"num_part": 2,
# 可视化参数
"show_visualization": True,
"color_map": {
0: [255, 0, 255], # 背景
1: [0, 255, 0] # 茎
}
}
# 🎯 类别定义 ====================================================
SEG_CLASSES = {'茎': 1, '背景': 0}
# 🛠️ 初始化日志 =================================================
logger = logging.getLogger("CloudSegTest")
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# 🔄 数据加载与预处理 =============================================
def load_pointcloud(file_path: str) -> tuple:
"""增强版文件加载方法"""
try:
# 省略二进制处理部分以简化代码,保持主要逻辑
encodings = ['utf-8', 'utf-16', 'latin-1', 'gb18030', 'gbk']
for encoding in encodings:
try:
with open(file_path, 'r', encoding=encoding, errors='replace') as f:
lines = [line for line in f if not any(c.isalpha() for c in line)]
data = np.loadtxt(lines, dtype=np.float32)
cols = data.shape[1]
min_col = 6 if config["use_normals"] else 3
if cols < min_col:
raise ValueError(f"数据列不足,需要至少{min_col}列")
points = data[:, :min_col]
labels = data[:, -1] if (config["has_labels"] and cols > min_col) else None
if len(points) == 0:
raise ValueError("加载到空点云数据")
return points, labels
except:
continue
raise ValueError("无法解析文件格式")
except Exception as e:
logger.error(f"文件加载失败: {str(e)}")
raise
def pc_normalize(pc: np.ndarray) -> np.ndarray:
"""点云归一化"""
centroid = np.mean(pc[:, :3], axis=0)
pc[:, :3] -= centroid
max_dist = np.max(np.linalg.norm(pc[:, :3], axis=1))
if max_dist > 1e-8:
pc[:, :3] /= max_dist
return pc
# 🧠 模型类 ======================================================
class PartSegmentationModel:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self._init_model()
def _init_model(self):
"""初始化模型"""
sys.path.append("models")
try:
from models.pointnet2_part_seg_msg import get_model
self.model = get_model(
config["num_part"],
normal_channel=config["use_normals"]
).to(self.device)
self._load_weights()
self.model.eval()
except Exception as e:
logger.error(f"模型初始化失败: {str(e)}")
raise
def _load_weights(self):
"""加载权重"""
try:
checkpoint = torch.load(config["model_path"], map_location=self.device)
logger.info(f"检查点键名: {checkpoint.keys()}")
# 兼容不同键名格式
if 'model_state_dict' in checkpoint:
self.model.load_state_dict(checkpoint['model_state_dict'])
elif 'model_state' in checkpoint:
self.model.load_state_dict(checkpoint['model_state'])
else:
raise KeyError("检查点中未找到有效模型参数")
logger.info("成功加载模型参数")
except Exception as e:
logger.error(f"权重加载失败: {str(e)}")
raise
def predict(self, points: np.ndarray) -> np.ndarray:
"""执行预测"""
try:
points_tensor = torch.tensor(points, dtype=torch.float32)
if points_tensor.dim() == 2:
points_tensor = points_tensor.unsqueeze(0).transpose(2, 1).to(self.device)
cls_label = torch.zeros(1, dtype=torch.long).to(self.device)
one_hot = torch.eye(config["num_classes"], device=self.device)[cls_label]
with torch.no_grad():
seg_pred, _ = self.model(points_tensor, one_hot)
return seg_pred.argmax(dim=2).squeeze().cpu().numpy()
except Exception as e:
logger.error(f"预测失败: {str(e)}")
raise
# 📊 可视化工具 ==================================================
def visualize_results(points: np.ndarray, pred_labels: np.ndarray):
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points[:, :3])
colors = np.array([config["color_map"][l] for l in pred_labels]) / 255.0
pcd.colors = o3d.utility.Vector3dVector(colors)
if config["show_visualization"]:
o3d.visualization.draw_geometries([pcd],
window_name="番茄茎分割结果",
width=800,
height=600
)
# 🛠️ 点云处理器 =================================================
class PointCloudProcessor:
def __init__(self, config):
self.config = config
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def _farthest_point_sample(self, points: np.ndarray) -> np.ndarray:
points_tensor = torch.tensor(points, device=self.device)
ratio = min(1.0, self.config["npoints"] / len(points))
indices = fps(points_tensor, ratio=ratio, random_start=True)
return indices.cpu().numpy()[:self.config["npoints"]]
def _grid_subsample(self, points: np.ndarray) -> np.ndarray:
coords = points[:, :3]
voxel_coords = np.floor((coords - coords.min(0)) / self.config["voxel_size"])
_, indices = np.unique(voxel_coords, axis=0, return_index=True)
return indices
def _semantic_aware_sample(self, seg_labels: np.ndarray) -> np.ndarray:
indices = []
for label in np.unique(seg_labels):
mask = seg_labels == label
count = max(1, int(self.config["npoints"] * np.sum(mask) / len(seg_labels)))
indices.extend(np.random.choice(np.where(mask)[0], count))
return np.array(indices[:self.config["npoints"]])
def _hybrid_sampling(self, points: np.ndarray) -> np.ndarray:
"""修复后的混合采样方法"""
grid_indices = self._grid_subsample(points)
if len(grid_indices) == 0:
return np.arange(len(points))[:self.config["npoints"]]
sub_points = points[grid_indices]
valid_npoints = min(self.config["npoints"], len(sub_points))
ratio = valid_npoints / len(sub_points)
fps_indices = fps(
torch.tensor(sub_points[:, :3], device=self.device),
ratio=ratio,
random_start=True
).cpu().numpy()
final_indices = grid_indices[fps_indices]
return final_indices[:self.config["npoints"]]
def calculate_iou(pred: np.ndarray, true: np.ndarray, classes: list) -> dict:
"""计算指定类别的IOU"""
iou_dict = {}
for cls in classes:
pred_mask = (pred == cls)
true_mask = (true == cls)
intersection = np.logical_and(pred_mask, true_mask).sum()
union = np.logical_or(pred_mask, true_mask).sum()
# 处理除零情况
if union == 0:
iou = np.nan
else:
iou = intersection / union
iou_dict[cls] = iou
return iou_dict
# 🚀 主流程 =====================================================
def main():
try:
logger.info("=== 开始处理 ===")
# 1. 加载数据
raw_points, raw_labels = load_pointcloud(config["input_file"])
logger.info(f"原始点云数量: {len(raw_points)}")
# 2. 预处理
processed_points = pc_normalize(raw_points)
assert processed_points.ndim == 2, "点云数据必须是二维数组"
# 3. 下采样
processor = PointCloudProcessor(config)
if len(processed_points) > config["npoints"]:
method = config["sampling_method"]
if method == 'random':
choice = np.random.choice(len(processed_points), config["npoints"], replace=False)
elif method == 'fps':
choice = processor._farthest_point_sample(processed_points[:, :3])
elif method == 'grid':
choice = processor._grid_subsample(processed_points)
elif method == 'semantic':
assert raw_labels is not None, "语义采样需要标签数据"
choice = processor._semantic_aware_sample(raw_labels)
elif method == 'hybrid':
choice = processor._hybrid_sampling(processed_points)
else:
choice = np.arange(len(processed_points))[:config["npoints"]]
else:
choice = np.concatenate([
np.arange(len(processed_points)),
np.random.choice(len(processed_points),
config["npoints"] - len(processed_points),
replace=True)
])
sampled_points = processed_points[choice]
logger.info(f"采样后点数: {len(sampled_points)}")
sampled_points = processed_points[choice]
if config["has_labels"] and raw_labels is not None:
sampled_labels = raw_labels[choice]
else:
sampled_labels = None
# 4. 模型推理
model = PartSegmentationModel()
pred_labels = model.predict(sampled_points)
logger.info(f"预测结果分布: {np.unique(pred_labels, return_counts=True)}")
if config["has_labels"] and sampled_labels is not None:
# 对齐标签维度
if sampled_labels.ndim > 1:
sampled_labels = sampled_labels.squeeze()
# 验证标签一致性
assert len(pred_labels) == len(sampled_labels), "预测结果与标签长度不一致"
class_list = config["iou_classes"]
iou_results = calculate_iou(pred_labels, sampled_labels, class_list)
logger.info("=== IOU计算结果 ===")
for cls, iou in iou_results.items():
cls_name = [k for k, v in SEG_CLASSES.items() if v == cls][0]
logger.info(f"{cls_name} ({cls}) IOU: {iou:.4f}")
# 计算均值(忽略NaN值)
valid_ious = [v for v in iou_results.values() if not np.isnan(v)]
mean_iou = np.mean(valid_ious) if valid_ious else np.nan
logger.info(f"平均IOU: {mean_iou:.4f}")
# 5. 可视化
visualize_results(sampled_points, pred_labels)
logger.info("=== 处理完成 ===")
except Exception as e:
logger.error(f"流程异常终止: {str(e)}")
raise
if __name__ == "__main__":
main()
这是我训练的一个模型,对于番茄植株截茎分割。和普通的茎分割不同,这个模型的目标是分割出一节一节的茎,最终的目标是测量出每一截茎的长度和茎粗.
现在需要你对这个脚本进行二次开发,将语义分割出的每一截茎(标签是1)聚类实例化,每一个截茎示例拟合成规则图形并自动得出每一截茎的高度和茎粗,显示的时候沿着z轴从上到下排列(先看最高的茎段),单位厘米。请帮我选择最合适的聚类算法。并可以可视化
注意,可能分割模型对于茎分段的效果不是特别理想,但应该还是有点作用的,而且断点处都是长出分支的地方。所以可以考虑点云间的距离(同一段茎肯定都是挨在一起的)、密度、变化、切线和斜率变化等等情况进行分段。当然,茎部永远都是直的,要有动态判断当前方向的方法
你先将我的要求充分挖掘并总结整理,再照着开始做
最新发布