GCN图卷积神经网络-中文文本分类 pytorch

本文介绍了使用PyTorch实现的GCN(图卷积网络)在中文文本分类中的应用。内容包括新手入门的常见问题、网络模型代码,并探讨了训练参数、图结构、特征融合等概念。通过300-500次训练,模型能获得超过80%的准确率。此外,还讨论了中文分词的差异以及GCN的基本原理。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

代码基于Text-GCN作者论文和基于Bible的实践,保留原作者信息,具体代码请移步原作者仓库
https://github.com/plkmo/Bible_Text_GCN
https://github.com/yao8839836/text_gcn

从无到有学习GCN的一个过程,很多python的用法都不懂,打上一部分注释,希望给我一样的小白一点帮助

新手入门容易遇到的问题

  1. 训练的是什么参数? AXW里面的A是图结构,X是feature,是节点的特征(类比word embedding),W是这一层的权重,也就是随着梯度下降更新的值,在pytorch里面定义一个weight,下文代码用了normal_方法对这个权重做了初始化,但在很多讲座介绍里这个是可有可无的
  2. kipf的论文里面X是取的单位矩阵,在这个文本分类任务里使用了TF-IDF和PMI作为对角矩阵的值,将边的信息融入了学习中,至于融合的怎么样实际上在数学上比较抽象,但我个人训练后,300-500epoch就能取得一个不错的结果,准确率可以达到80%以上还是很顶的,具体细节、更难的东西俺也不懂
  3. 中文和英文分词的区别?中文是词语有含义而不是英语的单词,所以需要先进行分词。下面代码用了jieba cut分词函数,写了一个小jieba函数,jieba真香
  4. GCN怎么来的,什么原理?参考NTU李宏毅老师姜成瀚助教做的影片,看那个视频B站有,讲的挺清楚的
# 下面的代码关键部分打了注释,觉得有帮助的同学不妨给我点个免费的赞让我开心一下
# -*- coding: utf-8 -*-
"""
Created on Thu May  9 10:28:24 2019

@author: WT
"""
import os
import pickle
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
import nltk
import numpy as np
import networkx as nx
from collections import OrderedDict
from itertools import combinations
import math
from tqdm import tqdm
import logging
import jieba

logging.basicConfig(format='%(asctime)s [%(levelname)s]: %(message)s', \
                    datefmt='%m/%d/%Y %I:%M:%S %p', level=logging.INFO)  # 配置输出日志
logger = logging.getLogger(__file__)

stop = [line.strip() for line in open('stop_words.txt').readlines()]


def cut_words(text):
    text = str(text)
    text = list(jieba.cut(text))
    for word in text:
        if word in stop:
            text.remove(word)
    return text

def load_pickle(filename):
    completeName = os.path.join("./data/", \
                                filename)
    with open(completeName, 'rb') as pkl_file:
        data = pickle.load(pkl_file)
    return data


def save_as_pickle(filename, data):
    completeName = os.path.join("./data/", \
                                filename)
    with open(completeName, 'wb') as output:
        pickle.dump(data, output)


def nCr(n, r):
    f = math.factorial  # 阶乘
    return int(f(n) / (f(r) * f(n - r)))


# 移除无意义词汇和符号
def filter_tokens(tokens, stopwords):
    tokens1 = []
    for token in tokens:
        if (token not in stopwords) and (token not in [".", ",", ";", "&", "'s", ":", "?", "!", "(", ")", \
                                                       "'
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值