源代码链接(github地址)
https://github.com/dennybritz/cnn-text-classification-tf
my https://github.com/tddfly/cnn-text-classification-tf
参考博文: https://blog.youkuaiyun.com/github_38414650/article/details/74019595
数据集:https://github.com/cystanford/text_classification
包含训练集(四种类别)、测试集、停用词表三个文件夹
四个文件
data_helpers.py 数据预处理
train.py神经网络的训练过程
text_cnn.py 卷积神经网络结构
eval.py预测、评估
data_helpers.py文件
import numpy as np
import re
import os
import jieba
def clean_str(string):
"""
数据预处理
"""
string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string) # 匹配所有大小写字母、数字、括号、逗号、感叹号、问号、引号等,不在这些字符之类的都用空格替代
#string = re.sub(r"[A-Za-z0-9(),!?\'\`]", " ", string) 注意这个去掉^ 以后,就是将所有中括号内的替换成空格
string = re.sub(r"\'s", " \'s", string) # 将“'s”用“ 's”代替,即family's变成family 's
string = re.sub(r"\'ve", " \'ve", string)
string = re.sub(r"n\'t", " n\'t", string) #将isn't变成 is n't
string = re.sub(r"\'re", " \'re", string) #将you're变成you 're
string = re.sub(r"\'d", " \'d", string) #将i'd变成 i 'd
string = re.sub(r"\'ll", " \'ll", string)
string = re.sub(r",", " , ", string)
string = re.sub(r"!", " ! ", string)
string = re.sub(r"\(", " \( ", string)
string = re.sub(r"\)", " \) ", string)
string = re.sub(r"\?", " \? ", string)
string = re.sub(r"\s{2,}", " ", string) # s是空格,将2个或者2个以上的空格用一个空格代替
return string.strip().lower()
def load_data_and_labels(positive_data_file, negative_data_file):
"""
Loads MR polarity data from files, splits the data into words and generates labels.
Returns split sentences and labels.
"""
# 从文件中读正负样本
positive_examples = list(open(positive_data_file, "r", encoding='utf-8').readlines())
positive_examples = [s.strip() for s in positive_examples] # 去掉换行符
negative_examples = list(open(negative_data_file, "r", encoding='utf-8').readlines())
negative_examples = [s.strip() for s in negative_examples]
# Split by words
x_text = positive_examples + negative_examples # 将两个列表合并类似extend,将结果放到一个新的列表中;extend是在原有列表中相加
x_text = [clean_str(sent) for sent in x_text]
# Generate labels
positive_labels = [[0, 1] for _ in positive_examples]
negative_labels = [[1, 0] for _ in negative_examples]
y = np.concatenate([positive_labels, negative_labels], 0) # np.concatenate 数组拼接,传入的参数必须是一个多个数组的元组或者列表
return [x_text, y]
# batches = data_helpers.batch_iter(list(zip(x_train, y_train)), FLAGS.batch_size=64, FLAGS.num_epochs=200)
def batch_iter(data, batch_size, num_epochs, shuffle=True): # 每一轮 获得一个batch的数据
"""
定义一个函数,输出batch样本,参数为data(包括feature和label),batchsize,epoch
"""
data = np.array(data) # [(text0,label0),(test1,label1)..]
# 注意:data是个列表,因为里面的元素的数据类型是 一致的,所以用数组np.array存储,节省内存
# type(data)返回data的数据类型,data.dtype返回数组中内容的数据类型。
data_size = len(data)
num_batches_per_epoch = int((len(data)-1)/batch_size) + 1 # 每次迭代训练所有数据,分成多少个batch
for epoch in range(num_epochs): # 在每一轮迭代过程中,都打乱数据
# Shuffle the data at each epoch
if shuffle:
shuffle_indices = np.random.permutation(np.arange(data_size)) #打乱的索引
# 注:shuffle 与 permutation的区别:都是对原来的数组进行重新洗牌,shuffle直接在原来的数组上进行操作,改变原来的元素顺序,无返回值;
# permutation不直接在原来的数组上进行操作,而是返回一个新的打乱顺序的数组,并不改变原来的数组。
shuffled_data = data[shuffle_indices] # 数据打乱
else:
shuffled_data = data
for batch_num in range(num_batches_per_epoch): # 对于每个batch的数据,获得batch内的起始与终止的位置
start_index = batch_num * batch_size
end_index = min((batch_num + 1) * batch_size, data_size)
yield shuffled_data[start_index:end_index]
# yield,在for循环执行时,每次返回一个batch的data,占用的内存为常数
#============分隔线==================
# 返回的数据格式,x_corpus为列表,里面每个元素是一个文本字符串
# y_label是个列表,对应文本的标签,ont-hot表示,如[ [1,0,0,0],[0,1,0,0]]里面是2个标签
def load_data_and_labels_3_27(corpus_path,stop_list_Path):
x_corpus = []
y_label = []
cate_dir = os.listdir(corpus_path) # 获取子类别目录
# [女性、体育、文学、校园]
for cate in cate_dir:
if cate == '女性':
label = [1,0,0,0]
elif cate =='体育':
label = [0, 1, 0, 0]
elif cate =='文学':
label = [0, 0, 1, 0]
else:
label = [0,0,0,1]
cate_complete_dir = corpus_path+'\\'+cate+"\\" # 获取子类别的完整路径
file_dir = os.listdir(cate_complete_dir)#获取每个类别下的文件
for file in file_dir:
file_complete_dir = cate_complete_dir+file # 获取每个类别下的文件的完整路径
content = readfile(file_complete_dir) # 返回一个文本
# 对文本进行处理,删除换行以及多余空格
content = content.replace("\n",'').strip()
content_seg = jieba.cut(content)
#创建停用词表
stopwords = stopwordsList(stop_list_Path)
outstr =''
for word in content_seg:
if word not in stopwords:
if word !='\t':
outstr+=word
outstr+=" "
text = ' '.join(outstr)
x_corpus.append(text)
y_label.append(label)
return [x_corpus,y_label]
def stopwordsList(stop_list_Path):
f = open(stop_list_Path,'r',encoding='utf-8')
st