所有讨论均以以下代码为基础进行:
import dgl
import torch as th
from dglfn import dgl_fn
import dgl.function as fn
# https://blog.youkuaiyun.com/CY19980216/article/details/110629996?
device=th.device("cuda" if th.cuda.is_available() else 'cpu')
u= th.tensor([0,1,2,3,4,3,4])
v= th.tensor([2,0,1,5,3,2,1])
g=dgl.graph((u,v)).to(device)
#print(g.edges())
g.ndata["value"]=th.ones(g.num_nodes(),3).to(device) #每个值 都是3维向量
#print(g.ndata["value"])
#g.ndata["ft"]=th.tensor([0,1,2,3,4,5],dtype=th.float32).to(device)
g.edata["weights"]=th.ones(g.num_edges(),3).to(device)
dgl_fn.StartGraph(size=(12,10),title="DGL网络")
dgl_fn.AddGraph(221,g,"value",'weights')
g.apply_edges(fn.u_add_v('value','value','weights')) #计算边的两个节点的相关运算,并存于边上。
# 计算方法: 源节点+目标节点值 , 例如: 0,1 的wegiths 为1, 2,得出边1 的值 为 3
#g.update_all(fn.u_add_v("value","value","m"),fn.sum("m","weights"))
dgl_fn.AddGraph(222,g,"value",'weights')
#print("the weights of edgs",g.edata["weights"]) #边的权重变为每个元素为2的矩阵
def updata_all_example(graph):
# store the result in graph.ndata['ft']
graph.update_all(fn.u_mul_e('value', 'weights', 'm'), fn.sum('m', 'ft')) # torch.sum(x, 1)
# Call update function outside of update_all
final_ft = graph.ndata['ft']
return final_ft
all=updata_all_example(g)
# 从源节点到它指向的边,相乘后的值 存放于m中,再对m和ft中的值 进行相加
#
#
#print(g.edges())
print(all)
dgl_fn.AddGraph(223,g,"value",'weights')
dgl_fn.ShowGraph()
在这张图中,有6个节点,总共有7条边。
如图所示:
为简化起见,预设原始的每节点的value特征是 [1.1,1], 原始的 每条边的权重 wegiths为[1,1,1].
在执行
g.apply_edges(fn.u_add_v('value','value','weights')) #这个函数 没啥可说的,就是直接向量相加,存到边上
后,所有边分别的权重,即每条边的特征 weights 为 [2,2,2], 没错,是个向量。
记住上面的值 ,即边的weights 权重为 [2,2,2], 每个节点的value为[1,1,1]
最复杂的是下面这句
def updata_all_example(graph):
# store the result in graph.ndata['ft']
graph.update_all(fn.u_mul_e('value', 'weights', 'm'), fn.sum('m', 'ft')) # torch.sum(x, 1)
# Call update function outside of update_all
final_ft = graph.ndata['ft']
return final_ft
调用以上函数:
all=updata_all_example(g)
上面的操作是对每条边进行如下操作:
取源节点的value特征乘以与它相连的边的权重weghts, 这里的乘是向量的按位乘,即相同位置的分量相乘后作为同一个位置的结果,相乘后的值 存放于m中,再对m中的值 进行求和存放于目标节点的mailbox的ft特征中。
print(all)
怎么得到的上面的值 呢?
我们先看最前面的图。 先说一个结论,我们得到的值all是跟节点数一致的一个矩阵。行数即节点数量(这里是6个节点),第一行表示第一个节点0,其它类推。列数是特征维度,即3.
我们有5个节点: 0,1,2,3,4,5
需要详细考察每个节点的入度(是的,dgl是有向图,只算入度。update_all 只影响有入度的顶点,下面标的出度只是为了形象起见,并不影响值 )
0: 1 入 (1出)
1: 2 入 (1出)
2: 2入 (1出)
3:1 入(2出)
4:0入(2出)
5: 1入(0出)
再对照 上图中的值 ,有没有想明白是啥意思?
对: 只计算入度。 以节点1为例 ,每个入度得到的值是 [1,1,1] * [2,2,2] -> [2,2,2], 总共有两个入度,作sum后,即是 [4,4,4], 对应上面结果中的第二行。
看明白没?
对,出度不会影响这个计算值 。所以,对于节点4, 它的sum后的结果就 [0, 0,0]
当然,我们上面的程序中没有更新函数。如果需要,可以直接更新目标节点的相应的特征值 。
有一点要注意: mailbox其实是一个抽象的概念,在消息函数运行后,每条边都有一个mailbox以及相关的特征,如上面的 mailbox 中的特征 m, 聚合函数就是把这个点的所有入度的mailbox中的m合并起来,合并的方法函数叫聚合函数。
这点刚开始有点费解。
比如顶点1, 它有两个入度,在执行完消息函数后,其实有两个mailbox ,里面都有一个m 特征, 聚合函数的作用就是把这些多个m 合并起来,然后存在当前节点 1的ft 特征中。