代码基于Text-GCN作者论文和基于Bible的实践,保留原作者信息,具体代码请移步原作者仓库
https://github.com/plkmo/Bible_Text_GCN
https://github.com/yao8839836/text_gcn
从无到有学习GCN的一个过程,很多python的用法都不懂,打上一部分注释,希望给我一样的小白一点帮助
新手入门容易遇到的问题
- 训练的是什么参数? AXW里面的A是图结构,X是feature,是节点的特征(类比word embedding),W是这一层的权重,也就是随着梯度下降更新的值,在pytorch里面定义一个weight,下文代码用了normal_方法对这个权重做了初始化,但在很多讲座介绍里这个是可有可无的
- kipf的论文里面X是取的单位矩阵,在这个文本分类任务里使用了TF-IDF和PMI作为对角矩阵的值,将边的信息融入了学习中,至于融合的怎么样实际上在数学上比较抽象,但我个人训练后,300-500epoch就能取得一个不错的结果,准确率可以达到80%以上还是很顶的,具体细节、更难的东西俺也不懂
- 中文和英文分词的区别?中文是词语有含义而不是英语的单词,所以需要先进行分词。下面代码用了jieba cut分词函数,写了一个小jieba函数,jieba真香
- GCN怎么来的,什么原理?参考NTU李宏毅老师姜成瀚助教做的影片,看那个视频B站有,讲的挺清楚的
# 下面的代码关键部分打了注释,觉得有帮助的同学不妨给我点个免费的赞让我开心一下
# -*- coding: utf-8 -*-
"""
Created on Thu May 9 10:28:24 2019
@author: WT
"""
import os
import pickle
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
import nltk
import numpy as np
import networkx as nx
from collections import OrderedDict
from itertools import combinations
import math
from tqdm import tqdm
import logging
import jieba
logging.basicConfig(format='%(asctime)s [%(levelname)s]: %(message)s', \
datefmt='%m/%d/%Y %I:%M:%S %p', level=logging.INFO) # 配置输出日志
logger = logging.getLogger(__file__)
stop = [line.strip() for line in open('stop_words.txt').readlines()]
def cut_words(text):
text = str(text)
text = list(jieba.cut(text))
for word in text:
if word in stop:
text.remove(word)
return text
def load_pickle(filename):
completeName = os.path.join("./data/", \
filename)
with open(completeName, 'rb') as pkl_file:
data = pickle.load(pkl_file)
return data
def save_as_pickle(filename, data):
completeName = os.path.join("./data/", \
filename)
with open(completeName, 'wb') as output:
pickle.dump(data, output)
def nCr(n, r):
f = math.factorial # 阶乘
return int(f(n) / (f(r) * f(n - r)))
# 移除无意义词汇和符号
def filter_tokens(tokens, stopwords):
tokens1 = []
for token in tokens:
if (token not in stopwords) and (token not in [".", ",", ";", "&", "'s", ":", "?", "!", "(", ")", \
"'