import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
from matplotlib.patches import Circle
from scipy.spatial.distance import cdist
from scipy.sparse import lil_matrix
try:
from sklearn.cluster import KMeans
has_sklearn = True
except ImportError:
print("警告: scikit-learn 未安装,将使用简化的粒球生成方法")
has_sklearn = False
class GranularBallGraph:
def __init__(self, points, n_clusters=10, overlap_threshold=0.1):
self.points = points
self.n_clusters = n_clusters
self.overlap_threshold = overlap_threshold
self.balls = []
self.graph = None
def generate_granular_balls(self):
"""生成粒球 - 支持两种方式"""
if has_sklearn and len(self.points) > self.n_clusters:
# 使用K-Means聚类方法
kmeans = KMeans(n_clusters=self.n_clusters, random_state=0, n_init=10)
labels = kmeans.fit_predict(self.points)
self.balls = []
for i in range(self.n_clusters):
cluster_points = self.points[labels == i]
if len(cluster_points) > 0:
center = kmeans.cluster_centers_[i]
radius = np.max(np.linalg.norm(cluster_points - center, axis=1))
self.balls.append((center, radius, cluster_points))
else:
# 简化的粒球生成方法(不使用sklearn)
print(f"使用简化的粒球生成方法(点数量: {len(self.points)})")
self.balls = []
n = min(self.n_clusters, len(self.points))
indices = np.random.choice(len(self.points), n, replace=False)
for idx in indices:
center = self.points[idx]
# 计算到最近点的距离作为半径
dists = np.linalg.norm(self.points - center, axis=1)
dists.sort()
radius = dists[min(5, len(dists) - 1)] # 取第5近的距离作为半径
self.balls.append((center, radius, self.points[dists <= radius]))
def build_graph(self):
"""构建粒球图结构"""
if not self.balls:
print("错误: 未生成粒球,请先调用 generate_granular_balls()")
return
n_balls = len(self.balls)
adj_matrix = lil_matrix((n_balls, n_balls), dtype=int)
centers = np.array([ball[0] for ball in self.balls])
radii = np.array([ball[1] for ball in self.balls])
# 计算粒球中心间的距离
dist_matrix = cdist(centers, centers)
# 判断粒球是否相交
for i in range(n_balls):
for j in range(i + 1, n_balls):
# 计算重叠比例
distance = dist_matrix[i, j]
min_radius = min(radii[i], radii[j])
if min_radius > 0: # 防止除以零
overlap_ratio = (radii[i] + radii[j] - distance) / min_radius
if overlap_ratio > self.overlap_threshold:
adj_matrix[i, j] = 1
adj_matrix[j, i] = 1
# 创建NetworkX图
self.graph = nx.Graph()
for i in range(n_balls):
self.graph.add_node(i, center=centers[i], radius=radii[i])
for i in range(n_balls):
for j in adj_matrix.rows[i]:
if i < j: # 避免重复添加边
self.graph.add_edge(i, j)
def visualize_2d(self, show_points=True, show_balls=True, show_graph=True):
"""可视化2D粒球图"""
if self.points.shape[1] != 2:
print("警告: 可视化仅支持2D数据")
return
plt.figure(figsize=(10, 8))
# 绘制原始点云
if show_points:
plt.scatter(self.points[:, 0], self.points[:, 1],
c='blue', s=10, alpha=0.6, label='点云')
# 绘制粒球
if show_balls and self.balls:
for i, (center, radius, _) in enumerate(self.balls):
circle = Circle(center, radius, fill=False,
edgecolor='red', alpha=0.4, linestyle='--')
plt.gca().add_patch(circle)
plt.text(center[0], center[1], f'B{i}',
fontsize=9, ha='center', va='center')
# 绘制图结构
if show_graph and self.graph:
pos = {i: ball[0] for i, ball in enumerate(self.balls)}
nx.draw(self.graph, pos, node_size=50, node_color='green',
edge_color='gray', width=1.5, alpha=0.7)
plt.title('粒球图 (Granular Ball Graph)')
plt.xlabel('X')
plt.ylabel('Y')
plt.grid(alpha=0.3)
plt.legend()
plt.axis('equal')
plt.show()
def reconstruct_surface(self, resolution=0.1):
"""简单的表面重建(仅用于2D数据)"""
if self.points.shape[1] != 2:
print("警告: 表面重建目前仅支持2D数据")
return
if not self.balls:
print("错误: 未生成粒球,请先调用 generate_granular_balls()")
return
# 创建网格
x_min, y_min = np.min(self.points, axis=0)
x_max, y_max = np.max(self.points, axis=0)
x = np.arange(x_min, x_max, resolution)
y = np.arange(y_min, y_max, resolution)
xx, yy = np.meshgrid(x, y)
grid_points = np.c_[xx.ravel(), yy.ravel()]
# 判断点是否在任意粒球内
surface_mask = np.zeros(len(grid_points), dtype=bool)
for center, radius, _ in self.balls:
dist = np.linalg.norm(grid_points - center, axis=1)
surface_mask |= (dist <= radius)
# 可视化重建表面
plt.figure(figsize=(10, 8))
plt.scatter(grid_points[surface_mask, 0], grid_points[surface_mask, 1],
c='blue', s=1, alpha=0.3, label='重建表面')
plt.scatter(self.points[:, 0], self.points[:, 1],
c='red', s=15, alpha=0.7, label='原始点')
# 添加粒球中心
centers = np.array([ball[0] for ball in self.balls])
plt.scatter(centers[:, 0], centers[:, 1],
c='green', s=50, marker='x', label='粒球中心')
plt.title('基于粒球图的表面重建')
plt.xlabel('X')
plt.ylabel('Y')
plt.legend()
plt.grid(alpha=0.3)
plt.axis('equal')
plt.show()
# 示例用法
if __name__ == "__main__":
# 生成示例点云数据(2D圆环)
np.random.seed(42)
n_points = 700
theta = np.linspace(0, 2 * np.pi, n_points)
r = 5 + np.random.rand(n_points) * 2
points = np.c_[r * np.cos(theta), r * np.sin(theta)]
noise = np.random.randn(n_points, 2) * 0.3
points += noise
# 创建并构建粒球图
print("创建粒球图...")
gbg = GranularBallGraph(points, n_clusters=10)
gbg.generate_granular_balls()
gbg.build_graph()
# 输出信息
if gbg.balls:
print(f"生成粒球数量: {len(gbg.balls)}")
avg_radius = np.mean([ball[1] for ball in gbg.balls])
print(f"平均半径: {avg_radius:.2f}")
if gbg.graph:
print(f"图结构: {gbg.graph.number_of_nodes()} 个节点, {gbg.graph.number_of_edges()} 条边")
# 可视化
gbg.visualize_2d(show_points=True, show_balls=True, show_graph=True)
gbg.reconstruct_surface()
最新发布