记录一下个人学习和使用Pytorch中的一些问题。强烈推荐 《深度学习框架PyTorch:入门与实战》.写的非常好而且作者也十分用心,大家都可以看一看,本文为学习第九章RNN诗人的学习笔记。
主要分析实现代码里面main,data,model,utils这4个代码文件完成整个项目模型结构定义,训练及生成,还有输出展示的整个过程。
utils
这个文件没啥好说的了,就是封装了一个visdom对象,再多加了方便使用的一次显示多个点以及网格显示多个图片的函数(然后这个项目用不到图片显示,应该就是前几个项目里打包过来的)。
data
原始数据是JSON结构化的数据格式,需要读入整理成可以被网络接受的形式。每首诗变成大小为125的数组,不足则补齐,超过则截断。
当然为了后续的embedding,以及预测分类操作,我们需要给这些字一个序号,所以输入数据的一首诗最后会成为这样的形式。
data.py文件中有3个函数,get_data
是主要函数负责调用_parseRawData
从JSON文件出解码以及预处理诗词数据(去掉无关数字,标点等),然后调用pad_sequences
把数据处理成统一大小,最后在自身内部完成正方向索引字典和数据的保存与返回。其中填充函数功能比较强大,可以在很多场景下使用。
def pad_sequences(sequences,
maxlen=None,
dtype='int32',
padding='pre',
truncating='pre',
value=0.):
"""
code from keras
Pads each sequence to the same length (length of the longest sequence).
If maxlen is provided, any sequence longer
than maxlen is truncated to maxlen.
Truncation happens off either the beginning (default) or
the end of the sequence.
Supports post-padding and pre-padding (default).
Arguments:
sequences: list of lists where each element is a sequence
maxlen: int, maximum length
dtype: type to cast the resulting sequence.
padding: 'pre' or 'post', pad either before or after each sequence.
truncating: 'pre' or 'post', remove values from sequences larger than
maxlen either in the beginning or in the end of the sequence
value: float, value to pad the sequences to the desired value.
Returns:
x: numpy array with dimensions (number_of_sequences, maxlen)
Raises:
ValueError: in case of invalid values for `truncating` or `padding`,
or in case of invalid shape for a `sequences` entry.
"""
if not hasattr(sequences, '__len__'):
raise ValueError('`sequences` must be iterable.')
lengths = []
for x in sequences:
if not hasattr(x, '__len__'):
raise ValueError('`sequences` must be a list of iterables. '
'Found non-iterable: ' + str(x))
lengths.append(len(x))
num_samples = len(sequences)
#如果没有设置好的填充函数,则统一到数据中最长值
if maxlen is None:
maxlen = np.max(lengths)
# take the sample shape from the first non empty sequence
# checking for consistency in the main loop below.
sample_shape = tuple()