Graph-U-Nets(二)代码分析
Graph U-Nets通过Pytorch进行实现,开源代码Graph U-Nets,原论文链接Graph U-Nets
直接对作者的代码进行解读
class GraphUnet(nn.Module):
def __init__(self, ks, in_dim, out_dim, dim, act, drop_p):
"""
:param ks: 表示pools层进行的节点采样率,数据类型为float型
"""
super(GraphUnet, self).__init__()
self.ks = ks
# 创建底部的gcn
self.bottom_gcn = GCN(dim, dim, act, drop_p)
# 对应原文的encoder中的gcn
self.down_gcns = nn.ModuleList()
# 对应原文的decoder中的gcn
self.up_gcns = nn.ModuleList()
# 对应原文的encoder中的gPool
self.pools = nn.ModuleList()
# 对应原文的decoder中的gUnPool
self.unpools = nn.ModuleList()
# ks的长度表示了gUNets的深度
self.l_n = len(ks)
# 构建l_n个子模块
for i in range(self.l_n):
self.down_gcns