LibTorch:tensor.index_select()

部署运行你感兴趣的模型镜像

LibTorch中的tensor.index_select()方法与PyTorch中的用法类似,作用都是在指定的tensor维度dim上按照index值索引向量。先看一下在LibTorch中的声明:

inline Tensor Tensor::index_select(int64_t dim, const Tensor & index)

主要是两个参数,一个是要选择的维度,对于二维tensor来说,0代表按行索引,1代表按列索引,参数index代表索引值,这个参数要重点注意一下!首先index它本身就是一个tensor!另外,还得是int64(kLong)类型的tensor!

下面举例说明:

如果我们要获取一个[shape为(3, 2)]的二维tensor的第列维度上(dim=1)的数据(index=0,即第0列)[shape为(3, 1)]并打印,如果像下面这样写:(from_blob的作用是将数组转化为tensor)

#include "iostream"
#include "torch/script.h"
using namespace torch::indexing;
int main()
{
    float a[3][2] = {1};
    at::Tensor b = at::from_blob(a, {3, 2}, at::kFloat).index_select(1, at::tensor(0));
    std::cout << b << std::endl;
    return 0;
}

如果就这么运行,那么bug就会出现了。。。哈哈哈,先看看bug说的啥吧:

terminate called after throwing an instance of 'c10::Error'
  what():  index_select(): Expected dtype int64 for index (index_select_out_cpu_ at ../aten/src/ATen/native/TensorAdvancedIndexing.cpp:387)

发现问题没?上面不是说了嘛,index它本身就是一个tensor!另外,还得是int64(kLong)类型的tensor!所以我们还要对at::tensor(0)做一下数据类型的转换,如下面这段代码就正常了:

#include "iostream"
#include "torch/script.h"
using namespace torch::indexing;
int main()
{
    float a[3][2] = {{1,2},{3,4},{5,6}};
    at::Tensor b = at::from_blob(a, {3, 2}, at::kFloat).index_select(1, at::tensor(0).toType(at::kLong));
    std::cout << b << std::endl;
    return 0;
}

/*****************运行结果*******************/
 1
 3
 5
[ CPUFloatType{3,1} ]

 

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

[rank3]: Traceback (most recent call last): [rank3]: File "/root/lanyun-tmp/PFT-SR-master/basicsr/train.py", line 222, in <module> [rank3]: train_pipeline(root_path) [rank3]: File "/root/lanyun-tmp/PFT-SR-master/basicsr/train.py", line 176, in train_pipeline [rank3]: model.optimize_parameters(current_iter) [rank3]: File "/root/lanyun-tmp/PFT-SR-master/basicsr/models/pft_model.py", line 93, in optimize_parameters [rank3]: l_total.backward() [rank3]: File "/root/miniconda/lib/python3.10/site-packages/torch/_tensor.py", line 581, in backward [rank3]: torch.autograd.backward( [rank3]: File "/root/miniconda/lib/python3.10/site-packages/torch/autograd/__init__.py", line 347, in backward [rank3]: _engine_run_backward( [rank3]: File "/root/miniconda/lib/python3.10/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward [rank3]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass [rank3]: RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasSgemmStridedBatched( handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches)` [rank3]:[E1103 17:09:30.589217989 ProcessGroupNCCL.cpp:1595] [PG ID 0 PG GUID 0(default_pg) Rank 3] Process group watchdog thread terminated with exception: CUDA error: an illegal memory access was encountered CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1 Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions. Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:43 (most recent call first): frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7fc42deb9446 in /root/miniconda/lib/python3.10/site-packages/torch/lib/libc10.so) frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7fc42de636e4 in /root/miniconda/lib/python3.10/site-packages/torch/lib/libc10.so) frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7fc42dfa59f8 in /root/miniconda/lib/python3.10/site-packages/torch/lib/libc10_cuda.so) frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x56 (0x7fc3dfb79626 in /root/miniconda/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so) frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0xa0 (0x7fc3dfb7e2f0 in /root/miniconda/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so) frame #5: c10d::ProcessGroupNCCL::watchdogHandler() + 0x1da (0x7fc3dfb85a5a in /root/miniconda/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so) frame #6: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x14d (0x7fc3dfb8751d in /root/miniconda/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so) frame #7: <unknown function> + 0x145c0 (0x7fc42e2e05c0 in /root/miniconda/lib/python3.10/site-packages/torch/lib/libtorch.so) frame #8: <unknown function> + 0x94ac3 (0x7fc42eb92ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6) frame #9: clone + 0x44 (0x7fc42ec23a04 in /usr/lib/x86_64-linux-gnu/libc.so.6) W1103 17:09:31.165000 265355 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 265420 closing signal SIGTERM W1103 17:09:31.168000 265355 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 265421 closing signal SIGTERM W1103 17:09:31.169000 265355 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 265422 closing signal SIGTERM /root/miniconda/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 44 leaked semaphore objects to clean up at shutdown warnings.warn('resource_tracker: There appear to be %d ' /root/miniconda/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 44 leaked semaphore objects to clean up at shutdown warnings.warn('resource_tracker: There appear to be %d ' /root/miniconda/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 44 leaked semaphore objects to clean up at shutdown warnings.warn('resource_tracker: There appear to be %d ' E1103 17:09:32.113000 265355 site-packages/torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: -6) local_rank: 3 (pid: 265423) of binary: /root/miniconda/bin/python Traceback (most recent call last): File "/root/miniconda/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/root/miniconda/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/root/miniconda/lib/python3.10/site-packages/torch/distributed/launch.py", line 208, in <module> main() File "/root/miniconda/lib/python3.10/site-packages/typing_extensions.py", line 2853, in wrapper return arg(*args, **kwargs) File "/root/miniconda/lib/python3.10/site-packages/torch/distributed/launch.py", line 204, in main launch(args) File "/root/miniconda/lib/python3.10/site-packages/torch/distributed/launch.py", line 189, in launch run(args) File "/root/miniconda/lib/python3.10/site-packages/torch/distributed/run.py", line 910, in run elastic_launch( File "/root/miniconda/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 138, in __call__ return launch_agent(self._config, self._entrypoint, list(args)) File "/root/miniconda/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 269, in launch_agent raise ChildFailedError( torch.distributed.elastic.multiprocessing.errors.ChildFailedError: ======================================================= basicsr/train.py FAILED ------------------------------------------------------- Failures: <NO_OTHER_FAILURES> ------------------------------------------------------- Root Cause (first observed failure): [0]: time : 2025-11-03_17:09:31 host : db4f85a2c0f8 rank : 3 (local_rank: 3) exitcode : -6 (pid: 265423) error_file: <N/A> traceback : Signal 6 (SIGABRT) received by PID 265423 ======================================================= root@db4f85a2c0f8:~/lanyun-tmp/PFT-SR-master# /root/miniconda/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 44 leaked semaphore objects to clean up at shutdown warnings.warn('resource_tracker: There appear to be %d ' 这是什么错误
最新发布
11-04
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

地球被支点撬走啦

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值