torch.nonezero()的用法

本文介绍了如何使用PyTorch中的torch.nonzero()函数获取张量中非零元素的索引,并通过实例展示了二维和三维张量的处理。该函数对于数据筛选和矩阵操作至关重要。
部署运行你感兴趣的模型镜像

函数原型:

torch.nonzero(input, out=None) → LongTensor

参数:

input (Tensor) – 源张量
out (LongTensor, optional) – 包含索引值的结果张量

代码示例

返回一个包含输入input中非零元素索引的张量。输出张量中的每行包含输入中非零元素的索引。

x = torch.tensor([0, 0, 1, 5, 8])
y = torch.nonzero(x)
print(y)
print(y.shape)
>>>
tensor([[2],
        [3],
        [4]])
torch.Size([3, 1])

如果输入input有n维,则输出的索引张量output的形状为 z x n, 这里 z 是输入张量input中所有非零元素的个数。

x = torch.tensor([[0, 0, 1, 5],
                  [1, 5, 0, 8],
                  [2, 8, 9, 0]])
y = torch.nonzero(x)
print(y)
print(y.shape)
>>>
tensor([[0, 2],
        [0, 3],
        [1, 0],
        [1, 1],
        [1, 3],
        [2, 0],
        [2, 1],
        [2, 2]])
torch.Size([8, 2])

pytorch文档学习链接:https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch/

您可能感兴趣的与本文相关的镜像

PyTorch 2.6

PyTorch 2.6

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

以下是使用代码实现 Tensor 常用比较函数 `torch.equal()`、`torch.eq()`、`torch.ne()`、`torch.gt()`、`torch.lt()`、`torch.ge()`、`torch.le()`、`torch.topk()`、`torch.sort()` 的示例: ```python import torch # 创建两个示例 Tensor a = torch.tensor([[1, 2], [3, 4]]) b = torch.tensor([[1, 2], [4, 3]]) # 1. torch.equal():比较两个张量是否完全相同 equal_result = torch.equal(a, b) print("torch.equal() result:", equal_result) # 2. torch.eq():逐元素比较两个张量是否相等 eq_result = torch.eq(a, b) print("torch.eq() result:") print(eq_result) # 3. torch.ne():逐元素比较两个张量是否不相等 ne_result = torch.ne(a, b) print("torch.ne() result:") print(ne_result) # 4. torch.gt():逐元素比较第一个张量是否大于第二个张量 gt_result = torch.gt(a, b) print("torch.gt() result:") print(gt_result) # 5. torch.lt():逐元素比较第一个张量是否小于第二个张量 lt_result = torch.lt(a, b) print("torch.lt() result:") print(lt_result) # 6. torch.ge():逐元素比较第一个张量是否大于或等于第二个张量 ge_result = torch.ge(a, b) print("torch.ge() result:") print(ge_result) # 7. torch.le():逐元素比较第一个张量是否小于或等于第二个张量 le_result = torch.le(a, b) print("torch.le() result:") print(le_result) # 8. torch.topk():返回输入张量中给定维度上的前 k 个最大元素及其索引 topk_values, topk_indices = torch.topk(a.flatten(), k=2) print("torch.topk() values:", topk_values) print("torch.topk() indices:", topk_indices) # 9. torch.sort():对输入张量进行排序 sorted_values, sorted_indices = torch.sort(a.flatten()) print("torch.sort() values:", sorted_values) print("torch.sort() indices:", sorted_indices) ```
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值