张量基础与重塑操作
在PyTorch中,张量是构建一切神经网络运算的基石。理解张量的形状(Shape)是进行数据操作的第一步。当我们谈及“重塑”(reshape),通常指的是改变张量的维度布局而不改变其包含的数据本身。最常用的方法是使用reshape()方法或view()方法。例如,一个包含12个元素的一维张量可以被重塑为一个3行4列的二维张量,或者一个2x3x2的三维张量。关键在于新形状的总元素数必须与原张量保持一致,否则会引发错误。此外,view()要求张量在内存中是连续的,而reshape()会尽可能返回一个视图,否则将自动进行拷贝操作,使其更安全但可能稍慢。
高级索引与筛选技巧
当我们需要从张量中提取非连续或复杂的子集时,基础索引就显得力不从心,这时需要用到高级索引。PyTorch的高级索引功能强大且灵活,允许我们使用整数或布尔张量进行索引。例如,我们可以使用一个整数索引张量来选取特定下标的元素,如tensor[[0, 2, 3]]。布尔索引则更为强大,它允许我们通过一个布尔条件掩码来筛选数据,例如tensor[tensor > 0.5]会返回所有大于0.5的元素。高级索引返回的结果通常是原始数据的一个拷贝,而非视图,这一点与基础切片操作不同。
使用torch.gather进行定向数据收集
torch.gather是一个非常重要但略显复杂的操作,它允许我们沿着指定的维度,根据索引张量来收集数据。其核心思想是,提供一个与原张量形状基本一致的索引张量,该索引张量中的每个值指明了从输入张量的哪个位置获取数据。这在许多场景下非常有用,例如在强化学习中收集特定动作的Q值,或在自然语言处理中收集词袋模型的特征。gather操作的精髓在于理解其维度的映射关系,即输出在指定维度上的每个位置,其值都是由索引张量在该位置指定的下标所对应的输入值填充。
使用torch.scatter_及其反向操作
torch.scatter_是gather的逆操作,它将源张量中的数据按照索引张量指定的位置,散布到一个目标张量中。带下划线的版本表示原地操作。这个函数常用于将稀疏数据转换为稠密表示,或是在自定义的损失函数中发挥作用。例如,我们可以将一组值散布到一个全零张量的特定索引位置上。与gather类似,理解目标张量、索引张量和源张量在指定维度上的形状对应关系是正确使用该函数的关键。scatter_add_则是其变体,它将散布操作改为相加操作,即如果多个源值指向目标张量的同一位置,这些值会被求和,这在梯度累积等场景中非常实用。
结合einsum进行简洁表达
爱因斯坦求和约定(einsum)虽然不是PyTorch独有的功能,但它为张量操作提供了一种极其简洁和强大的表达方式。通过一个简单的字符串公式,torch.einsum可以表示复杂的张量乘法、转置、对角化、求和等操作,而无需调用多个独立的函数。例如,矩阵乘法可以简单地表示为torch.einsum(‘ij,jk->ik’, A, B)。它能有效减少中间变量的创建,并且在语义上更清晰地表达了计算的本质。熟练掌握einsum可以极大提升编写复杂张量运算代码的效率和可读性。
565

被折叠的 条评论
为什么被折叠?



