torch.diagonal(input, offset=0, dim1=0, dim2=1)

部署运行你感兴趣的模型镜像

torch.diagonal(input, offset=0, dim1=0, dim2=1) 解释

官方解释:

返回输入值的局部视图,并将其相对于 dim1 和 dim2 的对角线元素作为一个维度添加到形状的末尾。

参数偏移量 控制着要考虑的对角线:

如果偏移 = 0,则为主对角线。

如果偏移量 > 0,则在主对角线之上。

如果偏移量 < 0,则在主对角线的下方。

使用相同参数对该函数的输出应用 torch.diag_embed(),可以得到一个对角矩阵,其中包含输入的对角条目。不过,torrent.diag_embed() 有不同的默认维数,因此需要明确指定。

参数 input (Tensor) - 输入张量。必须至少为 2 维。

offset(int,可选)- 要考虑的对角线。默认值:0(主对角线)。

dim1(int,可选项)- 取对角线的第一个维度。默认值:0。

dim2 (int,可选项)- 对角线的第二个维度。默认值:1。

举例

二维矩阵

x,维度为(h,w),其offset=0,dim1=0,dim2 =1。
则对角阵为
output=[(0,0),(1,1),……,(i,i)],其中i=min(m,n)−1 output = [(0,0),(1,1),……,(i,i)],其中i=min(m,n)-1 output=[(00),(11),……,(ii],其中i=minm,n1
其offset=1,
output=[(0,1),(1,2),……,(i,i+1)],i=min(m,n)−1−offset output = [(0,1),(1,2),……,(i,i+1)],i=min(m,n)-1-offset output=[(01),(12),……,(ii+1],i=minm,n1offset
其offset=-1,
output=[(1,0),(2,1),……,(i+1,i)],其中i=min(m,n)−1−offset output = [(1,0),(2,1),……,(i+1,i)],其中i=min(m,n)-1-offset output=[(10),(21),……,(i+1i],其中i=minm,n1offset
在这里插入图片描述
其中浅绿色为offset=-1,黄色为offset=0,浅蓝色offset=1

多维矩阵

若x维度为(a,b,c,d,e)

output = torch.diagonal(x, offset=0, dim1=2, dim2=4),

输出结果形状为(a, b, d, min(c, e))
其中
output[:,:,:,0]=x[:,:,0,:,0]output[:,:,:,1]=x[:,:,1,:,1]output[:,:,:,2]=x[:,:,2,:,2]output[:,:,:,i]=x[:,:,i,:,i],i=min(c,e)−1 output[:, :, :, 0] = x[:, :, 0, :, 0]\\ output[:, :, :, 1] = x[:, :, 1, :, 1]\\ output[:, :, :, 2] = x[:, :, 2, :, 2]\\ output[:, :, :, i] = x[:, :, i, :, i], i=min(c, e)-1 output[:,:,:,0]=x[:,:,0,:,0]output[:,:,:,1]=x[:,:,1,:,1]output[:,:,:,2]=x[:,:,2,:,2]output[:,:,:,i]=x[:,:,i,:,i],i=min(c,e)1

output = torch.diagonal(x, offset=1, dim1=2, dim2=4)

输出为
output[:,:,:,0]=x[a,b,0,d,1]output[:,:,:,1]=x[:,:,1,:,2]output[:,:,:,2]=x[:,:,2,:,3]output[:,:,:,i]=x[:,:,i,:,i+1],i=min(c,e)−1−offset output[:, :, :, 0] = x[a, b, 0, d, 1]\\ output[:, :, :, 1] = x[:, :, 1, :, 2]\\ output[:, :, :, 2] = x[:, :, 2, :, 3]\\ output[:, :, :, i] = x[:, :, i, :, i+1], i=min(c, e)-1-offset output[:,:,:,0]=x[a,b,0,d,1]output[:,:,:,1]=x[:,:,1,:,2]output[:,:,:,2]=x[:,,2,,3]output[:,:,:,i]=x[,,i,,i+1],i=min(c,e)1offset

output = torch.diagonal(x, offset=-1, dim1=2, dim2=4)

输出为
output[:,:,:,0]=x[:,:,1,:,0]output[:,:,:,1]=x[:,:,2,:,1]output[:,:,:,2]=x[:,:,3,:,2]output[:,:,:,i]=x[:,:,i+1,:,i],i=min(c,e)−1−offset output[:, :, :, 0] = x[:, :, 1, :, 0]\\ output[:, :, :, 1] = x[:, :, 2, :, 1]\\ output[:, :, :, 2] = x[:, :, 3, :, 2]\\ output[:, :, :, i] = x[:, :, i+1, :, i],i=min(c, e)-1-offset output[:,:,:,0]=x[,,1,,0]output[:,:,:,1]=x[,,2,,1]output[:,:,:,2]=x[,,3,,2]output[:,:,:,i]=x[,,i+1,,i],i=min(c,e)1offset

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

def sfb1d_atrous(lo, hi, g0, g1, mode='periodization', dim=-1, dilation=1, pad1=None, pad=None): """ 1D synthesis filter bank of an image tensor with no upsampling. Used for the stationary wavelet transform. """ C = lo.shape[1] d = dim % 4 # If g0, g1 are not tensors, make them. If they are, then assume that they # are in the right order if not isinstance(g0, torch.Tensor): g0 = torch.tensor(np.copy(np.array(g0).ravel()), dtype=torch.float, device=lo.device) if not isinstance(g1, torch.Tensor): g1 = torch.tensor(np.copy(np.array(g1).ravel()), dtype=torch.float, device=lo.device) L = g0.numel() shape = [1,1,1,1] shape[d] = L # If g aren't in the right shape, make them so if g0.shape != tuple(shape): g0 = g0.reshape(*shape) if g1.shape != tuple(shape): g1 = g1.reshape(*shape) g0 = torch.cat([g0]*C,dim=0) g1 = torch.cat([g1]*C,dim=0) # Calculate the padding size. # With dilation, zeros are inserted between the filter taps but not after. # that means a filter that is [a b c d] becomes [a 0 b 0 c 0 d]. centre = L / 2 fsz = (L-1)*dilation + 1 newcentre = fsz / 2 before = newcentre - dilation*centre # When conv_transpose2d is done, a filter with k taps expands an input with # N samples to be N + k - 1 samples. The 'padding' is really the opposite of # that, and is how many samples on the edges you want to cut out. # In addition to this, we want the input to be extended before convolving. # This means the final output size without the padding option will be # N + k - 1 + k - 1 # The final thing to worry about is making sure that the output is centred. short_offset = dilation - 1 centre_offset = fsz % 2 a = fsz//2 b = fsz//2 + (fsz + 1) % 2 # a = 0 # b = 0 pad = (0, 0, a, b) if d == 2 else (a, b, 0, 0) lo = mypad(lo, pad=pad, mode=mode) hi = mypad(hi, pad=pad, mode=mode) unpad = (fsz - 1, 0) if d == 2 else (0, fsz - 1) unpad = (0, 0) y = F.conv_transpose2d(lo, g0, padding=unpad, groups=C, dilation=dilation) + \ F.conv_transpose2d(hi, g1, padding=unpad, groups=C, dilation=dilation) # pad = (L-1, 0) if d == 2 else (0, L-1) # y = F.conv_transpose2d(lo, g0, padding=pad, groups=C, dilation=dilation) + \ # F.conv_transpose2d(hi, g1, padding=pad, groups=C, dilation=dilation) # # # Calculate the pad size # L2 = (L * dilation)//2 # # pad = (0, 0, L2, L2+dilation) if d == 2 else (L2, L2+dilation, 0, 0) # a = dilation*2 # b = dilation*(L-2) # if pad1 is None: # pad1 = (0, 0, a, b) if d == 2 else (a, b, 0, 0) # print(pad1) # lo = mypad(lo, pad=pad1, mode=mode) # hi = mypad(hi, pad=pad1, mode=mode) # if pad is None: # p = (a + b + (L - 1)*dilation)//2 # pad = (p, 0) if d == 2 else (0, p) # print(pad) return y/(2*dilation) def sfb2d_atrous(ll, lh, hl, hh, filts, mode='zero'): """ Does a single level 2d wavelet reconstruction of wavelet coefficients. Does separate row and column filtering by two calls to :py:func:`pytorch_wavelets.dwt.lowlevel.sfb1d` Inputs: ll (torch.Tensor): lowpass coefficients lh (torch.Tensor): horizontal coefficients hl (torch.Tensor): vertical coefficients hh (torch.Tensor): diagonal coefficients filts (list of ndarray or torch.Tensor): If a list of tensors has been given, this function assumes they are in the right form (the form returned by :py:func:`~pytorch_wavelets.dwt.lowlevel.prep_filt_sfb2d`). Otherwise, this function will prepare the filters to be of the right form by calling :py:func:`~pytorch_wavelets.dwt.lowlevel.prep_filt_sfb2d`. mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. Which padding to use. If periodization, the output size will be half the input size. Otherwise, the output size will be slightly larger than half. """ tensorize = [not isinstance(x, torch.Tensor) for x in filts] if len(filts) == 2: g0, g1 = filts if True in tensorize: g0_col, g1_col, g0_row, g1_row = prep_filt_sfb2d(g0, g1) else: g0_col = g0 g0_row = g0.transpose(2,3) g1_col = g1 g1_row = g1.transpose(2,3) elif len(filts) == 4: if True in tensorize: g0_col, g1_col, g0_row, g1_row = prep_filt_sfb2d(*filts) else: g0_col, g1_col, g0_row, g1_row = filts else: raise ValueError("Unknown form for input filts") lo = sfb1d_atrous(ll, lh, g0_col, g1_col, mode=mode, dim=2) hi = sfb1d_atrous(hl, hh, g0_col, g1_col, mode=mode, dim=2) y = sfb1d_atrous(lo, hi, g0_row, g1_row, mode=mode, dim=3) return y class SWTInverse(nn.Module): """ Performs a 2d DWT Inverse reconstruction of an image Args: wave (str or pywt.Wavelet): Which wavelet to use C: deprecated, will be removed in future """ def __init__(self, wave='db1', mode='zero', separable=True): super().__init__() if isinstance(wave, str): wave = pywt.Wavelet(wave) if isinstance(wave, pywt.Wavelet): g0_col, g1_col = wave.rec_lo, wave.rec_hi g0_row, g1_row = g0_col, g1_col else: if len(wave) == 2: g0_col, g1_col = wave[0], wave[1] g0_row, g1_row = g0_col, g1_col elif len(wave) == 4: g0_col, g1_col = wave[0], wave[1] g0_row, g1_row = wave[2], wave[3] # Prepare the filters if separable: filts = lowlevel.prep_filt_sfb2d(g0_col, g1_col, g0_row, g1_row) self.register_buffer('g0_col', filts[0]) self.register_buffer('g1_col', filts[1]) self.register_buffer('g0_row', filts[2]) self.register_buffer('g1_row', filts[3]) else: filts = lowlevel.prep_filt_sfb2d_nonsep( g0_col, g1_col, g0_row, g1_row) self.register_buffer('h', filts) self.mode = mode self.separable = separable def forward(self, coeffs): """ Args: coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where: yl is a lowpass tensor of shape :math:`(N, C_{in}, H_{in}', W_{in}')` and yh is a list of bandpass tensors of shape :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. I.e. should match the format returned by DWTForward Returns: Reconstructed input of shape :math:`(N, C_{in}, H_{in}, W_{in})` Note: :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly downsampled shapes of the DWT pyramid. Note: Can have None for any of the highpass scales and will treat the values as zeros (not in an efficient way though). """ yl, yh = coeffs ll = yl # Do a multilevel inverse transform for h in yh[::-1]: if h is None: h = torch.zeros(ll.shape[0], ll.shape[1], 3, ll.shape[-2], ll.shape[-1], device=ll.device) # 'Unpad' added dimensions if ll.shape[-2] > h.shape[-2]: ll = ll[...,:-1,:] if ll.shape[-1] > h.shape[-1]: ll = ll[...,:-1] # Do the synthesis filter banks if self.separable: lh, hl, hh = torch.unbind(h, dim=2) filts = (self.g0_col, self.g1_col, self.g0_row, self.g1_row) ll = lowlevel.sfb2d(ll, lh, hl, hh, filts, mode=self.mode) else: c = torch.cat((ll[:,:,None], h), dim=2) ll = lowlevel.sfb2d_nonsep(c, self.h, mode=self.mode) return ll 这个代码又是什么意思
最新发布
09-27
<think>首先,用户的问题是关于PyTorch中的torch.squeeze函数,特别指定了dim=0参数时的行为。用户问:“请问PyTorch torch.squeeze(input, dim=0) 的作用和使用示例”。 从系统级指令中,我需要: - 使用中文回答。 - 回答结构清晰,帮助用户逐步解决问题。 - 保证回答真实可靠。 - 在行内数学表达式中使用$...$格式。 - 独立公式使用$$...$$格式并单独成段。 - 生成相关问题。 - 在引用的段落末尾自然地添加引用标识,如[^1]。 用户提供了几个引用: - 引用[1]:关于torch.unsqueeze和torch.squeeze的示例。 - 引用[2]:详细介绍了torch.squeeze的功能、语法、参数和示例。 - 引用[3]:另一个使用示例。 - 引用[4]:关于torch.topk,但用户的问题是关于squeeze,所以这个可能不相关。 用户的问题聚焦在torch.squeeze(input, dim=0)上,所以我需要解释这个特定用法。 关键点从引用中提取: - 从引用[2]:torch.squeeze(input, dim=None) 移除指定维度dim上尺寸为1的维度。如果dim未指定,移除所有尺寸为1的维度。 - 示例:如果input的形状是(1,3,1,5),那么squeeze(dim=0)移除第0维,结果形状是(3,1,5)。 - 类似地在引用[3]中:如果tensor形状是(1,3,1),squeeze(dim=0)移除第0维,结果形状是(3,1)。 所以,对于torch.squeeze(input, dim=0): - 作用:移除输入张量input的第0维(即索引0的维度),但仅当该维度的尺寸为1时。如果尺寸不为1,则不做任何操作。 - 使用示例:展示一个具体的例子。 结构我的回答: 1. **介绍函数**:简要解释torch.squeeze的作用。 2. **指定dim=0的行为**:详细说明当dim=0时,函数的行为。 3. **使用示例**:提供代码示例来演示。 4. **注意事项**:提到如果dim=0的尺寸不为1时的情况。 5. **生成相关问题**:在结尾添加相关问题。 确保回答真实可靠:基于提供的引用和PyTorch官方文档(虽然引用中已有示例)。 回答结构: - 清晰、分步骤。 - 使用中文。 - 在代码块中使用Python代码示例。 - 结尾添加相关问题。 生成相关问题: - 基于用户的问题,例如: 1. torch.squeeze 和 torch.unsqueeze 有什么区别? 2. 如果不指定dim参数,会发生什么? 3. 在实际应用中,squeeze常用于什么场景? 引用标识:在回答中,如果引用内容,添加引用标识。例如,提到示例时引用[^2]。 现在,构建回答大纲: 1. **开头**:直接回答用户问题。 2. **解释作用**:描述torch.squeeze(input, dim=0)的作用。 3. **示例**:提供代码示例。 4. **总结**:简要总结。 5. **相关问题**:添加3个相关问题。</think>### PyTorch torch.squeeze(input, dim=0) 的作用和使用示例 在 PyTorch 中,`torch.squeeze()` 函数用于移除张量(tensor)中尺寸为 1 的维度,从而简化张量的形状。当指定 `dim=0` 参数时,该函数会**专门移除输入张量第 0 维(索引从 0 开始)的维度,但仅当该维度的尺寸为 1 时才生效**。如果第 0 维的尺寸不为 1,函数不会执行任何操作,返回原始张量。这一功能在数据处理中常用于压缩不必要的单维度,例如在神经网络输入预处理或特征提取时优化张量形状[^2][^3]。 下面我将通过一个清晰的步骤解释其作用,并提供代码示例来演示具体用法。 #### 1. **函数作用详解** - **基本行为**: - `torch.squeeze(input, dim=0)` 会检查输入张量 `input` 的第 0 维(即最外层的维度)。 - 如果第 0 维的尺寸为 1(例如形状为 $(1, 3, 4)$),则移除该维度,输出张量的形状变为 $(3, 4)$。 - 如果第 0 维尺寸不为 1(例如形状为 $(2, 3, 4)$),则函数不执行操作,直接返回原始张量。 - 这一操作是**原地不修改**的(即返回新张量,原张量不变),适用于需要降维的场景,如处理批量大小为 1 的数据[^2][^3]。 - **数学表示**: - 设输入张量为 $A \in \mathbb{R}^{d_0 \times d_1 \times \cdots \times d_n}$,其中 $d_0$ 是第 0 维尺寸。 - 如果 $d_0 = 1$,则输出张量形状变为 $(d_1, d_2, \ldots, d_n)$。 - 否则,输出张量与输入相同。 - 这可以用以下公式表示: $$ \text{output} = \begin{cases} \text{reshape}(A, (d_1, d_2, \ldots, d_n)) & \text{if } d_0 = 1 \\ A & \text{otherwise} \end{cases} $$ #### 2. **使用示例** 以下 Python 代码示例展示了 `torch.squeeze(input, dim=0)` 的具体行为。示例基于 PyTorch 实现,并包含详细的注释。 ```python import torch # 示例 1: 第 0 维尺寸为 1 时,移除该维度 # 创建一个形状为 (1, 3, 2) 的张量(第 0 维尺寸为 1) tensor1 = torch.randn(1, 3, 2) print("原始张量形状 (tensor1):", tensor1.shape) # 输出: torch.Size([1, 3, 2]) # 应用 torch.squeeze(dim=0) — 移除第 0 维 squeezed_tensor1 = torch.squeeze(tensor1, dim=0) print("移除第 0 维后形状:", squeezed_tensor1.shape) # 输出: torch.Size([3, 2]) # 示例 2: 第 0 维尺寸不为 1 时,无变化 # 创建一个形状为 (2, 1, 4) 的张量(第 0 维尺寸为 2,不为 1) tensor2 = torch.randn(2, 1, 4) print("\n原始张量形状 (tensor2):", tensor2.shape) # 输出: torch.Size([2, 1, 4]) # 应用 torch.squeeze(dim=0) — 第 0 维尺寸为 2,不执行移除 squeezed_tensor2 = torch.squeeze(tensor2, dim=0) print("移除第 0 维后形状:", squeezed_tensor2.shape) # 输出: torch.Size([2, 1, 4])(无变化) # 示例 3: 与其他维度对比 # 创建一个形状为 (1, 2, 1, 3) 的张量 tensor3 = torch.randn(1, 2, 1, 3) print("\n原始张量形状 (tensor3):", tensor3.shape) # 输出: torch.Size([1, 2, 1, 3]) # 仅移除第 0 维(尺寸为 1),其他维度不受影响 squeezed_tensor3 = torch.squeeze(tensor3, dim=0) print("移除第 0 维后形状:", squeezed_tensor3.shape) # 输出: torch.Size([2, 1, 3]) ``` **输出解释**: - 在示例 1 中,输入形状为 $(1, 3, 2)$,第 0 维尺寸为 1,移除后形状变为 $(3, 2)$。 - 在示例 2 中,输入形状为 $(2, 1, 4)$,第 0 维尺寸为 2(不为 1),因此形状不变。 - 在示例 3 中,输入形状为 $(1, 2, 1, 3)$,移除第 0 维后,形状变为 $(2, 1, 3)$,而第 2 维的尺寸为 1 未被移除(因为未指定该维度)。 #### 3. **注意事项** - **维度索引**:PyTorch 维度索引从 0 开始,`dim=0` 指代第一个维度。 - **错误处理**:如果指定 `dim` 超出张量维度范围,PyTorch 会抛出 `IndexError`。建议在调用前检查张量形状。 - **性能**:`squeeze` 操作是轻量级的,时间复杂度为 $O(1)$,因为它只修改元数据而不复制数据。 - **实际应用**:常用于处理卷积神经网络(CNN)的输出或数据集加载,例如当批量大小(batch size)为 1 时移除冗余维度[^2][^3]。 如果您需要更深入的示例或相关函数对比,请参考 PyTorch 官方文档或上述引用内容。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值