一、任务类型
训练大规模数据集ogbn_products数据集
ogbn_products:
ogbn-products 是 Open Graph Benchmark (OGB) 数据集中的一个,代表亚马逊产品联合购买网络的一个无向无权图 。在该数据集中,节点代表亚马逊上销售的产品,节点之间的边表示两个产品被一起购买的情况。节点特征是基于产品说明书中的单词,通过主成分分析(PCA)降维到100维 。
数据集的任务是对产品进行多类别分类,共有47个顶级类别用作目标标签 。数据集的划分不是随机的,而是根据产品的销售排名来进行,其中销售量最高的8%用于训练集,接下来的2%用于验证集,其余的用于测试集,这种划分方式更具挑战性,并且更贴近现实世界的应用场景 。
此外,ogbn-products 数据集的图结构信息默认以边表 edge_index 的形式存在 。在使用该数据集进行图神经网络(GNN)训练时,可以使用 PyTorch Geometric (PyG) 库中的 PygNodePropPredDataset 来加载数据集,并利用 Evaluator对模型进行评估 。
ogbn-products 数据集是 OGB 节点属性预测(Node Property Prediction)数据集中规模为中等的一个,具有2,449,029个节点和61,859,140条边 。使用该数据集可以研究和测试图神经网络模型在大规模数据集上的性能和泛化能力。
二、代码实战:
1.模型框架部分:
对于model的设计有两个分支,GCN和mlp,这两个分别作为model训练
models.gcn.py 实际上用的graphsage模块
import logging
import typing as tp
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import NeighborSampler
from torch_geometric.nn import SAGEConv
from torchmetrics import Accuracy, F1Score, Precision, Recall
from tqdm import tqdm
from configs.base import LRScheduler, Optimizer
class GraphConvNetwork(nn.Module):
"""Graph Convolutional Network."""
def __init__(
self,
num_features: tp.List[int],
num_classes: int,
dropout: float
):
"""
Initialize GraphConvNetwork.
Parameters
----------
num_features : tp.List[int]
Number of features in each vertex (number of convolutions).
num_classes : int
Number of classes in data.
dropout : float, optional
Dropout rate, by default 0.
"""
super().__init__()
self.layers = nn.ModuleList()
num_layers = len(num_features)
for i in range(num_layers - 1):
self.layers.append(
SAGEConv(in_channels=num_features[i], out_channels=num_features[i + 1]),
)
self.layers.append(
SAGEConv(in_channels=num_features[-1], out_channels=num_classes),
)
self.dropout = dropout
def forward(self, x: torch.tensor, edge_index: torch.tensor) -> torch.tensor:
for layer in self.layers:
x = layer(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=self.dropout)
return torch.log_softmax(x, dim=-1)
def inference(self, x_all: torch.tensor, subgraph_loader: NeighborSampler) -> torch.tensor:
for i, conv in enumerate(self.layers):
conv = conv.cpu()
xs = []
for batch_size, n_id, adj in tqdm(subgraph_loader):
edge_index, _, size = adj
x = x_all[n_id]
x_target = x[:size[1]]
x = conv((x, x_target), edge_index)
if i != len(self.layers) - 1:
x = F.relu(x)
xs.append(x.cpu())
x_all = torch.cat(xs, dim=0)
return x_all
class GCNModule(pl.LightningModule):
def __init__(
self,
model_params: dict,
optimizer: Optimizer,
lr_scheduler: LRScheduler | None,
):
"""
Initialize GCNModule.
Parameters
----------
model_params : dict
Dictionary with GraphConvNetwork parameters.
optimizer : Optimizer
Optimizer.
lr_scheduler : LRScheduler
Learning rate scheduler.
"""
super().__init__()
self.save_hyperparameters(ignore=['criterion']) # criterion is already saved during checkpointing
self.learning_rate = optimizer.opt_params['lr']
self.model = GraphConvNetwork(**model_params)
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.criterion = nn.CrossEntropyLoss()
num_classes = model_params['num_classes']
task = 'multiclass'
self.accuracy = Accuracy(task=task, num_classes=num_classes)
self.f1_score = F1Score(task=task, num_classes=num_classes)
self.precision_score = Precision(task=task, num_classes=num_classes)
self.recall_score = Recall(task=task, num_classes=num_classes)
self.test_probs = None
def forward(self, x: torch.tensor, edge_index: torch.tensor):
return self.model(x, edge_index)
def training_step(self, batch, batch_idx):
targets = batch.y.squeeze(1)[batch.train_mask]
probs = self.forward(batch.x, batch.edge_index)[batch.train_mask]
batch_size = len(targets)
loss = self.criterion(probs, targets)
self.log('train_loss', loss, on_epoch=False, on_step=True, batch_size=batch_size)
self.accuracy(probs, targets.long())
self.log('train_acc', self.accuracy, on_epoch=False, on_step=True, batch_size=batch_size)
self.f1_score(probs, targets.long())
self.log('train_f1', self.f1_score, on_epoch=False, on_step=True, batch_size=batch_size)
self.precision_score(probs, targets.long())
self.log('train_precision', self.precision_score, on_epoch=False, on_step=True, batch_size=batch_size)
self.recall_score(probs, targets.long())
self.log('train_recall_score', self.recall_score, on_epoch=False, on_step=True, batch_size=batch_size)
return loss
def validation_step(self, batch, batch_idx):
targets = batch.y.squeeze(1)[batch.valid_mask]
probs = self.forward(batch.x, batch.edge_index)[batch.valid_mask]
batch_size = len(targets)
loss = self.criterion(probs, targets)
self.log('val_loss', loss, on_epoch=True, on_step=True, batch_size=batch_size)
self.accuracy(probs, targets.long())
self.log('val_acc', self.accuracy, on_epoch=True, on_step=False, batch_size=batch_size)
self.f1_score(probs, targets.long())
self.log('val_f1', self.f1_score, on_epoch=True, on_step=False, batch_size=batch_size)
self.precision_score(probs, targets.long())
self.log('val_precision', self.precision_score, on_epoch=True, on_step=False, batch_size=batch_size)
self.recall_score(probs, targets.long())
self.log('val_recall_score', self.recall_score, on_epoch=True, on_step=False, batch_size=batch_size)
return loss
def on_test_start(self) -> None:
logging.info('Starting testing...')
self.test_probs = self.model.inference(
x_all=self.trainer.datamodule.data.x,
subgraph_loader=self.trainer.test_dataloaders[0],
)
def test_step(self, batch, batch_idx):
return 0
def on_test_end(self) -> None:
targets = self.trainer.datamodule.data.y
y_pred = self.test_probs.argmax(dim=-1, keepdim=True)
acc = self.accuracy(y_pred, targets.long())
f1_score = self.f1_score(y_pred, targets.long())
precision_score = self.precision_score(y_pred, targets.long())
recall_score = self.recall_score(y_pred, targets.long())
self.logger.experiment.log(
{
'test_acc': acc,
'test_f1': f1_score,
'test_precision': precision_score,
'test_recall_score': recall_score,
}
)
def configure_optimizers(self):
optimizer = getattr(torch.optim, self.optimizer.name)(
self.parameters(),
**self.optimizer.opt_params,
)
optim_dict = {
'optimizer': optimizer,
'monitor': 'val_loss',
}
if self.lr_scheduler is not None:
lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler.name)(
optimizer,
**self.lr_scheduler.lr_sched_params,
)
optim_dict.update({'lr_scheduler': lr_scheduler})
return optim_dict
models.mlp.py 这个方法用来做对照
import typing as tp
import pytorch_lightning as pl
import torch
from torch import nn
from torchmetrics import Accuracy, F1Score, Precision, Recall
from configs.base import LRScheduler, Optimizer
class MultiLayerPerceptron(nn.Module):
"""Base MLP model using here as a baseline."""
def __init__(
self,
num_neurons: tp.List[int],
num_classes: int,
dropout: float = 0,
):
"""
Initialize MultiLayerPerceptron.
Parameters
----------
num_neurons : tp.List[int]
Number of neurons in each layer.
num_classes : int
Number of classes in data.
dropout : float, optional
Dropout rate, by default 0.
"""
super().__init__()
self.lin_layers = []
num_layers = len(num_neurons)
for i in range(num_layers - 1):
self.lin_layers.extend(
[
nn.Linear(num_neurons[i], num_neurons[i + 1]),
nn.ReLU(),
nn.Dropout(dropout),
],
)
self.lin_layers.append(
nn.Linear(num_neurons[-1], num_classes),
)
self.model = nn.Sequential(*self.lin_layers)
def forward(self, x: torch.tensor) -> torch.tensor:
x = self.model(x)
return torch.log_softmax(x, dim=-1)
class MLPModule(pl.LightningModule):
def __init__(
self,
model_params: dict,
optimizer: Optimizer,
lr_scheduler: LRScheduler | None,
):
"""
Initialize MLPModule.
Parameters
----------
model_params : dict
Dictionary with MultiLayerPerceptron parameters.
optimizer : Optimizer
Optimizer.
lr_scheduler : LRScheduler
Learning rate scheduler.
"""
super().__init__()
self.save_hyperparameters(ignore=['criterion']) # criterion is already saved during checkpointing
self.learning_rate = optimizer.opt_params['lr']
self.model = MultiLayerPerceptron(**model_params)
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.criterion = nn.CrossEntropyLoss()
num_classes = model_params['num_classes']
task = 'multiclass'
self.accuracy = Accuracy(task=task, num_classes=num_classes)
self.f1_score = F1Score(task=task, num_classes=num_classes)
self.precision_score = Precision(task=task, num_classes=num_classes)
self.recall_score = Recall(task=task, num_classes=num_classes)
def forward(self, x: torch.tensor):
return self.model(x)
def training_step(self, batch, batch_idx):
x, targets = batch
targets = targets.squeeze(1)
probs = self.forward(x)
loss = self.criterion(probs, targets)
self.log('train_loss', loss, on_epoch=False, on_step=True)
self.accuracy(probs, targets.long())
self.log('train_acc', self.accuracy, on_epoch=False, on_step=True)
self.f1_score(probs, targets.long())
self.log('train_f1', self.f1_score, on_epoch=False, on_step=True)
self.precision_score(probs, targets.long())
self.log('train_precision', self.precision_score, on_epoch=False, on_step=True)
self.recall_score(probs, targets.long())
self.log('train_recall_score', self.recall_score, on_epoch=False, on_step=True)
return loss
def validation_step(self, batch, batch_idx):
x, targets = batch
targets = targets.squeeze(1)
probs = self.forward(x)
loss = self.criterion(probs, targets)
self.log('val_loss', loss, on_epoch=True, on_step=True)
self.accuracy(probs, targets.long())
self.log('val_acc', self.accuracy, on_epoch=True, on_step=False)
self.f1_score(probs, targets.long())
self.log('val_f1', self.f1_score, on_epoch=True, on_step=False)
self.precision_score(probs, targets.long())
self.log('val_precision', self.precision_score, on_epoch=True, on_step=False)
self.recall_score(probs, targets.long())
self.log('val_recall_score', self.recall_score, on_epoch=True, on_step=False)
return loss
def test_step(self, batch, batch_idx):
x, targets = batch
targets = targets.squeeze(1)
probs = self.forward(x)
loss = self.criterion(probs, targets)
self.log('test_loss', loss)
self.accuracy(probs, targets.long())
self.log('test_acc', self.accuracy)
self.f1_score(probs, targets.long())
self.log('test_f1', self.f1_score)
self.precision_score(probs, targets.long())
self.log('test_precision', self.precision_score)
self.recall_score(probs, targets.long())
self.log('test_recall_score', self.recall_score)
return loss
def configure_optimizers(self):
optimizer = getattr(torch.optim, self.optimizer.name)(
self.parameters(),
**self.optimizer.opt_params,
)
optim_dict = {
'optimizer': optimizer,
'monitor': 'val_loss',
}
if self.lr_scheduler is not None:
lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler.name)(
optimizer,
**self.lr_scheduler.lr_sched_params,
)
optim_dict.update({'lr_scheduler': lr_scheduler})
return optim_dict
2. 数据处理部分
分为两个文件data/embeddings_dataset.py 和 data/graph_dataset.py
import logging
import typing as tp
import pytorch_lightning as pl
import torch
from ogb.nodeproppred import PygNodePropPredDataset
from torch.utils.data import DataLoader, Dataset
class EmbeddingsDataset(Dataset):
def __init__(
self,
root: str,
mode: tp.Literal['train', 'valid', 'test'],
):
self.root = root
embeddings, targets = self._get_embeddings_and_targets(mode)
self.embeddings = embeddings
self.targets = targets
def _get_embeddings_and_targets(self, mode: str) -> tp.Tuple[torch.tensor, torch.tensor]:
dataset = PygNodePropPredDataset(name='ogbn-products', root=self.root)
graph = dataset[0]
indices = dataset.get_idx_split()[mode]
embeddings = graph.x[indices]
targets = graph.y[indices]
return embeddings, targets
def __len__(self):
return len(self.embeddings)
def __getitem__(self, index):
embedding = self.embeddings[index]
target = self.targets[index]
return embedding, target
class EmbeddingsDataModule(pl.LightningDataModule):
def __init__(
self,
root: str,
batch_size: int,
num_workers: int,
):
"""
Create Data Module for EmbeddingsDataset.
Parameters
----------
root : str
Path to root dir with dataset.
batch_size : int
Batch size for dataloaders.
num_workers : int
Number of workers in dataloaders.
"""
super().__init__()
self.save_hyperparameters()
self.root = root
self.batch_size = batch_size
self.num_workers = num_workers
def setup(self, stage: tp.Optional[str] = None):
if stage == 'fit' or stage is None:
self.train_dataset = EmbeddingsDataset(
root=self.root,
mode='train',
)
num_train_files = len(self.train_dataset)
logging.info(f'Mode: train, number of nodes: {num_train_files}')
self.val_dataset = EmbeddingsDataset(
root=self.root,
mode='valid',
)
num_val_files = len(self.val_dataset)
logging.info(f'Mode: val, number of nodes: {num_val_files}')
elif stage == 'test':
self.test_dataset = EmbeddingsDataset(
root=self.root,
mode='test',
)
num_test_files = len(self.test_dataset)
logging.info(f'Mode: test, number of nodes: {num_test_files}')
def train_dataloader(self):
return DataLoader(
self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, drop_last=True,
)
def val_dataloader(self):
return DataLoader(
self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, drop_last=False,
)
def test_dataloader(self):
return DataLoader(
self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, drop_last=False,
)
data/graph_dataset.py 文件
在train和valid部分,用的是ClusterLoader采样
在test部分,用的是NeighborSampler采样
ClusterLoader和NeighborSampler都是用于在大规模图数据上进行节点嵌入学习时的采样方法,但它们在实现方式和应用场景上有所不同。
ClusterLoader:
- ClusterLoader主要用于处理大规模图数据,它通过将图分割成多个簇(clusters),然后在每个簇上独立地进行训练。这种方法可以减少内存消耗,因为不需要将整个图加载到内存中。
- ClusterLoader通常与ClusterGCN方法结合使用,该方法通过递归地在簇内进行前向和后向传播来训练模型。
- 在ClusterLoader中,每个簇可以看作是一个子图,模型在每个子图上独立地进行训练,然后将结果合并以获得最终的节点嵌入。
NeighborSampler:
- NeighborSampler是GraphSAGE算法中使用的一种采样方法,它通过递归地采样节点的邻居来构建一个有向的多跳子图。
- NeighborSampler的核心思想是,给定一个mini-batch的节点和图卷积的层数L,以及每一层需要采样的邻居数目sizes,依次从第一层到第L层,对每一层进行邻居采样并返回一个二分图(bipartite subgraph)。
- 在每一层采样中,NeighborSampler会使用上一层采样中涉及到的所有节点进行邻居采样,这样可以保证在L层采样完成后,返回的结果包含所有在L层卷积中遇到的节点。
- NeighborSampler适用于节点维度的mini-batch训练,它允许在完全批量训练不可行的情况下,对大规模图上的GNN进行小批量训练。
总的来说,ClusterLoader更适合于处理大规模图数据,通过将图分割成簇来减少内存消耗,而NeighborSampler则适用于节点维度的mini-batch训练,通过递归采样邻居来构建子图。两者都是解决大规模图数据上的节点嵌入学习问题的有效方法,但具体的选择取决于图的特性和可用的计算资源。
import logging
import typing as tp
import torch
import pytorch_lightning as pl
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric.loader import ClusterLoader, ClusterData, NeighborSampler
class OGBGProductsDatamodule(pl.LightningDataModule):
"""DataModule for mini-batch GCN training using the Cluster-GCN algorithm."""
def __init__(
self,
root: str,
batch_size: int,
num_workers: int,
save_dir: tp.Optional[str] = 'data/ogbn_products/processed',
num_partitions: int = 15000,
):
"""
Create Data Module for OGBG Product task.
Parameters
----------
root : str
Path to the root directory.
batch_size : int
Batch size for dataloaders.
num_workers : int
Number of workers in dataloaders.
save_dir : tp.Optional[str]
Directory where already partitioned dataset is stored.
num_partitions : int
Number of partitions.
"""
super().__init__()
self.save_hyperparameters()
self.batch_size = batch_size
self.num_workers = num_workers
self.root = root
self.num_partitions = num_partitions
self.save_dir = save_dir
self.data = None
self.split_idx = None
self.cluster_data = None
def prepare_data(self):
dataset = PygNodePropPredDataset(name='ogbn-products', root=self.root)
self.split_idx = dataset.get_idx_split()
self.data = dataset[0]
# Convert split indices to boolean masks and add them to `data`
for key, idx in self.split_idx.items():
mask = torch.zeros(self.data.num_nodes, dtype=torch.bool)
mask[idx] = True
self.data[f'{key}_mask'] = mask
self.cluster_data = ClusterData(
self.data,
num_parts=self.num_partitions,
recursive=False,
save_dir=self.save_dir,
)
def setup(self, stage: tp.Optional[str] = None):
if stage == 'fit' or stage is None:
self.train_split = self.split_idx['train']
num_train_files = len(self.train_split)
logging.info(f'Mode: train, number of nodes: {num_train_files}')
self.valid_split = self.split_idx['valid']
num_valid_files = len(self.valid_split)
logging.info(f'Mode: valid, number of nodes: {num_valid_files}')
elif stage == 'test':
self.test_split = self.split_idx['test']
num_test_files = len(self.test_split)
logging.info(f'Mode: test, number of nodes: {num_test_files}')
def train_dataloader(self):
return ClusterLoader(
self.cluster_data,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
)
def val_dataloader(self):
return ClusterLoader(
self.cluster_data,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)
def test_dataloader(self):
return NeighborSampler(
self.data.edge_index,
sizes=[-1],
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)
3.代码文件配置部分
configs/base.py
import typing as tp
from dataclasses import dataclass
import pytorch_lightning as pl
@dataclass
class Project:
project_name: str
run_name: str
notes: str
tags: str
log_freq: int = 100
@dataclass
class Common:
seed: int = 8
@dataclass
class Dataset:
module: pl.LightningDataModule
root: str # path to root directory with images
batch_size: int
num_workers: int
@dataclass
class Model:
module: pl.LightningModule
model_params: dict
@dataclass
class Callbacks:
model_checkpoint: pl.callbacks.ModelCheckpoint
early_stopping: tp.Optional[pl.callbacks.EarlyStopping] = None
lr_monitor: tp.Optional[pl.callbacks.LearningRateMonitor] = None
model_summary: tp.Optional[tp.Union[pl.callbacks.ModelSummary, pl.callbacks.RichModelSummary]] = None
timer: tp.Optional[pl.callbacks.Timer] = None
@dataclass
class Optimizer:
name: str
opt_params: dict
@dataclass
class LRScheduler:
name: str
lr_sched_params: dict
@dataclass
class Train:
trainer_params: dict
callbacks: Callbacks
optimizer: Optimizer
lr_scheduler: LRScheduler
ckpt_path: tp.Optional[str] = None
@dataclass
class Config:
project: Project
common: Common
dataset: Dataset
model: Model
train: Train
configs/config_gcn.py
import os
import uuid
import pytorch_lightning as pl
from configs.base import (Callbacks, Common, Config, Dataset, LRScheduler,
Model, Optimizer, Project, Train)
from data.graph_dataset import OGBGProductsDatamodule
from models.gcn import GCNModule
RUN_NAME = 'gcn_' + uuid.uuid4().hex[:6] # unique run id
CONFIG = Config(
project=Project(
log_freq=10,
project_name='OGBN Product',
run_name=RUN_NAME,
tags='gcn',
notes='',
),
common=Common(seed=8),
dataset=Dataset(
module=OGBGProductsDatamodule,
root='data/',
batch_size=512,
num_workers=6,
),
model=Model(
module=GCNModule,
model_params={
'num_features': [100, 256, 256, 256],
'num_classes': 47,
'dropout': 0.15,
},
),
train=Train(
trainer_params={
'devices': 1,
'accelerator': 'auto',
'accumulate_grad_batches': 1,
'auto_scale_batch_size': None,
'gradient_clip_val': 0.0,
'benchmark': True,
'precision': 32,
'max_epochs': 200,
'auto_lr_find': None,
},
callbacks=Callbacks(
model_checkpoint=pl.callbacks.ModelCheckpoint(
dirpath=os.path.join('checkpoints', RUN_NAME),
save_top_k=2,
monitor='val_loss',
mode='min',
),
lr_monitor=pl.callbacks.LearningRateMonitor(logging_interval='step'),
),
optimizer=Optimizer(
name='Adam',
opt_params={
'lr': 0.001,
'weight_decay': 0.0001,
},
),
lr_scheduler=LRScheduler(
name='CosineAnnealingWarmRestarts',
lr_sched_params={
'T_0': 200,
'T_mult': 1,
'eta_min': 0.00001,
},
),
ckpt_path=None,
),
)
configs/config_mlp.py
import os
import uuid
import pytorch_lightning as pl
from configs.base import (Callbacks, Common, Config, Dataset, LRScheduler,
Model, Optimizer, Project, Train)
from data.embeddings_dataset import EmbeddingsDataModule
from models.mlp import MLPModule
RUN_NAME = 'mlp_baseline_' + uuid.uuid4().hex[:6] # unique run id
CONFIG = Config(
project=Project(
log_freq=500,
project_name='OGBN Product',
run_name=RUN_NAME,
tags='mlp, baselone',
notes='',
),
common=Common(seed=8),
dataset=Dataset(
module=EmbeddingsDataModule,
root='data/',
batch_size=512,
num_workers=6,
),
model=Model(
module=MLPModule,
model_params={
'num_neurons': [100, 256, 256, 256],
'num_classes': 47,
'dropout': 0.15,
},
),
train=Train(
trainer_params={
'devices': 1,
'accelerator': 'auto',
'accumulate_grad_batches': 1,
'auto_scale_batch_size': None,
'gradient_clip_val': 0.0,
'benchmark': True,
'precision': 32,
'max_epochs': 135,
'auto_lr_find': None,
},
callbacks=Callbacks(
model_checkpoint=pl.callbacks.ModelCheckpoint(
dirpath=os.path.join('checkpoints', RUN_NAME),
save_top_k=2,
monitor='val_loss',
mode='min',
),
lr_monitor=pl.callbacks.LearningRateMonitor(logging_interval='step'),
),
optimizer=Optimizer(
name='Adam',
opt_params={
'lr': 0.001,
'weight_decay': 0.0001,
},
),
lr_scheduler=LRScheduler(
name='CosineAnnealingWarmRestarts',
lr_sched_params={
'T_0': 120,
'T_mult': 1,
'eta_min': 0.00001,
},
),
ckpt_path=None,
),
)
configs/utils.py
配置文件的统筹代码
import pytorch_lightning as pl
from configs.base import Config, LRScheduler, Optimizer
SEP = '_'
def get_dict_for_optimizer(config_optimizer: Optimizer) -> dict:
opt_dict = {}
opt_dict[SEP.join(['optimizer', 'name'])] = config_optimizer.name
for key, value in config_optimizer.opt_params.items():
opt_dict[SEP.join(['optimizer', key])] = value
return opt_dict
def get_dict_for_lr_scheduler(config_lr_scheduler: LRScheduler) -> dict:
lr_dict = {}
lr_dict[SEP.join(['lr_scheduler', 'name'])] = config_lr_scheduler.name
for key, value in config_lr_scheduler.lr_sched_params.items():
lr_dict[SEP.join(['lr_scheduler', key])] = value
return lr_dict
def get_config_dict(model: pl.LightningModule, datamodule: pl.LightningDataModule, config: Config) -> dict:
config_dict = {}
datamodule_dict = dict(datamodule.hparams)
trainer_dict = config.train.trainer_params
optimizer_dict = get_dict_for_optimizer(config.train.optimizer)
lr_sched_dict = get_dict_for_lr_scheduler(config.train.lr_scheduler)
# using update to avoid doubling keys
# config_dict.update(model_dict)
config_dict.update(datamodule_dict)
config_dict.update(trainer_dict)
config_dict.update(optimizer_dict)
config_dict.update(lr_sched_dict)
return config_dict
4.训练启动文件
train.py
import argparse
import logging
import os
import typing as tp
from runpy import run_path
import pytorch_lightning as pl
import wandb
from kornia.losses.focal import FocalLoss
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import RichModelSummary
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import WandbLogger
from configs.base import Config
from configs.utils import get_config_dict
def get_wandb_logger(model: pl.LightningModule, datamodule: pl.LightningDataModule, config: Config) -> WandbLogger:
config_dict = get_config_dict(model=model, datamodule=datamodule, config=config)
wandb_logger = WandbLogger(
project=config.project.project_name,
name=config.project.run_name,
tags=config.project.tags.split(','),
notes=config.project.notes,
config=config_dict,
log_model='all',
)
wandb_logger.watch(
model=model,
log='all',
log_freq=config.project.log_freq, # log gradients and parameters every log_freq batches
)
return wandb_logger
def parse() -> tp.Any:
parser = argparse.ArgumentParser()
parser.add_argument(
'--config',
default=os.path.join('src', 'configs', 'config_mlp.py'),
type=str,
help='Path to experiment config file (*.py)',
)
parser.add_argument(
'--loss',
default='bce',
type=str,
choices=['bce', 'focal'],
help='Loss type',
)
return parser.parse_args()
def main(args: tp.Any, config: Config):
# model
model = config.model.module(
config.model.model_params,
optimizer=config.train.optimizer,
lr_scheduler=config.train.lr_scheduler,
)
if args.loss == 'focal':
model.criterion = FocalLoss(alpha=1, gamma=2, reduction='mean')
logging.info('Using FocalLoss')
# data module
datamodule = config.dataset.module(
root=config.dataset.root,
batch_size=config.dataset.batch_size,
num_workers=config.dataset.num_workers,
)
# logger
logger = get_wandb_logger(model, datamodule, config)
base_path = os.path.split(args.config)[0]
wandb.save(args.config, base_path=base_path, policy='now')
# trainer
trainer_params = config.train.trainer_params
callbacks = list(config.train.callbacks.__dict__.values())
callbacks = filter(lambda callback: callback is not None, callbacks)
trainer = Trainer(
callbacks=[
TQDMProgressBar(refresh_rate=1),
RichModelSummary(),
*callbacks,
],
logger=logger,
**trainer_params,
)
if trainer_params['auto_scale_batch_size'] is not None or trainer_params['auto_lr_find'] is not None:
trainer.tune(model=model, datamodule=datamodule)
trainer.fit(
model=model,
datamodule=datamodule,
ckpt_path=config.train.ckpt_path,
)
trainer.test(
model=model,
datamodule=datamodule,
ckpt_path='best',
)
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
args = parse()
config_module = run_path(args.config)
exp_config = config_module['CONFIG']
seed_everything(exp_config.common.seed, workers=True)
main(args, exp_config)
三、环境配置
ogb==1.3.5
pytorch_lightning==1.8.6
torch==1.13.1
torch_geometric==2.2.0
torchmetrics==0.11.0
tqdm==4.62.3
wandb==0.13.7
torch_geometric的安装,需要安装依赖包,这个根据cuda版本和pytorch版本安装