总的数据流程:
1.数据集:分为数据和标签
2.处理后送入build_dataset构建,再经过迭代DatasetIterater处理得到批次数据,送入train中将每一个批次索引转化为词向量形式训练
!!送入train之前需要先构建词汇表以及对应的词向量嵌入矩阵,以便前向传播的时候将索引转化为向量处理!!
需要注意的是:数据的流动以及形式转换的整个过程
# coding: UTF-8
import time
import torch
import numpy as np
from train_eval import train, init_network
from importlib import import_module#动态模块导入工具:import_module 是 Python 标准库 importlib 模块中的一个函数,用于在运行时动态导入模块。与常规的 import 语句相比,import_module 允许你在程序运行时根据需要动态地导入模块。
import argparse#命令行参数解析工具。
parser = argparse.ArgumentParser(description='Chinese Text Classification')#定义命令行参数解析器:用于从命令行读取模型名称、词嵌入类型和分词方式
parser.add_argument('--model', type=str, required=True, help='choose a model: TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer')#--model:选择使用的模型,必须提供。
parser.add_argument('--embedding', default='pre_trained', type=str, help='random or pre_trained')#--embedding:选择词嵌入类型,默认为预训练词嵌入。
parser.add_argument('--word', default=False, type=bool, help='True for word, False for char')#--word:选择分词方式,默认为按字符分词。
args = parser.parse_args()
if __name__ == '__main__':#主程序入口:检查是否作为主程序运行,确保以下代码仅在直接运行脚本时执行。
dataset = 'THUCNews' # 数据集
# 搜狗新闻:embedding_SougouNews.npz, 腾讯:embedding_Tencent.npz, 随机初始化:random
embedding = 'embedding_SougouNews.npz'#设置词嵌入文件路径:根据命令行参数选择词嵌入类型。如果选择随机初始化,则设置 embedding 为 random。
if args.embedding == 'random':
embedding = 'random'
model_name = args.model # 'TextRCNN' # TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer 获取模型名称:从命令行参数中获取模型名称。
if model_name == 'FastText':#根据模型名称导入相应的工具函数:如果是 FastText 模型,使用 utils_fasttext 中的工具函数,并强制将词嵌入设置为 random。否则,使用 utils 中的工具函数。
from utils_fasttext import build_dataset, build_iterator, get_time_dif
embedding = 'random'
else:
from utils import build_dataset, build_iterator, get_time_dif
x = import_module('models.' + model_name)#动态导入模型模块:根据模型名称动态导入相应的模型模块。
config = x.Config(dataset, embedding)#初始化模型配置:使用模型模块中的 Config 类初始化配置。
np.random.seed(1)#设置随机种子:确保每次运行的结果一致,设置 NumPy 和 Torch 的随机种子,并使得 CuDNN 后端的计算是确定性的。
torch.manual_seed(1)
torch.cuda.manual_seed_all(1)
torch.backends.cudnn.deterministic = True # 保证每次结果一样
start_time = time.time()#记录开始时间:用于计算数据加载时间。
print("Loading data...")
vocab, train_data, dev_data, test_data = build_dataset(config, args.word)#加载数据集:使用 build_dataset 函数加载词汇表、训练集、验证集和测试集。
train_iter = build_iterator(train_data, config)#构建数据迭代器:使用 build_iterator 函数为训练集、验证集和测试集构建数据迭代器。
dev_iter = build_iterator(dev_data, config)
test_iter = build_iterator(test_data, config)
time_dif = get_time_dif(start_time)
print("Time usage:", time_dif)#计算并打印数据加载时间。
# train:初始化模型并训练
config.n_vocab = len(vocab)#设置词汇表大小:将词汇表大小设置到配置中。
model = x.Model(config).to(config.device)#初始化模型:根据配置初始化模型,并将模型移动到指定设备(CPU 或 GPU)。
if model_name != 'Transformer':#初始化模型权重:如果模型不是 Transformer,则使用 init_network 函数初始化模型权重。
init_network(model)
print(model.parameters)#打印模型参数:打印模型的参数信息。
train(config, model, train_iter, dev_iter, test_iter)#训练模型:调用 train 函数进行模型训练。
中文文本分类流程
512

被折叠的 条评论
为什么被折叠?



