Python ogbn_products图数据训练,大规模图数据训练实战代码,大规模图数据训练实战;products图数据训练样例和实战

一、任务类型

训练大规模数据集ogbn_products数据集

ogbn_products:

ogbn-products 是 Open Graph Benchmark (OGB) 数据集中的一个,代表亚马逊产品联合购买网络的一个无向无权图 。在该数据集中,节点代表亚马逊上销售的产品,节点之间的边表示两个产品被一起购买的情况。节点特征是基于产品说明书中的单词,通过主成分分析(PCA)降维到100维 。

数据集的任务是对产品进行多类别分类,共有47个顶级类别用作目标标签 。数据集的划分不是随机的,而是根据产品的销售排名来进行,其中销售量最高的8%用于训练集,接下来的2%用于验证集,其余的用于测试集,这种划分方式更具挑战性,并且更贴近现实世界的应用场景 。

此外,ogbn-products 数据集的图结构信息默认以边表 edge_index 的形式存在 。在使用该数据集进行图神经网络(GNN)训练时,可以使用 PyTorch Geometric (PyG) 库中的 PygNodePropPredDataset 来加载数据集,并利用 Evaluator对模型进行评估 。

ogbn-products 数据集是 OGB 节点属性预测(Node Property Prediction)数据集中规模为中等的一个,具有2,449,029个节点和61,859,140条边 。使用该数据集可以研究和测试图神经网络模型在大规模数据集上的性能和泛化能力。

二、代码实战:

1.模型框架部分:

对于model的设计有两个分支,GCN和mlp,这两个分别作为model训练

models.gcn.py   实际上用的graphsage模块

import logging
import typing as tp

import pytorch_lightning as pl

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import NeighborSampler
from torch_geometric.nn import SAGEConv
from torchmetrics import Accuracy, F1Score, Precision, Recall
from tqdm import tqdm

from configs.base import LRScheduler, Optimizer


class GraphConvNetwork(nn.Module):
    """Graph Convolutional Network."""

    def __init__(
        self,
        num_features: tp.List[int],
        num_classes: int,
        dropout: float
    ):
        """
        Initialize GraphConvNetwork.

        Parameters
        ----------
        num_features : tp.List[int]
            Number of features in each vertex (number of convolutions).
        num_classes : int
            Number of classes in data.
        dropout : float, optional
           Dropout rate, by default 0.
        """
        super().__init__()
        self.layers = nn.ModuleList()
        num_layers = len(num_features)
        for i in range(num_layers - 1):
            self.layers.append(
                SAGEConv(in_channels=num_features[i], out_channels=num_features[i + 1]),
            )
        self.layers.append(
            SAGEConv(in_channels=num_features[-1], out_channels=num_classes),
        )
        self.dropout = dropout

    def forward(self, x: torch.tensor, edge_index: torch.tensor) -> torch.tensor:
        for layer in self.layers:
            x = layer(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout)
        return torch.log_softmax(x, dim=-1)

    def inference(self, x_all: torch.tensor, subgraph_loader: NeighborSampler) -> torch.tensor:
        for i, conv in enumerate(self.layers):
            conv = conv.cpu()
            xs = []
            for batch_size, n_id, adj in tqdm(subgraph_loader):
                edge_index, _, size = adj
                x = x_all[n_id]
                x_target = x[:size[1]]
                x = conv((x, x_target), edge_index)
                if i != len(self.layers) - 1:
                    x = F.relu(x)
                xs.append(x.cpu())
            x_all = torch.cat(xs, dim=0)
        return x_all


class GCNModule(pl.LightningModule):
    def __init__(
        self,
        model_params: dict,
        optimizer: Optimizer,
        lr_scheduler: LRScheduler | None,
    ):
        """
        Initialize GCNModule.

        Parameters
        ----------
        model_params : dict
            Dictionary with GraphConvNetwork parameters.
        optimizer : Optimizer
            Optimizer.
        lr_scheduler : LRScheduler
            Learning rate scheduler.
        """
        super().__init__()
        self.save_hyperparameters(ignore=['criterion'])  # criterion is already saved during checkpointing
        self.learning_rate = optimizer.opt_params['lr']

        self.model = GraphConvNetwork(**model_params)
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

        self.criterion = nn.CrossEntropyLoss()

        num_classes = model_params['num_classes']
        task = 'multiclass'
        self.accuracy = Accuracy(task=task, num_classes=num_classes)
        self.f1_score = F1Score(task=task, num_classes=num_classes)
        self.precision_score = Precision(task=task, num_classes=num_classes)
        self.recall_score = Recall(task=task, num_classes=num_classes)

        self.test_probs = None

    def forward(self, x: torch.tensor, edge_index: torch.tensor):
        return self.model(x, edge_index)

    def training_step(self, batch, batch_idx):
        targets = batch.y.squeeze(1)[batch.train_mask]
        probs = self.forward(batch.x, batch.edge_index)[batch.train_mask]
        batch_size = len(targets)

        loss = self.criterion(probs, targets)
        self.log('train_loss', loss, on_epoch=False, on_step=True, batch_size=batch_size)

        self.accuracy(probs, targets.long())
        self.log('train_acc', self.accuracy, on_epoch=False, on_step=True, batch_size=batch_size)

        self.f1_score(probs, targets.long())
        self.log('train_f1', self.f1_score, on_epoch=False, on_step=True, batch_size=batch_size)

        self.precision_score(probs, targets.long())
        self.log('train_precision', self.precision_score, on_epoch=False, on_step=True, batch_size=batch_size)

        self.recall_score(probs, targets.long())
        self.log('train_recall_score', self.recall_score, on_epoch=False, on_step=True, batch_size=batch_size)
        return loss

    def validation_step(self, batch, batch_idx):
        targets = batch.y.squeeze(1)[batch.valid_mask]
        probs = self.forward(batch.x, batch.edge_index)[batch.valid_mask]
        batch_size = len(targets)

        loss = self.criterion(probs, targets)
        self.log('val_loss', loss, on_epoch=True, on_step=True, batch_size=batch_size)

        self.accuracy(probs, targets.long())
        self.log('val_acc', self.accuracy, on_epoch=True, on_step=False, batch_size=batch_size)

        self.f1_score(probs, targets.long())
        self.log('val_f1', self.f1_score, on_epoch=True, on_step=False, batch_size=batch_size)

        self.precision_score(probs, targets.long())
        self.log('val_precision', self.precision_score, on_epoch=True, on_step=False, batch_size=batch_size)

        self.recall_score(probs, targets.long())
        self.log('val_recall_score', self.recall_score, on_epoch=True, on_step=False, batch_size=batch_size)
        return loss

    def on_test_start(self) -> None:
        logging.info('Starting testing...')
        self.test_probs = self.model.inference(
            x_all=self.trainer.datamodule.data.x,
            subgraph_loader=self.trainer.test_dataloaders[0],
        )

    def test_step(self, batch, batch_idx):
        return 0

    def on_test_end(self) -> None:
        targets = self.trainer.datamodule.data.y
        y_pred = self.test_probs.argmax(dim=-1, keepdim=True)

        acc = self.accuracy(y_pred, targets.long())
        f1_score = self.f1_score(y_pred, targets.long())
        precision_score = self.precision_score(y_pred, targets.long())
        recall_score = self.recall_score(y_pred, targets.long())

        self.logger.experiment.log(
            {
                'test_acc': acc,
                'test_f1': f1_score,
                'test_precision': precision_score,
                'test_recall_score': recall_score,
            }
        )

    def configure_optimizers(self):
        optimizer = getattr(torch.optim, self.optimizer.name)(
            self.parameters(),
            **self.optimizer.opt_params,
        )

        optim_dict = {
            'optimizer': optimizer,
            'monitor': 'val_loss',
        }

        if self.lr_scheduler is not None:
            lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler.name)(
                optimizer,
                **self.lr_scheduler.lr_sched_params,
            )

            optim_dict.update({'lr_scheduler': lr_scheduler})
        return optim_dict

models.mlp.py 这个方法用来做对照

import typing as tp

import pytorch_lightning as pl
import torch
from torch import nn
from torchmetrics import Accuracy, F1Score, Precision, Recall

from configs.base import LRScheduler, Optimizer


class MultiLayerPerceptron(nn.Module):
    """Base MLP model using here as a baseline."""

    def __init__(
        self,
        num_neurons: tp.List[int],
        num_classes: int,
        dropout: float = 0,
    ):
        """
        Initialize MultiLayerPerceptron.

        Parameters
        ----------
        num_neurons : tp.List[int]
            Number of neurons in each layer.
        num_classes : int
            Number of classes in data.
        dropout : float, optional
           Dropout rate, by default 0.
        """
        super().__init__()

        self.lin_layers = []
        num_layers = len(num_neurons)
        for i in range(num_layers - 1):
            self.lin_layers.extend(
                [
                    nn.Linear(num_neurons[i], num_neurons[i + 1]),
                    nn.ReLU(),
                    nn.Dropout(dropout),
                ],
            )
        self.lin_layers.append(
            nn.Linear(num_neurons[-1], num_classes),
        )
        self.model = nn.Sequential(*self.lin_layers)

    def forward(self, x: torch.tensor) -> torch.tensor:
        x = self.model(x)
        return torch.log_softmax(x, dim=-1)


class MLPModule(pl.LightningModule):
    def __init__(
        self,
        model_params: dict,
        optimizer: Optimizer,
        lr_scheduler: LRScheduler | None,
    ):
        """
        Initialize MLPModule.

        Parameters
        ----------
        model_params : dict
            Dictionary with MultiLayerPerceptron parameters.
        optimizer : Optimizer
            Optimizer.
        lr_scheduler : LRScheduler
            Learning rate scheduler.
        """
        super().__init__()
        self.save_hyperparameters(ignore=['criterion'])  # criterion is already saved during checkpointing
        self.learning_rate = optimizer.opt_params['lr']

        self.model = MultiLayerPerceptron(**model_params)
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

        self.criterion = nn.CrossEntropyLoss()

        num_classes = model_params['num_classes']
        task = 'multiclass'
        self.accuracy = Accuracy(task=task, num_classes=num_classes)
        self.f1_score = F1Score(task=task, num_classes=num_classes)
        self.precision_score = Precision(task=task, num_classes=num_classes)
        self.recall_score = Recall(task=task, num_classes=num_classes)

    def forward(self, x: torch.tensor):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, targets = batch
        targets = targets.squeeze(1)
        probs = self.forward(x)

        loss = self.criterion(probs, targets)
        self.log('train_loss', loss, on_epoch=False, on_step=True)

        self.accuracy(probs, targets.long())
        self.log('train_acc', self.accuracy, on_epoch=False, on_step=True)

        self.f1_score(probs, targets.long())
        self.log('train_f1', self.f1_score, on_epoch=False, on_step=True)

        self.precision_score(probs, targets.long())
        self.log('train_precision', self.precision_score, on_epoch=False, on_step=True)

        self.recall_score(probs, targets.long())
        self.log('train_recall_score', self.recall_score, on_epoch=False, on_step=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, targets = batch
        targets = targets.squeeze(1)
        probs = self.forward(x)

        loss = self.criterion(probs, targets)
        self.log('val_loss', loss, on_epoch=True, on_step=True)

        self.accuracy(probs, targets.long())
        self.log('val_acc', self.accuracy, on_epoch=True, on_step=False)

        self.f1_score(probs, targets.long())
        self.log('val_f1', self.f1_score, on_epoch=True, on_step=False)

        self.precision_score(probs, targets.long())
        self.log('val_precision', self.precision_score, on_epoch=True, on_step=False)

        self.recall_score(probs, targets.long())
        self.log('val_recall_score', self.recall_score, on_epoch=True, on_step=False)
        return loss

    def test_step(self, batch, batch_idx):
        x, targets = batch
        targets = targets.squeeze(1)
        probs = self.forward(x)

        loss = self.criterion(probs, targets)
        self.log('test_loss', loss)

        self.accuracy(probs, targets.long())
        self.log('test_acc', self.accuracy)

        self.f1_score(probs, targets.long())
        self.log('test_f1', self.f1_score)

        self.precision_score(probs, targets.long())
        self.log('test_precision', self.precision_score)

        self.recall_score(probs, targets.long())
        self.log('test_recall_score', self.recall_score)
        return loss

    def configure_optimizers(self):
        optimizer = getattr(torch.optim, self.optimizer.name)(
            self.parameters(),
            **self.optimizer.opt_params,
        )

        optim_dict = {
            'optimizer': optimizer,
            'monitor': 'val_loss',
        }

        if self.lr_scheduler is not None:
            lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler.name)(
                optimizer,
                **self.lr_scheduler.lr_sched_params,
            )

            optim_dict.update({'lr_scheduler': lr_scheduler})
        return optim_dict

2. 数据处理部分

分为两个文件data/embeddings_dataset.py 和 data/graph_dataset.py

import logging
import typing as tp

import pytorch_lightning as pl
import torch
from ogb.nodeproppred import PygNodePropPredDataset
from torch.utils.data import DataLoader, Dataset


class EmbeddingsDataset(Dataset):
    def __init__(
        self,
        root: str,
        mode: tp.Literal['train', 'valid', 'test'],
    ):
        self.root = root
        embeddings, targets = self._get_embeddings_and_targets(mode)
        self.embeddings = embeddings
        self.targets = targets

    def _get_embeddings_and_targets(self, mode: str) -> tp.Tuple[torch.tensor, torch.tensor]:
        dataset = PygNodePropPredDataset(name='ogbn-products', root=self.root)
        graph = dataset[0]
        indices = dataset.get_idx_split()[mode]
        embeddings = graph.x[indices]
        targets = graph.y[indices]
        return embeddings, targets

    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, index):
        embedding = self.embeddings[index]
        target = self.targets[index]
        return embedding, target


class EmbeddingsDataModule(pl.LightningDataModule):
    def __init__(
        self,
        root: str,
        batch_size: int,
        num_workers: int,
    ):
        """
        Create Data Module for EmbeddingsDataset.

        Parameters
        ----------
        root : str
            Path to root dir with dataset.
        batch_size : int
            Batch size for dataloaders.
        num_workers : int
            Number of workers in dataloaders.
        """
        super().__init__()
        self.save_hyperparameters()
        self.root = root
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage: tp.Optional[str] = None):
        if stage == 'fit' or stage is None:
            self.train_dataset = EmbeddingsDataset(
                root=self.root,
                mode='train',
            )
            num_train_files = len(self.train_dataset)
            logging.info(f'Mode: train, number of nodes: {num_train_files}')

            self.val_dataset = EmbeddingsDataset(
                root=self.root,
                mode='valid',
            )
            num_val_files = len(self.val_dataset)
            logging.info(f'Mode: val, number of nodes: {num_val_files}')

        elif stage == 'test':
            self.test_dataset = EmbeddingsDataset(
                root=self.root,
                mode='test',
            )
            num_test_files = len(self.test_dataset)
            logging.info(f'Mode: test, number of nodes: {num_test_files}')

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, drop_last=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, drop_last=False,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, drop_last=False,
        )

 data/graph_dataset.py 文件

在train和valid部分,用的是ClusterLoader采样

在test部分,用的是NeighborSampler采样

ClusterLoader和NeighborSampler都是用于在大规模图数据上进行节点嵌入学习时的采样方法,但它们在实现方式和应用场景上有所不同。

  1. ClusterLoader

    • ClusterLoader主要用于处理大规模图数据,它通过将图分割成多个簇(clusters),然后在每个簇上独立地进行训练。这种方法可以减少内存消耗,因为不需要将整个图加载到内存中。
    • ClusterLoader通常与ClusterGCN方法结合使用,该方法通过递归地在簇内进行前向和后向传播来训练模型。
    • 在ClusterLoader中,每个簇可以看作是一个子图,模型在每个子图上独立地进行训练,然后将结果合并以获得最终的节点嵌入。
  2. NeighborSampler

    • NeighborSampler是GraphSAGE算法中使用的一种采样方法,它通过递归地采样节点的邻居来构建一个有向的多跳子图。
    • NeighborSampler的核心思想是,给定一个mini-batch的节点和图卷积的层数L,以及每一层需要采样的邻居数目sizes,依次从第一层到第L层,对每一层进行邻居采样并返回一个二分图(bipartite subgraph)。
    • 在每一层采样中,NeighborSampler会使用上一层采样中涉及到的所有节点进行邻居采样,这样可以保证在L层采样完成后,返回的结果包含所有在L层卷积中遇到的节点。
    • NeighborSampler适用于节点维度的mini-batch训练,它允许在完全批量训练不可行的情况下,对大规模图上的GNN进行小批量训练。

总的来说,ClusterLoader更适合于处理大规模图数据,通过将图分割成簇来减少内存消耗,而NeighborSampler则适用于节点维度的mini-batch训练,通过递归采样邻居来构建子图。两者都是解决大规模图数据上的节点嵌入学习问题的有效方法,但具体的选择取决于图的特性和可用的计算资源。

import logging
import typing as tp

import torch
import pytorch_lightning as pl
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric.loader import ClusterLoader, ClusterData, NeighborSampler


class OGBGProductsDatamodule(pl.LightningDataModule):
    """DataModule for mini-batch GCN training using the Cluster-GCN algorithm."""

    def __init__(
        self,
        root: str,
        batch_size: int,
        num_workers: int,
        save_dir: tp.Optional[str] = 'data/ogbn_products/processed',
        num_partitions: int = 15000,
    ):
        """
        Create Data Module for OGBG Product task.

        Parameters
        ----------
        root : str
            Path to the root directory.
        batch_size : int
            Batch size for dataloaders.
        num_workers : int
            Number of workers in dataloaders.
        save_dir : tp.Optional[str]
            Directory where already partitioned dataset is stored.
        num_partitions : int
            Number of partitions.
        """
        super().__init__()
        self.save_hyperparameters()
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.root = root
        self.num_partitions = num_partitions
        self.save_dir = save_dir

        self.data = None
        self.split_idx = None
        self.cluster_data = None

    def prepare_data(self):
        dataset = PygNodePropPredDataset(name='ogbn-products', root=self.root)
        self.split_idx = dataset.get_idx_split()
        self.data = dataset[0]

        # Convert split indices to boolean masks and add them to `data`
        for key, idx in self.split_idx.items():
            mask = torch.zeros(self.data.num_nodes, dtype=torch.bool)
            mask[idx] = True
            self.data[f'{key}_mask'] = mask

        self.cluster_data = ClusterData(
            self.data,
            num_parts=self.num_partitions,
            recursive=False,
            save_dir=self.save_dir,
        )

    def setup(self, stage: tp.Optional[str] = None):
        if stage == 'fit' or stage is None:
            self.train_split = self.split_idx['train']
            num_train_files = len(self.train_split)
            logging.info(f'Mode: train, number of nodes: {num_train_files}')

            self.valid_split = self.split_idx['valid']
            num_valid_files = len(self.valid_split)
            logging.info(f'Mode: valid, number of nodes: {num_valid_files}')

        elif stage == 'test':
            self.test_split = self.split_idx['test']
            num_test_files = len(self.test_split)
            logging.info(f'Mode: test, number of nodes: {num_test_files}')

    def train_dataloader(self):
        return ClusterLoader(
            self.cluster_data,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
        )

    def val_dataloader(self):
        return ClusterLoader(
            self.cluster_data,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
        )

    def test_dataloader(self):
        return NeighborSampler(
            self.data.edge_index,
            sizes=[-1],
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
        )

 3.代码文件配置部分

configs/base.py

import typing as tp
from dataclasses import dataclass

import pytorch_lightning as pl


@dataclass
class Project:
    project_name: str
    run_name: str
    notes: str
    tags: str
    log_freq: int = 100


@dataclass
class Common:
    seed: int = 8


@dataclass
class Dataset:
    module: pl.LightningDataModule
    root: str  # path to root directory with images
    batch_size: int
    num_workers: int


@dataclass
class Model:
    module: pl.LightningModule
    model_params: dict


@dataclass
class Callbacks:
    model_checkpoint: pl.callbacks.ModelCheckpoint
    early_stopping: tp.Optional[pl.callbacks.EarlyStopping] = None
    lr_monitor: tp.Optional[pl.callbacks.LearningRateMonitor] = None
    model_summary: tp.Optional[tp.Union[pl.callbacks.ModelSummary, pl.callbacks.RichModelSummary]] = None
    timer: tp.Optional[pl.callbacks.Timer] = None


@dataclass
class Optimizer:
    name: str
    opt_params: dict


@dataclass
class LRScheduler:
    name: str
    lr_sched_params: dict


@dataclass
class Train:
    trainer_params: dict
    callbacks: Callbacks
    optimizer: Optimizer
    lr_scheduler: LRScheduler
    ckpt_path: tp.Optional[str] = None


@dataclass
class Config:
    project: Project
    common: Common
    dataset: Dataset
    model: Model
    train: Train

configs/config_gcn.py 

import os
import uuid

import pytorch_lightning as pl

from configs.base import (Callbacks, Common, Config, Dataset, LRScheduler,
                          Model, Optimizer, Project, Train)
from data.graph_dataset import OGBGProductsDatamodule
from models.gcn import GCNModule

RUN_NAME = 'gcn_' + uuid.uuid4().hex[:6]  # unique run id


CONFIG = Config(
    project=Project(
        log_freq=10,
        project_name='OGBN Product',
        run_name=RUN_NAME,
        tags='gcn',
        notes='',
    ),

    common=Common(seed=8),

    dataset=Dataset(
        module=OGBGProductsDatamodule,
        root='data/',
        batch_size=512,
        num_workers=6,
    ),

    model=Model(
        module=GCNModule,
        model_params={
            'num_features': [100, 256, 256, 256],
            'num_classes': 47,
            'dropout': 0.15,
        },
    ),

    train=Train(
        trainer_params={
            'devices': 1,
            'accelerator': 'auto',
            'accumulate_grad_batches': 1,
            'auto_scale_batch_size': None,
            'gradient_clip_val': 0.0,
            'benchmark': True,
            'precision': 32,
            'max_epochs': 200,
            'auto_lr_find': None,
        },

        callbacks=Callbacks(
            model_checkpoint=pl.callbacks.ModelCheckpoint(
                dirpath=os.path.join('checkpoints', RUN_NAME),
                save_top_k=2,
                monitor='val_loss',
                mode='min',
            ),

            lr_monitor=pl.callbacks.LearningRateMonitor(logging_interval='step'),
        ),

        optimizer=Optimizer(
            name='Adam',
            opt_params={
                'lr': 0.001,
                'weight_decay': 0.0001,
            },
        ),

        lr_scheduler=LRScheduler(
            name='CosineAnnealingWarmRestarts',
            lr_sched_params={
                'T_0': 200,
                'T_mult': 1,
                'eta_min': 0.00001,
            },
        ),

        ckpt_path=None,
    ),
)

configs/config_mlp.py 

import os
import uuid

import pytorch_lightning as pl

from configs.base import (Callbacks, Common, Config, Dataset, LRScheduler,
                          Model, Optimizer, Project, Train)
from data.embeddings_dataset import EmbeddingsDataModule
from models.mlp import MLPModule

RUN_NAME = 'mlp_baseline_' + uuid.uuid4().hex[:6]  # unique run id


CONFIG = Config(
    project=Project(
        log_freq=500,
        project_name='OGBN Product',
        run_name=RUN_NAME,
        tags='mlp, baselone',
        notes='',
    ),

    common=Common(seed=8),

    dataset=Dataset(
        module=EmbeddingsDataModule,
        root='data/',
        batch_size=512,
        num_workers=6,
    ),

    model=Model(
        module=MLPModule,
        model_params={
            'num_neurons': [100, 256, 256, 256],
            'num_classes': 47,
            'dropout': 0.15,
        },
    ),

    train=Train(
        trainer_params={
            'devices': 1,
            'accelerator': 'auto',
            'accumulate_grad_batches': 1,
            'auto_scale_batch_size': None,
            'gradient_clip_val': 0.0,
            'benchmark': True,
            'precision': 32,
            'max_epochs': 135,
            'auto_lr_find': None,
        },

        callbacks=Callbacks(
            model_checkpoint=pl.callbacks.ModelCheckpoint(
                dirpath=os.path.join('checkpoints', RUN_NAME),
                save_top_k=2,
                monitor='val_loss',
                mode='min',
            ),

            lr_monitor=pl.callbacks.LearningRateMonitor(logging_interval='step'),
        ),

        optimizer=Optimizer(
            name='Adam',
            opt_params={
                'lr': 0.001,
                'weight_decay': 0.0001,
            },
        ),

        lr_scheduler=LRScheduler(
            name='CosineAnnealingWarmRestarts',
            lr_sched_params={
                'T_0': 120,
                'T_mult': 1,
                'eta_min': 0.00001,
            },
        ),

        ckpt_path=None,
    ),
)

 configs/utils.py

配置文件的统筹代码 

import pytorch_lightning as pl

from configs.base import Config, LRScheduler, Optimizer

SEP = '_'


def get_dict_for_optimizer(config_optimizer: Optimizer) -> dict:
    opt_dict = {}
    opt_dict[SEP.join(['optimizer', 'name'])] = config_optimizer.name
    for key, value in config_optimizer.opt_params.items():
        opt_dict[SEP.join(['optimizer', key])] = value
    return opt_dict


def get_dict_for_lr_scheduler(config_lr_scheduler: LRScheduler) -> dict:
    lr_dict = {}
    lr_dict[SEP.join(['lr_scheduler', 'name'])] = config_lr_scheduler.name
    for key, value in config_lr_scheduler.lr_sched_params.items():
        lr_dict[SEP.join(['lr_scheduler', key])] = value
    return lr_dict


def get_config_dict(model: pl.LightningModule, datamodule: pl.LightningDataModule, config: Config) -> dict:
    config_dict = {}
    datamodule_dict = dict(datamodule.hparams)
    trainer_dict = config.train.trainer_params
    optimizer_dict = get_dict_for_optimizer(config.train.optimizer)
    lr_sched_dict = get_dict_for_lr_scheduler(config.train.lr_scheduler)

    # using update to avoid doubling keys
    # config_dict.update(model_dict)
    config_dict.update(datamodule_dict)
    config_dict.update(trainer_dict)
    config_dict.update(optimizer_dict)
    config_dict.update(lr_sched_dict)
    return config_dict

4.训练启动文件 

train.py

import argparse
import logging
import os
import typing as tp
from runpy import run_path

import pytorch_lightning as pl
import wandb
from kornia.losses.focal import FocalLoss
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import RichModelSummary
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import WandbLogger

from configs.base import Config
from configs.utils import get_config_dict


def get_wandb_logger(model: pl.LightningModule, datamodule: pl.LightningDataModule, config: Config) -> WandbLogger:
    config_dict = get_config_dict(model=model, datamodule=datamodule, config=config)
    wandb_logger = WandbLogger(
        project=config.project.project_name,
        name=config.project.run_name,
        tags=config.project.tags.split(','),
        notes=config.project.notes,
        config=config_dict,
        log_model='all',
    )
    wandb_logger.watch(
        model=model,
        log='all',
        log_freq=config.project.log_freq,  # log gradients and parameters every log_freq batches
    )
    return wandb_logger


def parse() -> tp.Any:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--config',
        default=os.path.join('src', 'configs', 'config_mlp.py'),
        type=str,
        help='Path to experiment config file (*.py)',
    )
    parser.add_argument(
        '--loss',
        default='bce',
        type=str,
        choices=['bce', 'focal'],
        help='Loss type',
    )
    return parser.parse_args()


def main(args: tp.Any, config: Config):
    # model
    model = config.model.module(
        config.model.model_params,
        optimizer=config.train.optimizer,
        lr_scheduler=config.train.lr_scheduler,
    )

    if args.loss == 'focal':
        model.criterion = FocalLoss(alpha=1, gamma=2, reduction='mean')
        logging.info('Using FocalLoss')

    # data module
    datamodule = config.dataset.module(
        root=config.dataset.root,
        batch_size=config.dataset.batch_size,
        num_workers=config.dataset.num_workers,
    )

    # logger
    logger = get_wandb_logger(model, datamodule, config)
    base_path = os.path.split(args.config)[0]
    wandb.save(args.config, base_path=base_path, policy='now')

    # trainer
    trainer_params = config.train.trainer_params
    callbacks = list(config.train.callbacks.__dict__.values())
    callbacks = filter(lambda callback: callback is not None, callbacks)
    trainer = Trainer(
        callbacks=[
            TQDMProgressBar(refresh_rate=1),
            RichModelSummary(),
            *callbacks,
        ],
        logger=logger,
        **trainer_params,
    )

    if trainer_params['auto_scale_batch_size'] is not None or trainer_params['auto_lr_find'] is not None:
        trainer.tune(model=model, datamodule=datamodule)

    trainer.fit(
        model=model,
        datamodule=datamodule,
        ckpt_path=config.train.ckpt_path,
    )

    trainer.test(
        model=model,
        datamodule=datamodule,
        ckpt_path='best',
    )


if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO)
    args = parse()
    config_module = run_path(args.config)
    exp_config = config_module['CONFIG']
    seed_everything(exp_config.common.seed, workers=True)
    main(args, exp_config)

三、环境配置

ogb==1.3.5
pytorch_lightning==1.8.6
torch==1.13.1
torch_geometric==2.2.0
torchmetrics==0.11.0
tqdm==4.62.3
wandb==0.13.7 

torch_geometric的安装,需要安装依赖包,这个根据cuda版本和pytorch版本安装

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

医学小达人

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值