DataSetEncoders, PosEncoder(helpers文件中的encoders.py)

这段代码定义了图神经网络中的两类编码器:数据集编码器位置编码器。通过 DataSetEncodersPosEncoder 两个枚举类,代码提供了不同类型的节点和边的编码方案。编码器支持基于分子特征的原子、键嵌入,或通过拉普拉斯(LAP)和随机游走(RWSE)位置编码增强模型的性能。

EncoderLinear 类扩展了 torch.nn.Linear,用于处理可选的额外输入,而 DataSetEncodersPosEncoder 枚举类用于根据不同的数据集或任务选择适当的编码器。

from helpers.encoders import DataSetEncoders, PosEncoder

from enum import Enum, auto
from torch.nn import Linear
from torch import Tensor
from torch_geometric.data import Data

from lrgb.encoders.laplace import LapPENodeEncoder, LAP_DIM_PE
from lrgb.encoders.kernel import RWSENodeEncoder, KER_DIM_PE
from lrgb.encoders.mol_encoder import AtomEncoder, BondEncoder


class EncoderLinear(Linear):
    def forward(self, x: Tensor, pestat=None) -> Tensor:
        return super().forward(x)


class DataSetEncoders(Enum):
    """
        an object for the different encoders
    """
    NONE = auto()
    MOL = auto()

    @staticmethod
    def from_string(s: str):
        try:
            return DataSetEncoders[s]
        except KeyError:
            raise ValueError()

    def node_encoder(self, in_dim: int, emb_dim: int):
        if self is DataSetEncoders.NONE:
            return EncoderLinear(in_features=in_dim, out_features=emb_dim)
        elif self is DataSetEncoders.MOL:
            return AtomEncoder(emb_dim)
        else:
            raise ValueError(f'DataSetEncoders {self.name} not supported')

    def edge_encoder(self, emb_dim: int, model_type):
        if self is DataSetEncoders.NONE:
            return None
        elif self is DataSetEncoders.MOL:
            if model_type.is_gcn():
                return None
            else:
                return BondEncoder(emb_dim)
        else:
            raise ValueError(f'DataSetEncoders {self.name} not supported')

    def use_encoders(self) -> bool:
        return self is not DataSetEncoders.NONE


class PosEncoder(Enum):
    """
        an object for the different encoders
    """
    NONE = auto()
    LAP = auto()
    RWSE = auto()

    @staticmethod
    def from_string(s: str):
        try:
            return PosEncoder[s]
        except KeyError:
            raise ValueError()

    def get(self, in_dim: int, emb_dim: int, expand_x: bool):
        if self is PosEncoder.NONE:
            return None
        elif self is PosEncoder.LAP:
            return LapPENodeEncoder(dim_in=in_dim, dim_emb=emb_dim, expand_x=expand_x)
        elif self is PosEncoder.RWSE:
            return RWSENodeEncoder(dim_in=in_dim, dim_emb=emb_dim, expand_x=expand_x)
        else:
            raise ValueError(f'DataSetEncoders {self.name} not supported')

    def DIM_PE(self):
        if self is PosEncoder.NONE:
            return None
        elif self is PosEncoder.LAP:
            return LAP_DIM_PE
        elif self is PosEncoder.RWSE:
            return KER_DIM_PE
        else:
            raise ValueError(f'DataSetEncoders {self.name} not supported')

    def get_pe(self, data: Data, device):
        if self is PosEncoder.NONE:
            return None
        elif self is PosEncoder.LAP:
            return [data.EigVals.to(device), data.EigVecs.to(device)]
        elif self is PosEncoder.RWSE:
            return data.pestat_RWSE.to(device)
        else:
            raise ValueError(f'DataSetEncoders {self.name} not supported')

其中包括

from lrgb.encoders.laplace import LapPENodeEncoder, LAP_DIM_PE可以查看

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值