函数介绍
在使用dgl框架自定义卷积算子或者看dgl实现函数的源代码时,总会碰到如下的代码形式:
with graph.local_scope():
if isinstance(feat, tuple):
feat, _ = feat # dst feature not used
cj = graph.srcdata['cj']
ci = graph.dstdata['ci']
if self.device is not None:
cj = cj.to(self.device)
ci = ci.to(self.device)
with graph.local_scope()
函数到底起到什么作用,这里做一下介绍,graph使用dgl定义的一个图,官方的介绍如下:
Enter a local scope context for the graph.
By entering a local scope, any out-place mutation to the feature data will not reflect to the original graph, thus making it easier to use in a function scope (e.g. forward computation of a model).
If set, the local scope will use same initializers for node features and edge features.
Notes
Inplace operations do reflect to the original graph. This function also has little overhead when the number of feature tensors in this graph is small.
大致意思为:
使用local_scope() 范围时,任何对节点或边的修改在脱离这个局部范围后将不会影响图中的原始特征值 。直白点就是:在这个范围中,你可以引用、修改图中的特征值 ,但是只是临时的,出了这个范围,一切恢复原样。这样做的目的是方便计算。毕竟我们在图上作消息传递和聚合,有时仅是为了计算值并不想改变原始图。
例子
下面看个例子:
import dgl
import torch
def foo(g):
with g.local_scope():
g.edata['h'] = torch.ones((g.num_edges(), 3))
g.edata['h2'] = torch.ones((g.num_edges(), 3))
return g.edata['h']
# local_scope avoids changing the graph features when exiting the function.
g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([0, 0, 2])))
g.edata['h'] = torch.zeros((g.num_edges(), 3))
newh = foo(g)
print(g.edata['h']) # still get tensor of all zeros
tensor([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]])
'h2' in g.edata # new feature set in the function scope is not found
False
# 下面这里例子为什么会改变原图的特征还不太理解,大家理解的可以在评论区留言
# In-place operations will still reflect to the original graph.
def foo(g):
with g.local_scope():
# in-place operation
g.edata['h'] += 1
return g.edata['h']
g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([0, 0, 2])))
g.edata['h'] = torch.zeros((g.num_edges(), 1))
newh = foo(g)
print(g.edata['h']) # the result changes
tensor([[1.],
[1.],
[1.]])