import torch
import networkx as nx
import matplotlib.pyplot as plt
### 画图,将Graph以图片的形式展示出来
def drawGraph(edge_index: torch.Tensor, batch: torch.Tensor, batch_size=0):
options = {
#"font_size": 36,
#"node_size": 3000,
"node_color": "white",
"edgecolors": "black",
#"linewidths": 5,
#"width": 5,
}
r: int = 0
preCount: int= 0
for i in range(batch_size):
G = nx.Graph()
count = (batch == i).sum().item()
for j in range(count):
G.add_node(j)
for j in range(preCount, edge_index.shape[1]):
x = edge_index[0][j].item() - r
y = edge_index[1][j].item() - r
if x >= 0 and x < count:
G.add_edge(x, y)
else:
break
nx.draw(G,with_labels=True,**options)
plt.show()
r = r + count
preCount = preCount + len(G.edges) * 2
def drawGraphByMetric(graph:list):
options = {
"node_color": "white",
"edgecolors": "black",
}
for k in range(len(graph)):
G = nx.Graph()
for i in range(len(graph[k])):
if graph[k][i][i]!=0:
G.add_node(i)
for i in range(len(graph[k])):
for j in range(len(graph[k])):
if(graph[k][i][j]==1):
G.add_edge(i,j)
nx.draw(G, with_labels=True, **options)
plt.show()
深度学习图池化中的画节点代码
最新推荐文章于 2025-01-06 21:30:00 发布