torch.randn 函数解析
torch.randn
是 PyTorch 中用于生成服从标准正态分布(均值为 0,标准差为 1)的张量的函数。生成的张量中的元素是从正态分布中随机采样的。
函数签名
torch.randn(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor
参数说明
*size
:可变参数,定义输出张量的形状。可以是整数或整数序列。out
:可选参数,指定输出张量。dtype
:可选参数,指定输出张量的数据类型,如torch.float32
、torch.float64
等。layout
:可选参数,指定张量的内存布局,默认为torch.strided
。device
:可选参数,指定张量存储的设备,如'cpu'
或'cuda'
。requires_grad
:可选参数,布尔值,指定张量是否需要梯度计算,默认为False
。
示例代码
生成一个形状为 (2, 3) 的标准正态分布张量:
import torch
# 生成 2x3 的标准正态分布张量
x = torch.randn(2, 3)
print(x)
生成一个形状为 (4, 4) 的张量,并指定数据类型和设备:
# 生成 4x4 的张量,数据类型为 float64,存储在 GPU 上
y = torch.randn(4, 4, dtype=torch.float64, device='cuda')
print(y)
注意事项
- 生成的张量元素服从标准正态分布(均值 0,标准差 1)。
- 如果需要其他均值和标准差的正态分布,可以使用
torch.normal
函数。 - 默认情况下,
dtype
是torch.float32
,设备是 CPU。