from math import pi as PI
import torch
# .F封装了很多函数, 没有可学习参数
import torch.nn.functional as F
# torch.nn封装了很多神经网络单元, 比如Sequential序列, lstm, conv, Embedding单元等(当然自带参数追踪等)
from torch.nn import Embedding, Sequential, Linear
from torch_scatter import scatter
from torch_geometric.nn import radius_graph
# iteraction 黄色模块, 对节点的信息进行更新
# energy_and_force=False, cutoff=10.0, num_layers=6, hidden_channels=128, out_channels=1, num_filters=128, num_gaussians=50
class update_v(torch.nn.Module):
def __init__(self, hidden_channels, num_filters):
#每个类模型默认会添加def __init__(self)假如不需要额外的实例的属性时可以不写, self指实例本身,可以换其他名字
#__init__ 方法是一个类的初始化方法, 它会在你创建一个新的对象时自动被调用, 它用于初始化该对象的属性.
super().__init__()
#必须加,否则继承不了属性和方法(父类没有初始化)
#super(residual_block, self).__init__() 上面是简化的写法, 这行意思是当前需要继承的类的名字的实例本身
self.act = ShiftedSoftplus()
self.lin1 = Linear(num_filters, hidden_channels)
# Linear是线性层, 有偏置, 没激活
# 数据按照行理解的, 第一个的列等于第二个的行, 列是特征大小, 后面是神经元的数量, 也是输出维度向量的维度,向量就算是列向量也是横着理解的,比如embedding,nn.embedding(108, 64)矩阵维度108*64, 也是按照行理解
self.lin2 = Linear(hidden_channels, hidden_channels)
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.xavier_uniform_(self.lin1.weight)
self.lin1.bias.data.fill_(0)
torch.nn.init.xavier_uniform_(self.lin2.weight)
self.lin2.bias.data.fill_(0)
def forward(self, v, e, edge_index): # self.lin1 = Linear(num_filters, hidden_channels)
_, i = edge_index # hidden_channels就是原子特征的维度
out = scatter(e, i, dim=0) # 聚合
out = self.lin1(out) # [num_nodes, num_filters]
out = self.act(out)
out = self.lin2(out) # [num_nodes, hidden_channels]激活函数不改变张量维度,只进行非线性变换
return v + out
#%%
# filter generator 橙色模块
# 该模块的物理意义是,邻居原子对目标原子的影响力会随着距离的增大而衰减(C矩阵, 余弦矩阵). 邻居原子对目标原子的影响是一个距离的函数,而且该函数是可学的一个MLP模块(self.mlp(dist_emb))
# 这个update_e到底什么意思, 应该该是计算出节点向其邻居传递的信息大小
# 输入维度, 节点特征 v [num_nodes, hidden_channels]
# 距离嵌入 dist_emb:[num_edges, num_gaussians]
# 输出维度, 边特征 e:[num_edges, num_filters]
class update_e(torch.nn.Module):
def __init__(self, hidden_channels, num_filters, num_gaussians, cutoff):
super().__init__()
self.cutoff = cutoff
self.lin = Linear(hidden_channels, num_filters, bias=False)
self.mlp = Sequential(Linear(num_gaussians, num_filters), ShiftedSoftplus(), Linear(num_filters, num_filters))
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.xavier_uniform_(self.lin.weight)
torch.nn.init.xavier_uniform_(self.mlp[0].weight)
self.mlp[0].bias.data.fill_(0)
torch.nn.init.xavier_uniform_(self.mlp[2].weight)
self.mlp[0].bias.data.fill_(0)
def forward(self, v, dist, dist_emb, edge_index):
j, _ = edge_index
C = 0.5 * (torch.cos(dist * PI / self.cutoff) + 1.0)
W = self.mlp(dist_emb) * C.view(-1, 1) # dist_emb:边数 * 高斯函数数量, mlp的输出:边数 * 64
# .view()将张量铺平然后重塑, 比如形状(4,3,2), 按顺序展开. 它和reshape的作用相同, 但是要求内存连续(不能装置, 切片), 效率更高, reshape会检查连续防止出错.
v = self.lin(v) # v:这批原子的节点数 * 64, 输出: 貌似不变 这层是黄色方块cfconv上的
e = v[j] * W # W:边数 * 64
return e # 每个原子对于其中心原子所要传递的信息
#%%
# 最后的全连接, 池化, 输出层, 进行信息聚合
class update_u(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
self.lin1 = Linear(hidden_channels, hidden_channels // 2)
self.act = ShiftedSoftplus()
self.lin2 = Linear(hidden_channels // 2, out_channels)
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.xavier_uniform_(self.lin1.weight)
self.lin1.bias.data.fill_(0)
torch.nn.init.xavier_uniform_(self.lin2.weight)
self.lin2.bias.data.fill_(0)
def forward(self, v, batch):
v = self.lin1(v)
v = self.act(v) # v:节点数*特征
v = self.lin2(v) # v:节点数*1
u = scatter(v, batch, dim=0) # dim=0, 在第一个维度进行压缩, 如果是(4, 2, 2)在dim=0压缩,变为(2,2), 压缩掉第一个维度, 每个位置累加.
return u
#%%
class SchNet(torch.nn.Module):
r""" energy_and_force (bool, optional): If set to :obj:`True`, will predict energy and take the negative of the derivative of the energy with respect to the atomic positions as predicted forces. (default: :obj:`False`)
num_layers (int, optional): The number of layers. (default: :obj:`6`)
hidden_channels (int, optional): Hidden embedding size. (default: :obj:`128`)
out_channels (int, optional): Output embedding size. (default: :obj:`1`)
num_filters (int, optional): The number of filters to use. (default: :obj:`128`)
num_gaussians (int, optional): The number of gaussians :math:`\mu`. (default: :obj:`50`)
cutoff (float, optional): Cutoff distance for interatomic interactions. (default: :obj:`10.0`)."""
def __init__(self, energy_and_force=False, cutoff=10.0, num_layers=6, hidden_channels=128, out_channels=1, num_filters=128, num_gaussians=50):
super(SchNet, self).__init__()
self.energy_and_force = energy_and_force
self.cutoff = cutoff
self.num_layers = num_layers
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.num_filters = num_filters
self.num_gaussians = num_gaussians
# v : vertex, node embedding, node feature
self.init_v = Embedding(100, hidden_channels)
self.dist_emb = emb(0.0, cutoff, num_gaussians)
self.update_vs = torch.nn.ModuleList([update_v(hidden_channels, num_filters) for _ in range(num_layers)])
self.update_es = torch.nn.ModuleList([update_e(hidden_channels, num_filters, num_gaussians, cutoff) for _ in range(num_layers)])
self.update_u = update_u(hidden_channels, out_channels)
self.reset_parameters()
def reset_parameters(self):
self.init_v.reset_parameters()
for update_e in self.update_es:
update_e.reset_parameters()
for update_v in self.update_vs:
update_v.reset_parameters()
self.update_u.reset_parameters()
def forward(self, batch):
z, pos, batch = batch.z, batch.pos, batch.batch
if self.energy_and_force:
pos.requires_grad_()
edge_index = radius_graph(pos, r=self.cutoff, batch=batch)
row, col = edge_index
dist = (pos[row] - pos[col]).norm(dim=-1) # 沿着维度1操作, 对每个(行)求模长
dist_emb = self.dist_emb(dist) # 高斯函数进行展开
v = self.init_v(z) # v是节点特征,进行初始化embedding
for update_e, update_v in zip(self.update_es, self.update_vs):
e = update_e(v, dist, dist_emb, edge_index)
v = update_v(v,e, edge_index)
u = self.update_u(v, batch) # atom-wise, 逐原子的, 说wise是因为没有信息的交互, 最后的全连接层
return u
# act
class ShiftedSoftplus(torch.nn.Module):
def __init__(self):
super(ShiftedSoftplus, self).__init__()
self.shift = torch.log(torch.tensor(2.0)).item() # .item()转换为Python标量
def forward(self, x):
return F.softplus(x) - self.shift
#%% md
1. 为什么这里要加.item()?
不需要对偏移量计算梯度, 尽管利用torch.tensor创建的张量默认require_grad=False, 以防万一算了造成资源浪费.
默认设备cpu, 转为标量在后续计算时候会自动与x设备对齐(这时会隐式转为一个不计算梯度的张量).
pytorch默认叶节点(用户直接创建且require_grad=True的张量)保留梯度, 非叶节点(运算生成的中间变量)会计算但是require_grad=None, 必须要用.retain_grad()才能保留.
2. 为什么激活函数需要写成类,继承nn.Module?
自定义的pytorch模块都需要继承nn.Module, 其包含了.parameter等一系列属性方法. 不继承在nn.Sequential中也无法使用, 其通过Module进行管理.
#%%
# gaussian_emb
import torch
class emb(torch.nn.Module):
def __init__(self, start=0.0, stop=5.0, num_gaussians=50):
super().__init__()
offset = torch.linspace(start, stop, num_gaussians)
self.coeff = -0.5 / (offset[1] - offset[0]).item()**2
self.register_buffer('offset', offset)
def forward(self, dist):
dist = dist.view(-1, 1) - self.offset.view(1, -1)
return torch.exp(self.coeff * torch.pow(dist, 2))
emb = emb(0.0, 5.0, 10)
print(emb(torch.tensor([0.0, 5.0, 5.0])).shape)
# 输出shape(3, 10)
print(torch.tensor([0.0, 5.0, 5.0]).shape)
#%%
# 用于展开的高斯函数的图像
import numpy as np
import matplotlib.pyplot as plt
# 单个
c = 0 # 中心
sigma = 0.1 # 宽度,越小函数的宽度越小,越尖锐
# 生成x值范围
x = np.linspace(-5, 5, 400)
# 计算高斯RBF
rbf = np.exp(- (x - c)**2 / (2 * sigma**2))
# 绘图
plt.figure(figsize=(8, 4))
plt.plot(x, rbf, label=f'Gaussian RBF\n$c$={c}, $σ$={sigma}')
plt.title('Gaussian RBF')
plt.xlabel('x')
plt.ylabel('RBF(x)')
plt.show()
# 多个
# 实现中所有RBF的宽度σ都等于RBF中心之间的距离Δx。
# 参数定义
start = -5.0 # RBF中心的起始值
stop = 5.0 # RBF中心的终止值
num_gaussians = 10 # 高斯RBF的数量
sigma = 0.5 # 宽度, 越小函数的宽度越小, 越尖锐
# 生成RBF中心
centers = np.linspace(start, stop, num_gaussians)
# 生成x值范围
x = np.linspace(start, stop, 400)
# 计算多个高斯RBF
rbfs = [np.exp(- (x - c)**2 / (2 * sigma**2)) for c in centers]
# 绘图
plt.figure(figsize=(12, 6))
for i, rbf in enumerate(rbfs):
plt.plot(x, rbf, label=f'RBF {i+1}: c={centers[i]:.2f}')
plt.title('Multiple Gaussian Radial Basis Functions')
plt.xlabel('x')
plt.ylabel('RBF(x)')
plt.legend(loc='upper right', bbox_to_anchor=(1.15, 1.05), ncol=1, fontsize='small')
plt.grid(True)
plt.tight_layout()
plt.show()
#%%
# 余弦截断函数
import torch
import matplotlib.pyplot as plt
import numpy as np
# 设置截断距离
cutoff = 10.0
PI = np.pi
# 生成距离范围
r = torch.linspace(0, cutoff, steps=100)
# 余弦截断函数
C_cosine = 0.5 * (torch.cos(r * PI / cutoff) + 1.0)
# 高斯截断函数 C(r) = exp(- (r/R)^2 )
C_gaussian = torch.exp(- (r / cutoff) ** 2)
# 转换为numpy数组
r_np = r.numpy()
C_cosine_np = C_cosine.numpy()
C_gaussian_np = C_gaussian.numpy()
# 绘制图像
plt.figure(figsize=(8, 5))
plt.plot(r_np, C_cosine_np, label='cos', color='blue')
plt.plot(r_np, C_gaussian_np, label='gau', color='red', linestyle='--')
# 添加标题和标签
plt.title('cos vs gau', fontsize=14)
plt.xlabel('r', fontsize=12)
plt.ylabel('C(r)', fontsize=12)
# 添加网格和图例
plt.grid(True, linestyle='--', alpha=0.6)
plt.legend(fontsize=12)
# 显示图像
plt.tight_layout()
plt.show()