张量是 PyTorch 中最基本的数据结构,是进行所有计算的核心。可以将其理解为一个多维数组。这一节我们将深入探索张量的方方面面,从创建、属性到各种高级操作。
2.1 张量的创建和初始化
PyTorch 提供了极其丰富的张量创建方式,以适应不同的需求。
import torch
import numpy as np
# 1. 从数据直接创建
# torch.tensor() 是最主要的工厂函数,它会拷贝数据,因此是安全的
data = [[1, 2], [3, 4]]
tensor_from_data = torch.tensor(data)
print(f"From Python list:\n {tensor_from_data}\n")
# 2. 从 NumPy 数组创建
# torch.from_numpy() 创建的张量与 NumPy 数组共享内存,这意味着修改一方会影响另一方。这在需要高效数据交换时非常有用。
np_array = np.array(data)
tensor_from_numpy = torch.from_numpy(np_array)
print(f"From NumPy array (shares memory):\n {tensor_from_numpy}\n")
# 验证内存共享:修改 NumPy 数组,观察张量的变化
np_array[0, 0] = 100
print(f"Original NumPy array modified. Tensor view:\n {tensor_from_numpy}\n")
# 如果想避免共享内存,可以使用 .clone() 方法创建一个独立的副本
tensor_clone = torch.from_numpy(np_array).clone()
np_array[0, 0] = 200
print(f"NumPy array modified again. Cloned tensor is unaffected:\n {tensor_clone}\n")
# 3. 从其他张量创建 (保留属性)
# 这些方法可以重用现有张量的属性(如 shape, dtype, device),除非显式覆盖,这有助于保持代码的一致性。
x_data = torch.tensor([[5, 6], [7, 8]], dtype=torch.float32)
x_ones = torch.ones_like(x_data) # 创建一个全为1的张量,形状和类型与 x_data 相同
print(f"Ones Tensor like x_data:\n {x_ones} \n")
x_rand = torch.rand_like(x_data, dtype=torch.float) # 覆盖 x_data 的数据类型为 float
print(f"Random Tensor like x_data (dtype overridden):\n {x_rand} \n")
# 4. 创建具有特定形状和值的张量
shape = (2, 3, 4)
zeros_tensor = torch.zeros(shape)
ones_tensor = torch.ones(shape)
rand_tensor = torch.rand(shape) # [0, 1) 上的均匀分布
randn_tensor = torch.randn(shape) # 标准正态分布 (均值为0,方差为1)
eye_tensor = torch.eye(3) # 3x3 的单位矩阵
full_tensor = torch.full(shape, 7) # 填充特定值
print(f"Zeros Tensor (shape {shape}):\n {zeros_tensor} \n")
print(f"Random Tensor (shape {shape}):\n {rand_tensor} \n")
print(f"Identity Matrix:\n {eye_tensor} \n")
# 5. 创建序列
# 这些函数在定义模型权重或生成坐标时非常有用
arange_tensor = torch.arange(0, 10, step=2) # 类似于 Python 的 range
linspace_tensor = torch.linspace(0, 10, steps=5) # 在 [0, 10] 之间生成5个等间距的点
logspace_tensor = torch.logspace(-1, 1, steps=5) # 在 10^-1 到 10^1 之间生成5个对数等间距的点
print(f"Arange Tensor: {arange_tensor}")
print(f"Linspace Tensor: {linspace_tensor}")
print(f"Logspace Tensor: {logspace_tensor}")
2.2 张量的属性
每个张量都有三个关键属性,它们共同定义了张量的状态:shape
, dtype
, 和 device
。
tensor = torch.randn(3, 4)
print(f"Shape of tensor: {tensor.shape}")
print(f"Datatype of tensor: {tensor.dtype}")
print(f"Device tensor is stored on: {tensor.device}")
# 张量的数据类型非常重要,因为它决定了内存消耗和计算精度
# 常见的 dtype 包括:torch.float32 (默认浮点), torch.float64, torch.int64 (默认整数), torch.bool
# 我们可以使用 .to() 方法来改变张量的 device 或 dtype
# 将张量移动到 GPU (如果可用)
if torch.cuda.is_available():
tensor_gpu = tensor.to('cuda')
print(f"\nTensor moved to: {tensor_gpu.device}")
# 也可以同时改变 dtype
tensor_gpu_fp16 = tensor_gpu.to(dtype=torch.float16)
print(f"Tensor on GPU with new dtype: {tensor_gpu_fp16.dtype}")
# 将张量转换回 CPU
# tensor_cpu = tensor_gpu.cpu()
2.3 张量操作:索引、切片、连接与变形
PyTorch 提供了丰富且强大的操作来处理张量,这些操作是构建复杂模型的基础。
索引与切片
这与 NumPy 的索引非常相似,提供了灵活的数据访问方式。
tensor = torch.arange(12).reshape(3, 4)
print(f"Original Tensor:\n {tensor}\n")
# 获取第一行
print(f"First row: {tensor[0]}\n")
# 获取最后一列
print(f"Last column:\n {tensor[:, -1]}\n")
# 切片:获取第1行到第2行,第1列到第2列的子张量
print(f"Sub-tensor (rows 1-2, cols 1-2):\n {tensor[1:3, 1:3]}\n")
# 使用布尔索引:选取所有大于5的元素
mask = tensor > 5
print(f"Mask (elements > 5):\n {mask}\n")
print(f"Elements > 5: {tensor[mask]}")
连接 (Concatenation)
torch.cat
可以将一系列张量沿指定维度进行连接。
t1 = torch.zeros(2, 3)
t2 = torch.ones(2, 3)
# 沿维度 0 连接 (行方向,增加行数)
cat_dim0 = torch.cat([t1, t2], dim=0)
print(f"Concatenated along dim=0 (shape: {cat_dim0.shape}):\n {cat_dim0}\n")
# 沿维度 1 连接 (列方向,增加列数)
cat_dim1 = torch.cat([t1, t2], dim=1)
print(f"Concatenated along dim=1 (shape: {cat_dim1.shape}):\n {cat_dim1}\n")
# torch.stack 是另一种连接方式,它会创建一个新的维度
t3 = torch.arange(3)
t4 = torch.arange(3, 6)
stacked_tensor = torch.stack([t3, t4], dim=0)
print(f"Stacked tensor (shape: {stacked_tensor.shape}):\n {stacked_tensor}")
变形 (Reshaping)
改变张量的形状而不改变其数据是常见需求。
x = torch.arange(12)
# 使用 reshape
y = x.reshape(3, 4)
print(f"Reshaped to 3x4:\n {y}\n")
# 使用 view (与 reshape 类似,但对内存连续性有要求)
# view 创建的张量与原张量共享数据
z = x.view(2, 6)
print(f"Viewed as 2x6:\n {z}\n")
# -1 是一个强大的占位符,可以自动计算维度大小
w = x.reshape(2, 2, -1)
print(f"Reshaped with -1 (shape: {w.shape}):\n {w}\n")
# 展平 (Flatten): 将多维张量变为一维
flat_tensor = y.flatten()
print(f"Flattened tensor: {flat_tensor}")
# 压缩和解压缩维度
# squeeze() 移除所有大小为1的维度
# unsqueeze() 在指定位置添加一个大小为1的维度
a = torch.zeros(1, 3, 1, 4)
squeezed = a.squeeze()
print(f"Original shape: {a.shape}, Squeezed shape: {squeezed.shape}")
unsqueezed = squeezed.unsqueeze(dim=0).unsqueeze(dim=2)
print(f"Unsqueezed back to: {unsqueezed.shape}")
2.4 广播机制 (Broadcasting)
广播机制允许 PyTorch 在处理不同形状的张量时,自动扩展较小张量的维度以匹配较大张量,从而实现逐元素操作,而无需显式地复制数据。这极大地提高了代码的简洁性和内存效率。
广播的规则如下:
- 如果两个张量的维度数不同,在维度较少的张量的形状前面补 1,直到它们的维度数相同。
- 对于每个维度,如果两个张量在该维度上的大小相同,或者其中一个张量的大小为 1,则它们在该维度上是兼容的。
- 如果所有维度都兼容,则可以进行广播。
- 广播后,每个维度的大小等于两个输入张量在该维度上的最大值。
- 大小为 1 的维度会被扩展(表现得像被复制了)以匹配另一个张量的大小。
广播示例
a = torch.arange(3).reshape(3, 1)
# a (shape: 3x1):
# [[0],
# [1],
# [2]]
b = torch.arange(2).reshape(1, 2)
# b (shape: 1x2):
# [[0, 1]]
# 广播过程:
# a 的 shape (3, 1) -> 扩展为 (3, 2)
# b 的 shape (1, 2) -> 扩展为 (3, 2)
c = a + b
print(f"Tensor a (shape {a.shape}):\n{a}")
print(f"Tensor b (shape {b.shape}):\n{b}")
print(f"After broadcasting, a + b (shape {c.shape}):\n{c}")
广播机制可视化
graph TD
subgraph Before Broadcasting
A[Tensor a (3, 1)<br>[ [0], [1], [2] ] ]
B[Tensor b (1, 2)<br>[ [0, 1] ] ]
end
subgraph Broadcasting Process
A_ext[a is stretched to (3, 2)<br>[ [0, 0], [1, 1], [2, 2] ] ]
B_ext[b is stretched to (3, 2)<br>[ [0, 1], [0, 1], [0, 1] ] ]
end
subgraph Result
C[Result c = a + b (3, 2)<br>[ [0, 1], [1, 2], [2, 3] ] ]
end
A --> A_ext
B --> B_ext
A_ext -- Element-wise Add --> C
B_ext -- Element-wise Add --> C
2.5 本章总结
本章深入探讨了 PyTorch 的核心——张量。我们学习了以下关键内容:
- 张量创建:掌握了从 Python 列表、NumPy 数组(注意内存共享)、其他张量以及使用
zeros
,ones
,rand
等函数创建张量的多种方法。 - 张量属性:理解了
shape
,dtype
,device
三大属性的重要性,并学会了使用.to()
方法在不同设备(CPU/GPU)和数据类型之间进行转换。 - 高级操作:
- 通过索引和切片,我们能够灵活地访问和修改张量的部分数据。
- 使用
torch.cat
和torch.stack
进行张量连接,这是组合模型特征的常用技巧。 - 学习了
reshape
,view
,flatten
,squeeze
,unsqueeze
等变形操作,以满足不同网络层对输入形状的要求。
- 广播机制:理解了 PyTorch 如何自动处理不同形状张量之间的运算,这使得代码更简洁、高效。
对张量的熟练操作是精通 PyTorch 的基石。这些知识将贯穿于后续所有章节的模型构建、数据处理和训练过程中。