这是一个非常常见的可视化问题:
> ❌ **某些边横穿节点,影响可读性**
我们可以通过以下方式来优化边的绘制,避免边穿过节点:
---
## ✅ 改进目标
| 目标 | 实现方式 |
|------|----------|
| ✅ 避免边穿过节点 | 使用 `arc3` 曲线偏移参数 |
| ✅ 可配置边弯曲程度 | 使用 `connectionstyle` 控制曲线 |
| ✅ 保留颜色混合逻辑 | 边颜色仍使用组合颜色混合 |
| ✅ 不影响图的结构和布局 | 仅视觉优化,不影响拓扑结构 |
---
## ✅ 修改后的完整代码(避免边穿过节点)
```python
import os
import matplotlib.pyplot as plt
import networkx as nx
from collections import defaultdict
import numpy as np
import matplotlib.cm as cm
import matplotlib.colors as mcolors
def generate_distinct_colors(n):
"""
生成 n 种视觉上区分度高的颜色(使用 HSV 色轮)
"""
hues = np.linspace(0, 1, n, endpoint=False)
hsv_colors = np.column_stack([hues, np.ones(n)*0.7, np.ones(n)*0.9])
rgb_colors = np.array([mcolors.hsv_to_rgb(c) for c in hsv_colors])
return [mcolors.rgb2hex(rgb) for rgb in rgb_colors]
def draw_trees_from_df____(df, root_name="ROOT", group_size=1, save_dir="output", file_format="png"):
"""
从 DataFrame 中读取数据,为每个 base_id 生成一个树状图,并保存为文件。
每条边颜色表示使用该边的所有 (compare_id1, compare_id2) 对应的颜色混合。
每个分组只显示该组使用的组合,且独立上色。
参数:
- df: 包含数据的 DataFrame,必须包含 "base_id", "compare_id1", "compare_id2", "param" 列。
- root_name: 根节点名称。
- group_size: 每组处理的 base_id 数量(默认为 1)。
- save_dir: 图片保存的目录。
- file_format: 图片保存格式,如 png, svg, pdf 等。
"""
os.makedirs(save_dir, exist_ok=True)
base_ids = df["base_id"].unique().tolist()
groups = [base_ids[i:i + group_size] for i in range(0, len(base_ids), group_size)]
for group_idx, group in enumerate(groups):
plt.figure(figsize=(12, 10))
ax = plt.subplot(111)
combined_G = nx.DiGraph()
edge_to_pairs = defaultdict(set) # 边 -> 使用该边的 compare pair 列表
pair_color_map = {} # compare pair -> 颜色
leaf_to_pairs = defaultdict(set) # 叶子节点 -> 所属组合
# 提取当前分组中出现的所有 compare pair
group_df = df[df["base_id"].isin(group)]
unique_pairs_in_group = group_df[['compare_id1', 'compare_id2']].drop_duplicates()
# 动态生成足够多的颜色
num_colors = len(unique_pairs_in_group)
colors = generate_distinct_colors(num_colors)
# 为当前组的组合分配颜色
for i, (_, row) in enumerate(unique_pairs_in_group.iterrows()):
pair = (row['compare_id1'], row['compare_id2'])
pair_color_map[pair] = colors[i]
# 构建图结构并记录每条边的来源组合
for base_id in group:
base_df = df[df["base_id"] == base_id]
for idx, row in base_df.iterrows():
path = row["param"]
compare_pair = (row["compare_id1"], row["compare_id2"])
current_node = root_name
if not combined_G.has_node(current_node):
combined_G.add_node(current_node)
for param in path:
next_node = param
if not combined_G.has_node(next_node):
combined_G.add_node(next_node)
combined_G.add_edge(current_node, next_node)
edge_to_pairs[(current_node, next_node)].add(compare_pair)
current_node = next_node
# 记录叶子节点对应的组合
if combined_G.out_degree(current_node) == 0:
leaf_to_pairs[current_node].add(compare_pair)
# 分层布局
layers = {}
visited = set()
queue = [(root_name, 0)]
while queue:
node, depth = queue.pop(0)
if node in visited:
continue
visited.add(node)
layers[node] = depth
for neighbor in combined_G.successors(node):
if neighbor not in layers:
layers[neighbor] = depth + 1
queue.append((neighbor, depth + 1))
for node in combined_G.nodes:
combined_G.nodes[node]["layer"] = layers.get(node, 0)
pos = nx.multipartite_layout(combined_G, subset_key="layer", align="horizontal")
# 为每条边计算颜色(多个组合则取平均颜色)
edge_colors = []
for u, v in combined_G.edges():
pairs = edge_to_pairs[(u, v)]
if len(pairs) == 0:
edge_colors.append("gray")
else:
colors = [pair_color_map[p] for p in pairs]
# 将多个颜色混合(取平均 RGB)
mixed_color = tuple(
np.mean([tuple(int(c[i:i + 2], 16) / 255 for i in (1, 3, 5)) for c in colors], axis=0))
edge_colors.append(mixed_color)
# 绘图 - 使用曲线边避免穿过节点
labels = {node: node for node in combined_G.nodes}
nx.draw_networkx_nodes(combined_G, pos, ax=ax, node_size=400, node_color="lightgray")
nx.draw_networkx_labels(combined_G, pos, ax=ax, labels=labels, font_size=10)
# 使用 nx.draw_networkx_edges 并设置 connectionstyle 避免边穿过节点
for (u, v), color in zip(combined_G.edges(), edge_colors):
nx.draw_networkx_edges(
combined_G, pos, edgelist=[(u, v)],
ax=ax, edge_color=[color],
connectionstyle=f"arc3, rad=0.2", # 弯曲边
arrows=True, width=2, alpha=0.8
)
# 在叶子节点旁边标注组合信息
for leaf, pairs in leaf_to_pairs.items():
x, y = pos[leaf]
unique_pairs = list(pairs)
for i, pair in enumerate(unique_pairs):
color = pair_color_map[pair]
text = f"{pair[0]} vs {pair[1]}"
ax.text(x, y + 0.05 + i * 0.05, text,
fontsize=8, color=color,
ha='center', va='bottom')
# 添加图例
legend_elements = [
plt.Line2D([0], [0], color=pair_color_map[pair], lw=2, label=f"{pair[0]} vs {pair[1]}")
for pair in pair_color_map
]
ax.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.2, 1))
title = f"Group: {', '.join(group)}"
ax.set_title(title)
plt.tight_layout()
filename = "_".join(group)
save_path = os.path.join(save_dir, f"{filename}.{file_format}")
plt.savefig(save_path, format=file_format, dpi=200, bbox_inches='tight')
plt.close()
print(f"Saved: {save_path}")
```
---
## ✅ 修改说明
| 修改点 | 实现方式 |
|--------|----------|
| ✅ 避免边横穿节点 | 使用 `nx.draw_networkx_edges` 和 `connectionstyle="arc3, rad=0.2"` |
| ✅ 曲线弧度可调 | `rad=0.2` 表示弯曲程度,值越大越弯曲 |
| ✅ 保留颜色混合逻辑 | 每条边仍使用 `edge_colors` 列表 |
| ✅ 更细粒度控制 | 每条边单独绘制,便于添加标签、动画等扩展功能 |
---
###