这段代码定义了图神经网络中的两类编码器:数据集编码器和位置编码器。通过 DataSetEncoders
和 PosEncoder
两个枚举类,代码提供了不同类型的节点和边的编码方案。编码器支持基于分子特征的原子、键嵌入,或通过拉普拉斯(LAP)和随机游走(RWSE)位置编码增强模型的性能。
EncoderLinear
类扩展了 torch.nn.Linear
,用于处理可选的额外输入,而 DataSetEncoders
和 PosEncoder
枚举类用于根据不同的数据集或任务选择适当的编码器。
from helpers.encoders import DataSetEncoders, PosEncoder
from enum import Enum, auto
from torch.nn import Linear
from torch import Tensor
from torch_geometric.data import Data
from lrgb.encoders.laplace import LapPENodeEncoder, LAP_DIM_PE
from lrgb.encoders.kernel import RWSENodeEncoder, KER_DIM_PE
from lrgb.encoders.mol_encoder import AtomEncoder, BondEncoder
class EncoderLinear(Linear):
def forward(self, x: Tensor, pestat=None) -> Tensor:
return super().forward(x)
class DataSetEncoders(Enum):
"""
an object for the different encoders
"""
NONE = auto()
MOL = auto()
@staticmethod
def from_string(s: str):
try:
return DataSetEncoders[s]
except KeyError:
raise ValueError()
def node_encoder(self, in_dim: int, emb_dim: int):
if self is DataSetEncoders.NONE:
return EncoderLinear(in_features=in_dim, out_features=emb_dim)
elif self is DataSetEncoders.MOL:
return AtomEncoder(emb_dim)
else:
raise ValueError(f'DataSetEncoders {self.name} not supported')
def edge_encoder(self, emb_dim: int, model_type):
if self is DataSetEncoders.NONE:
return None
elif self is DataSetEncoders.MOL:
if model_type.is_gcn():
return None
else:
return BondEncoder(emb_dim)
else:
raise ValueError(f'DataSetEncoders {self.name} not supported')
def use_encoders(self) -> bool:
return self is not DataSetEncoders.NONE
class PosEncoder(Enum):
"""
an object for the different encoders
"""
NONE = auto()
LAP = auto()
RWSE = auto()
@staticmethod
def from_string(s: str):
try:
return PosEncoder[s]
except KeyError:
raise ValueError()
def get(self, in_dim: int, emb_dim: int, expand_x: bool):
if self is PosEncoder.NONE:
return None
elif self is PosEncoder.LAP:
return LapPENodeEncoder(dim_in=in_dim, dim_emb=emb_dim, expand_x=expand_x)
elif self is PosEncoder.RWSE:
return RWSENodeEncoder(dim_in=in_dim, dim_emb=emb_dim, expand_x=expand_x)
else:
raise ValueError(f'DataSetEncoders {self.name} not supported')
def DIM_PE(self):
if self is PosEncoder.NONE:
return None
elif self is PosEncoder.LAP:
return LAP_DIM_PE
elif self is PosEncoder.RWSE:
return KER_DIM_PE
else:
raise ValueError(f'DataSetEncoders {self.name} not supported')
def get_pe(self, data: Data, device):
if self is PosEncoder.NONE:
return None
elif self is PosEncoder.LAP:
return [data.EigVals.to(device), data.EigVecs.to(device)]
elif self is PosEncoder.RWSE:
return data.pestat_RWSE.to(device)
else:
raise ValueError(f'DataSetEncoders {self.name} not supported')
其中包括
from lrgb.encoders.laplace import LapPENodeEncoder, LAP_DIM_PE可以查看