2021SC@SDUSC
上周对源码中关键包——torch.utils.data包和torch.nn.utils.rnn.pad_packed_sequence包的功能和MyDataset类的__ init __ ()函数进行了分析,这周接着对__ getitem __ ()和__ len __ ()函数进行分析。
一、my_dataloader
__ getitem __ ()函数
def __getitem__(self, index):
x, trg= self.f[index]
x = x.lower()
#分词处理
x = nltk.tokenize.word_tokenize(x)
代码分段解析:
此段代码将在__ init __ ()函数中存放的silver label数据取出到x中并对其进行分词处理操作。
for i in range(len(x)):
x[i] = self.vocab(x[i])
x.append(self.vocab('<end>'))
x = [self.vocab('<start>')] + x
代码分段解析:
此段代码将silver label在数据集vocab_kp20k.npy中出现过的词存入x。
if len(x)>512:
x = x[:512]
while len(x) < 512:
x.append(self.vocab('<pad>'))
#构造张量
src = torch.Tensor(x)
代码分段解析:
此段对x做进一步处理,将x长度控制在512以内并将x构造为张量赋值给src。
其中出现的torch.tensor()函数,对tensor和此函数做了学习和理解,整理如下。
tensor
Tensor实际上就是一个多维数组(multidimensional array)。其目的是能够创造更高维度的矩阵、向量。如图所示。

将三维的张量用一个正方体来表示。

这样可以进一步生成更高维的张量。

通过图示和Python的实例能够想象Tensor的空间构造以及如何用Tensor的属性来构造Tensor。
torch.tensor()
torch.tensor(data, dtype=None, device=None, requires_grad=False)
其中data可以是list,tuple,NumPy,ndarray等其他类型。torch.tensor会从data中的数据部分做拷贝而不是直接引用,根据原始数据类型生成相应的torch.LongTensor torch.FloatTensor和torch.DoubleTensor。示例如下。

data变成了浮点型,tensor1.type()随之也变成相应的torch.FloatTensor。

由此可见,可以生成指定dtype的tensor。
x = trg
x = ','.join(x)
#分词处理
x = nltk.tokenize.word_tokenize(x.lower())
代码分段解析:
将第一个代码段中x, trg = self.f[index]存入trg的silvel label值再次赋值给x并做分词处理。
for i in range(len(x)):
x[i] = self.vocab(x[i])
x.append(self.vocab('<end>'))
x = [self.vocab('<start>')] + x
代码分段解析:
此段代码将silver label在数据集vocab_kp20k.npy中出现过的词存入x。
while len(x) < 30:
x.append(self.vocab('<pad>'))
trg = torch.Tensor(x)
return src, trg
代码分段解析:
对x做进一步处理,保证x长度<30后给x构造张量并赋值给trg,最后返回张量src、trg。
__ len __ ()函数
def __len__(self):
return len(self.f)
__ len __ ()函数比较简短,直接返回传入数据silver label的大小。
二、总结
本周分析了my_dataloader.py的__ getitem __ ()函数和__ len __ ()函数。在__ getitem __ ()函数中着重学习了tensor的含义以及torch.tensor()函数的用法,对张量有了进一步的理解。
至此,完成了对my_dataloader.py的分析,将数据集vocab_kp20k.npy和silver label载入模型Model.py,下一步将进行模型的训练。故下周起对Train.py展开分析。
本文详细解析了my_dataloader.py中的__getitem__()和__len__()函数,重点介绍了张量的概念、torch.tensor()的使用,以及数据预处理步骤,包括分词和长度控制。通过对vocab_kp20k.npy数据集的处理,为后续模型Model.py的训练做好准备。
546

被折叠的 条评论
为什么被折叠?



