Pytorch学习笔记-第九章

Pytorch学习笔记-第九章RNN诗人


记录一下个人学习和使用Pytorch中的一些问题。强烈推荐 《深度学习框架PyTorch:入门与实战》.写的非常好而且作者也十分用心,大家都可以看一看,本文为学习第九章RNN诗人的学习笔记。

主要分析实现代码里面main,data,model,utils这4个代码文件完成整个项目模型结构定义,训练及生成,还有输出展示的整个过程。

utils

这个文件没啥好说的了,就是封装了一个visdom对象,再多加了方便使用的一次显示多个点以及网格显示多个图片的函数(然后这个项目用不到图片显示,应该就是前几个项目里打包过来的)。

data

原始数据是JSON结构化的数据格式,需要读入整理成可以被网络接受的形式。每首诗变成大小为125的数组,不足则补齐,超过则截断。
在这里插入图片描述
当然为了后续的embedding,以及预测分类操作,我们需要给这些字一个序号,所以输入数据的一首诗最后会成为这样的形式。
在这里插入图片描述
data.py文件中有3个函数,get_data是主要函数负责调用_parseRawData从JSON文件出解码以及预处理诗词数据(去掉无关数字,标点等),然后调用pad_sequences把数据处理成统一大小,最后在自身内部完成正方向索引字典和数据的保存与返回。其中填充函数功能比较强大,可以在很多场景下使用。

def pad_sequences(sequences,
                  maxlen=None,
                  dtype='int32',
                  padding='pre',
                  truncating='pre',
                  value=0.):
    """
    code from keras
    Pads each sequence to the same length (length of the longest sequence).
    If maxlen is provided, any sequence longer
    than maxlen is truncated to maxlen.
    Truncation happens off either the beginning (default) or
    the end of the sequence.
    Supports post-padding and pre-padding (default).
    Arguments:
        sequences: list of lists where each element is a sequence
        maxlen: int, maximum length
        dtype: type to cast the resulting sequence.
        padding: 'pre' or 'post', pad either before or after each sequence.
        truncating: 'pre' or 'post', remove values from sequences larger than
            maxlen either in the beginning or in the end of the sequence
        value: float, value to pad the sequences to the desired value.
    Returns:
        x: numpy array with dimensions (number_of_sequences, maxlen)
    Raises:
        ValueError: in case of invalid values for `truncating` or `padding`,
            or in case of invalid shape for a `sequences` entry.
    """
    if not hasattr(sequences, '__len__'):
        raise ValueError('`sequences` must be iterable.')
    lengths = []
    for x in sequences:
        if not hasattr(x, '__len__'):
            raise ValueError('`sequences` must be a list of iterables. '
                             'Found non-iterable: ' + str(x))
        lengths.append(len(x))

    num_samples = len(sequences)
    #如果没有设置好的填充函数,则统一到数据中最长值
    if maxlen is None:
        maxlen = np.max(lengths)

    # take the sample shape from the first non empty sequence
    # checking for consistency in the main loop below.
    sample_shape = tuple()
    
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值