torch.block_diag()
是PyTorch中用于将多个矩阵沿对角线拼接成一个大矩阵的函数。这个函数可以用于构建卷积神经网络中的卷积核矩阵,或者构建变分自编码器等需要对多个线性变换进行堆叠的模型。
torch.block_diag()
函数的语法如下:
torch.block_diag(*args)
其中,*args
是要拼接的矩阵,可以是一个或多个Tensor对象。
下面是一个简单的代码示例,演示了如何使用torch.block_diag()
函数将三个矩阵沿对角线拼接成一个大矩阵:
import torch
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5