一、模型图
二、构图方式
滑动窗口,捕捉共现边
三、消息传递
聚合,通过GRU训练。
四,读出
五,资源下载
TextING
The implementation of TextING
Require
Python 3.7.6
torch 1.5.1
torch-geometric 1.6.1
Source
Download corpus.zip
https://download.youkuaiyun.com/download/qq_28969139/13061605
zip ./corpus
corpus/20ng.labels.txt
corpus/20ng.texts.txt
...
--------------------------
Download golve.6B.zip
https://download.youkuaiyun.com/download/qq_28969139/14033550
zip ./source
source/glove.6B.50d.txt
source/glove.6B.100d.txt
...
Run 'python handle_glove.py'
Preprocess
Run 'python preprocess.py'
Run 'python prebuild.py'
Run
Run 'python train.py'
1目录结构
2代码
config.py
# static params
args = {
'20ng':
{'train_size': 11314,
'test_size': 7532,
'valid_size': 1131,
"num_classes": 20
},
'aclImdb':
{'train_size': 25000,
'test_size': 25000,
'valid_size': 2500
},
'ag_news':
{'train_size': 120000,
'test_size': 7600,
'valid_size': 12000
},
'dblp':
{'train_size': 61479,
'test_size': 20000,
'valid_size': 6148
},
'mr':
{'train_size': 7108,
'test_size': 3554,
'valid_size': 711,
"num_classes": 2
},
'ohsumed':
{'train_size': 3357,
'test_size': 4043,
'valid_size': 336,
"num_classes": 23
},
'R8':
{'train_size': 5485,
'test_size': 2189,
'valid_size': 548,
"num_classes": 8
},
'R52':
{'train_size': 6532,
'test_size': 2568,
'valid_size': 653,
"num_classes": 52
},
'TREC':
{'train_size': 5452,
'test_size': 500,
'valid_size': 545
},
'WebKB':
{'train_size': 2803,
'test_size': 1396,
'valid_size': 280
},
'wiki':
{'train_size': 3000,
'test_size': 127000,
'valid_size': 300
}
}
dataset.py
from config import args
import joblib
import numpy as np
from torch_geometric.data import Data, DataLoader
import torch
import random
from tqdm import tqdm
class MyDataLoader(object):
def __init__(self, dataset, batch_size, mini_batch_size=0):
self.total = len(dataset)
self.dataset = dataset
self.batch_size = batch_size
self.mini_batch_size = mini_batch_size
if mini_batch_size == 0:
self.mini_batch_size = self.batch_size
def __getitem__(self, item):
ceil = (item + 1) * self.batch_size
sub_dataset = self.dataset[ceil - self.batch_size:ceil]
if ceil >= self.total:
random.shuffle(self.dataset)
return DataLoader(sub_dataset, batch_size=self.mini_batch_size)
def __len__(self):
if self.total == 0: return 0
return (self.total - 1) // self.batch_size + 1
def split_train_valid_test(data, train_size, valid_part=0.1):
train_data