图神经网络学习1-Schnet源码解析

from math import pi as PI
import torch
# .F封装了很多函数, 没有可学习参数
import torch.nn.functional as F
# torch.nn封装了很多神经网络单元, 比如Sequential序列, lstm, conv, Embedding单元等(当然自带参数追踪等)
from torch.nn import Embedding, Sequential, Linear

from torch_scatter import scatter
from torch_geometric.nn import radius_graph

# iteraction 黄色模块, 对节点的信息进行更新
# energy_and_force=False, cutoff=10.0, num_layers=6, hidden_channels=128, out_channels=1, num_filters=128, num_gaussians=50
class update_v(torch.nn.Module):
    def __init__(self, hidden_channels, num_filters):
    #每个类模型默认会添加def __init__(self)假如不需要额外的实例的属性时可以不写, self指实例本身,可以换其他名字
    #__init__ 方法是一个类的初始化方法, 它会在你创建一个新的对象时自动被调用, 它用于初始化该对象的属性.
        super().__init__()
        #必须加,否则继承不了属性和方法(父类没有初始化)
        #super(residual_block, self).__init__() 上面是简化的写法, 这行意思是当前需要继承的类的名字的实例本身
        self.act = ShiftedSoftplus()
        self.lin1 = Linear(num_filters, hidden_channels)
        # Linear是线性层, 有偏置, 没激活
        # 数据按照行理解的, 第一个的列等于第二个的行, 列是特征大小, 后面是神经元的数量, 也是输出维度向量的维度,向量就算是列向量也是横着理解的,比如embedding,nn.embedding(108, 64)矩阵维度108*64, 也是按照行理解
        self.lin2 = Linear(hidden_channels, hidden_channels)
        self.reset_parameters()
    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.lin1.weight)
        self.lin1.bias.data.fill_(0)
        torch.nn.init.xavier_uniform_(self.lin2.weight)
        self.lin2.bias.data.fill_(0)
    def forward(self, v, e, edge_index): # self.lin1 = Linear(num_filters, hidden_channels)
        _, i = edge_index          # hidden_channels就是原子特征的维度
        out = scatter(e, i, dim=0) # 聚合
        out = self.lin1(out)       # [num_nodes, num_filters]
        out = self.act(out)
        out = self.lin2(out)       # [num_nodes, hidden_channels]激活函数不改变张量维度,只进行非线性变换
        return v + out
#%%
# filter generator 橙色模块
# 该模块的物理意义是,邻居原子对目标原子的影响力会随着距离的增大而衰减(C矩阵, 余弦矩阵). 邻居原子对目标原子的影响是一个距离的函数,而且该函数是可学的一个MLP模块(self.mlp(dist_emb))
# 这个update_e到底什么意思, 应该该是计算出节点向其邻居传递的信息大小

# 输入维度, 节点特征 v [num_nodes, hidden_channels]
# 距离嵌入 dist_emb:[num_edges, num_gaussians]
# 输出维度, 边特征 e:[num_edges, num_filters]
class update_e(torch.nn.Module):
    def __init__(self, hidden_channels, num_filters, num_gaussians, cutoff):
        super().__init__()
        self.cutoff = cutoff
        self.lin = Linear(hidden_channels, num_filters, bias=False)
        self.mlp = Sequential(Linear(num_gaussians, num_filters), ShiftedSoftplus(), Linear(num_filters, num_filters))
        self.reset_parameters()
    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.lin.weight)
        torch.nn.init.xavier_uniform_(self.mlp[0].weight)
        self.mlp[0].bias.data.fill_(0)
        torch.nn.init.xavier_uniform_(self.mlp[2].weight)
        self.mlp[0].bias.data.fill_(0)
    def forward(self, v, dist, dist_emb, edge_index):
        j, _ = edge_index
        C = 0.5 * (torch.cos(dist * PI / self.cutoff) + 1.0)
        W = self.mlp(dist_emb) * C.view(-1, 1) # dist_emb:边数 * 高斯函数数量, mlp的输出:边数 * 64
        # .view()将张量铺平然后重塑, 比如形状(4,3,2), 按顺序展开. 它和reshape的作用相同, 但是要求内存连续(不能装置, 切片), 效率更高, reshape会检查连续防止出错.
        v = self.lin(v) # v:这批原子的节点数 * 64, 输出: 貌似不变 这层是黄色方块cfconv上的
        e = v[j] * W # W:边数 * 64
        return e # 每个原子对于其中心原子所要传递的信息
#%%
# 最后的全连接, 池化, 输出层, 进行信息聚合
class update_u(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.lin1 = Linear(hidden_channels, hidden_channels // 2)
        self.act = ShiftedSoftplus()
        self.lin2 = Linear(hidden_channels // 2, out_channels)
        self.reset_parameters()
    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.lin1.weight)
        self.lin1.bias.data.fill_(0)
        torch.nn.init.xavier_uniform_(self.lin2.weight)
        self.lin2.bias.data.fill_(0)
    def forward(self, v, batch):
        v = self.lin1(v)
        v = self.act(v)  # v:节点数*特征
        v = self.lin2(v) # v:节点数*1
        u = scatter(v, batch, dim=0) # dim=0, 在第一个维度进行压缩, 如果是(4, 2, 2)在dim=0压缩,变为(2,2), 压缩掉第一个维度, 每个位置累加.
        return u
#%%
class SchNet(torch.nn.Module):
    r"""    energy_and_force (bool, optional): If set to :obj:`True`, will predict energy and take the negative of the derivative of the energy with respect to the atomic positions as predicted forces. (default: :obj:`False`)
            num_layers (int, optional): The number of layers. (default: :obj:`6`)
            hidden_channels (int, optional): Hidden embedding size. (default: :obj:`128`)
            out_channels (int, optional): Output embedding size. (default: :obj:`1`)
            num_filters (int, optional): The number of filters to use. (default: :obj:`128`)
            num_gaussians (int, optional): The number of gaussians :math:`\mu`. (default: :obj:`50`)
            cutoff (float, optional): Cutoff distance for interatomic interactions. (default: :obj:`10.0`)."""
    def __init__(self, energy_and_force=False, cutoff=10.0, num_layers=6, hidden_channels=128, out_channels=1, num_filters=128, num_gaussians=50):
        super(SchNet, self).__init__()
        self.energy_and_force = energy_and_force
        self.cutoff = cutoff
        self.num_layers = num_layers
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.num_filters = num_filters
        self.num_gaussians = num_gaussians
        # v : vertex, node embedding, node feature
        self.init_v = Embedding(100, hidden_channels)
        self.dist_emb = emb(0.0, cutoff, num_gaussians)
        self.update_vs = torch.nn.ModuleList([update_v(hidden_channels, num_filters) for _ in range(num_layers)])
        self.update_es = torch.nn.ModuleList([update_e(hidden_channels, num_filters, num_gaussians, cutoff) for _ in range(num_layers)])
        self.update_u = update_u(hidden_channels, out_channels)
        self.reset_parameters()
    def reset_parameters(self):
        self.init_v.reset_parameters()
        for update_e in self.update_es:
            update_e.reset_parameters()
        for update_v in self.update_vs:
            update_v.reset_parameters()
        self.update_u.reset_parameters()
    def forward(self, batch):
        z, pos, batch = batch.z, batch.pos, batch.batch
        if self.energy_and_force:
            pos.requires_grad_()
        edge_index = radius_graph(pos, r=self.cutoff, batch=batch)
        row, col = edge_index
        dist = (pos[row] - pos[col]).norm(dim=-1) # 沿着维度1操作, 对每个(行)求模长
        dist_emb = self.dist_emb(dist) # 高斯函数进行展开
        v = self.init_v(z) # v是节点特征,进行初始化embedding
        for update_e, update_v in zip(self.update_es, self.update_vs):
            e = update_e(v, dist, dist_emb, edge_index)
            v = update_v(v,e, edge_index)
        u = self.update_u(v, batch) # atom-wise, 逐原子的, 说wise是因为没有信息的交互, 最后的全连接层
        return u

# act
class ShiftedSoftplus(torch.nn.Module):
    def __init__(self):
        super(ShiftedSoftplus, self).__init__()
        self.shift = torch.log(torch.tensor(2.0)).item()  # .item()转换为Python标量

    def forward(self, x):
        return F.softplus(x) - self.shift
#%% md
1. 为什么这里要加.item()?
不需要对偏移量计算梯度, 尽管利用torch.tensor创建的张量默认require_grad=False, 以防万一算了造成资源浪费.
默认设备cpu, 转为标量在后续计算时候会自动与x设备对齐(这时会隐式转为一个不计算梯度的张量).
pytorch默认叶节点(用户直接创建且require_grad=True的张量)保留梯度, 非叶节点(运算生成的中间变量)会计算但是require_grad=None, 必须要用.retain_grad()才能保留.


2. 为什么激活函数需要写成类,继承nn.Module?
自定义的pytorch模块都需要继承nn.Module, 其包含了.parameter等一系列属性方法. 不继承在nn.Sequential中也无法使用, 其通过Module进行管理.
#%%
# gaussian_emb
import torch
class emb(torch.nn.Module):
    def __init__(self, start=0.0, stop=5.0, num_gaussians=50):
        super().__init__()
        offset = torch.linspace(start, stop, num_gaussians)
        self.coeff = -0.5 / (offset[1] - offset[0]).item()**2
        self.register_buffer('offset', offset)
    def forward(self, dist):
        dist = dist.view(-1, 1) - self.offset.view(1, -1)
        return torch.exp(self.coeff * torch.pow(dist, 2))
emb = emb(0.0, 5.0, 10)
print(emb(torch.tensor([0.0, 5.0, 5.0])).shape)
# 输出shape(3, 10)
print(torch.tensor([0.0, 5.0, 5.0]).shape)
#%%
# 用于展开的高斯函数的图像
import numpy as np
import matplotlib.pyplot as plt

# 单个
c = 0       # 中心
sigma = 0.1   # 宽度,越小函数的宽度越小,越尖锐
# 生成x值范围
x = np.linspace(-5, 5, 400)
# 计算高斯RBF
rbf = np.exp(- (x - c)**2 / (2 * sigma**2))
# 绘图
plt.figure(figsize=(8, 4))
plt.plot(x, rbf, label=f'Gaussian RBF\n$c$={c}, $σ$={sigma}')
plt.title('Gaussian RBF')
plt.xlabel('x')
plt.ylabel('RBF(x)')
plt.show()

# 多个
# 实现中所有RBF的宽度σ都等于RBF中心之间的距离Δx。
# 参数定义
start = -5.0     # RBF中心的起始值
stop = 5.0       # RBF中心的终止值
num_gaussians = 10  # 高斯RBF的数量
sigma = 0.5         # 宽度, 越小函数的宽度越小, 越尖锐
# 生成RBF中心
centers = np.linspace(start, stop, num_gaussians)
# 生成x值范围
x = np.linspace(start, stop, 400)
# 计算多个高斯RBF
rbfs = [np.exp(- (x - c)**2 / (2 * sigma**2)) for c in centers]
# 绘图
plt.figure(figsize=(12, 6))
for i, rbf in enumerate(rbfs):
    plt.plot(x, rbf, label=f'RBF {i+1}: c={centers[i]:.2f}')
plt.title('Multiple Gaussian Radial Basis Functions')
plt.xlabel('x')
plt.ylabel('RBF(x)')
plt.legend(loc='upper right', bbox_to_anchor=(1.15, 1.05), ncol=1, fontsize='small')
plt.grid(True)
plt.tight_layout()
plt.show()
#%%
# 余弦截断函数
import torch
import matplotlib.pyplot as plt
import numpy as np
# 设置截断距离
cutoff = 10.0
PI = np.pi
# 生成距离范围
r = torch.linspace(0, cutoff, steps=100)
# 余弦截断函数
C_cosine = 0.5 * (torch.cos(r * PI / cutoff) + 1.0)
# 高斯截断函数 C(r) = exp(- (r/R)^2 )
C_gaussian = torch.exp(- (r / cutoff) ** 2)
# 转换为numpy数组
r_np = r.numpy()
C_cosine_np = C_cosine.numpy()
C_gaussian_np = C_gaussian.numpy()
# 绘制图像
plt.figure(figsize=(8, 5))
plt.plot(r_np, C_cosine_np, label='cos', color='blue')
plt.plot(r_np, C_gaussian_np, label='gau', color='red', linestyle='--')
# 添加标题和标签
plt.title('cos vs gau', fontsize=14)
plt.xlabel('r', fontsize=12)
plt.ylabel('C(r)', fontsize=12)
# 添加网格和图例
plt.grid(True, linestyle='--', alpha=0.6)
plt.legend(fontsize=12)
# 显示图像
plt.tight_layout()
plt.show()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值