文本分类(一) | (9) 项目组织结构

项目Github地址

在学习某个深度学习框架时,掌握其基本知识和接口固然重要,但如何合理组织代码,使得代码具有良好的可读性和可扩展性也必不可少。本文不会深入讲解过多知识性的东西,更多的则是传授一些经验,关于如何使得自己的程序更pythonic,更符合pytorch的设计理念。这些内容可能有些争议,因其受我个人喜好和coding风格影响较大,你可以将这部分当成是一种参考或提议,而不是作为必须遵循的准则。归根到底,都是希望你能以一种更为合理的方式组织自己的程序。

在做深度学习实验或项目时,为了得到最优的模型结果,中间往往需要很多次的尝试和修改(也就是所谓地调参)。根据我的个人经验,在从事大多数深度学习研究时,程序都需要实现以下几个功能:

1)模型定义

2)数据处理和加载

3)训练模型(Train&Validate)

4)训练过程的可视化或相关指标的计算

5)测试/预测(Test/Inference)

另外程序还应该满足以下几个要求:

1)模型需具有高度可配置性,便于修改参数、修改模型,反复实验

2)代码应具有良好的组织结构,使人一目了然

3)代码应具有良好的说明,使其他人能够理解

接下来我将应用这些内容,并结合实际的例子,来讲解如何合理组织我们的文本分类项目。

目录

1. 文件组织结构

2. 数据预处理和加载

3. 模型定义

4. 配置文件

5. main.py

6. 使用方式

7. 实验结果与分析

8. 预测与网页Demo

9. 程序所依赖的环境

10. 总结


1. 文件组织结构

首先来看程序文件的组织结构:

其中:

1)checkpoints/: 用于保存训练好的模型,可使程序在异常退出后仍能重新载入模型,恢复训练

2)data/:数据相关操作,包括数据预处理、dataset实现等

3)models/:模型定义,可以有多个模型,例如上面的FastText、TextCNN等,一个模型对应一个文件

4)config.py:配置文件,所有可配置的变量都集中在此,并提供默认值

5)main.py:主文件,训练和预测程序的入口,可通过不同的命令来指定不同的操作和参数

6)load_word_vector.py:定义加载预训练词向量的函数

7).idea/、static/、templates/、settings.py、urls.py、views.py、wsgi.py、db.sqlite3、manage.py:网页Demo运行支撑文件。

8)requirements.txt:程序依赖的第三方库

9)README.pdf:项目说明文档

 

2. 数据预处理和加载

数据的相关预处理函数主要保存在data/dataset.py中。关于数据加载的相关操作,其基本原理就是使用Dataset进行数据集的封装,再使用Dataloader实现数据并行加载。

具体的预处理过程和实现细节在文本分类专栏的第(2)篇博客中已经详细介绍了。

使用时,我们可通过dataloader加载数据:

#读取之前预处理过程 保存的处理好的训练集、验证集和测试集
    X_train = torch.load('./data/X_train.pt')
    y_train = torch.load('./data/y_train.pt')
    X_val = torch.load('./data/X_val.pt')
    y_val = torch.load('./data/y_val.pt')
    X_test = torch.load('./data/X_test.pt')
    y_test = torch.load('./data/y_test.pt')
    
    #封装成DataSet
    trainset = Data.TensorDataset(X_train,y_train)
    valset = Data.TensorDataset(X_val,y_val)
    testset = Data.TensorDataset(X_test,y_test)

    #使用DataLoader并行加载数据
    train_iter = Data.DataLoader(trainset,opt.batch_size,shuffle=True,num_workers=opt.num_workers)
    val_iter = Data.DataLoader(valset,opt.batch_size)
    test_iter = Data.DataLoader(testset,opt.batch_size)
  • 加载预训练词向量

load_word_vector.py定义了加载预训练词向量的函数:

def read_word_vector(path): #path为 下载的预训练词向量 解压后的文件所在的路径
    #读取预训练词向量
    with open(path, 'r') as f:
        words = set()  # 定义一个words集合
        word_to_vec_map = {}  # 定义词到向量的映射字典
        for line in f:  #跳过文件的第一行 
            break

        for line in f:  # 遍历f中的每一行
            line = line.strip().split()  # 去掉首尾空格,每一行以空格切分  返回一个列表  第一项为单词 其余为单词的嵌入表示
            curr_word = line[0]  # 取出单词
            words.add(curr_word)  # 加到集合/词典中
            # 定义词到其嵌入表示的映射字典
            word_to_vec_map[curr_word] = np.array(line[1:], dtype=np.float64)

    return words, word_to_vec_map



def load_pretrained_embedding(word2index, word2vector):#word2index是构建的词典(单词到索引的映射),word2vector是预训练词向量(单词到词向量的映射)
  
    embed = torch.zeros(len(word2index), opt.embed_size) # 初始化词嵌入矩阵为0
    oov_count = 0 # 找不到预训练词向量的词典中单词的个数

    for word, index in word2index.items(): #遍历词典中的每个单词 及其在词典中的索引
        try: #如果单词有对应的预训练词向量 则用预训练词向量对词嵌入矩阵的对应行进行赋值
            embed[index, :] = torch.from_numpy(word2vector[word])
        except KeyError:
            oov_count += 1

    if oov_count > 0:
        print("There are %d oov words."%oov_count)
    return embed #返回词嵌入矩阵

在主程序main.py 的train函数中调用:

 

  #加载预训练词向量
    if opt.use_pretrained_word_vector:
        words,word2vec = read_word_vector(opt.word_vector_path) #opt.word_vector_path为下载的预训练词向量 解压后的文件所在的路径
        print("预训练词向量读取完毕!")
        #读取之前预处理过程保存的词典(词到索引的映射)
        with open('./data/word2index.json') as f:
            word2index = json.load(f)

        model.embedding.weight.data.copy_(load_pretrained_embedding(word2index, word2vec)) #使用加载完预训练词向量的词嵌入矩阵 对embdding层的词嵌入矩阵赋值
        print("预训练词向量加载完毕!")
        if opt.frozen: #冻结还是finetuning
            model.embedding.weight.requires_grad = False

3. 模型定义

各个模型的定义主要保存在models/目录下,其中BasicModule是对nn.Module的简易封装,提供快速加载(可以处理GPU训练、CPU加载的情况)和保存模型(提供多GPU训练时的模型保存方法)的接口,其他模型都继承自BasicModule。

class BasicModule(nn.Module):
   '''
   封
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值