1. 背景说明
在Pytorch 实现情感分类版本基础上进行优化实现,使用到文本文件reviews.txt和标签文件labels.txt两个数据文件。
1.1 数据集预览
1. reviews.txt
此文件是一个包含25001条句子的长文本。
![]()
print(len(text))
# 33678267 个字符
2. labels.txt

3. 清理标点符号
from string import punctuation
# 标点符号
print(punctuation)
# !"#$%&'()*+,-./:;<=>?@[\]^_`{|}~
clean_text = ''.join([char for char in text if char not in punctuation])
1.2 整体架构
1. 读取原始数据集(文本集);
2. 文本预处理
1. 清理无用的标点符号;
2. 根据换行符\n分隔;
3. 单词 --> 索引转换;
4. 标签 --> 0、1转换;
5. 清理文本(过短以及过长的样本);
6. 将单词映射为整型;
7. 设定统一的文本长度(对整个文本数据中的每条评论进行填充或截断)。
3. 特征工程
1. array --> tensor;
2. 将数据集分割为train、val、test三部分;
3. 通过DataLoader分批加载数据。
4. 定义网络模型结构
5. 定义超参数
6. 定义训练函数(训练+验证)
7. 定义测试函数
8. 定义预测函数
2. 数据加载与探索
2.1 读取原始数据集
1. 获取文本
def get_text():
with open('../data/reviews.txt', 'r') as f:
text = f.read()
# 清理标点符号
clean_text = ''.join([char for char in text if char not in punctuation])
# 按行切分句子
clean_text = clean_text.split('\n')
# print(len(clean_text)) # 25001
return clean_text
2. 获取label
def get_label():
with open('../data/labels.txt', 'r') as f:
labels = f.read()
# 按行切分标签
labels = labels.split('\n')
return labels
2.2 文本预处理
1. 单词 --> 索引转换
# 构建"单词:索引"的字典
def get_word2index_dict(text):
# 获取评论中的不同的单词,建议使用列表推导,比for循环快4~10倍
words = [word.lower() for sentence in text for word in sentence.split(

本文档详细介绍了使用PyTorch进行情感分类的进阶过程,包括数据集预览、数据加载与探索(读取、预处理、分割、分批加载)、模型构建、模型训练(训练pipeline、验证、测试)以及模型在线预测的实现。内容涵盖了从数据预处理到模型调用的完整流程。
最低0.47元/天 解锁文章
650

被折叠的 条评论
为什么被折叠?



