报错indexerror: tensors used as indices must be long, byte or bool tensors

本文博主分享了在使用KGPolicy进行模型训练时遇到的索引错误,通过将one_hot张量从float类型转换为long类型解决了问题,详细解释了原因并提供了相关技术链接作为参考。
部署运行你感兴趣的模型镜像

报错

Traceback (most recent call last):
  File "main.py", line 306, in <module>
    args_config=args_config,
  File "main.py", line 224, in train
    avg_reward,
  File "main.py", line 56, in train_one_epoch
    selected_neg_items_list, _ = sampler(batch_data, adj_matrix, edge_matrix)
  File "/home/zzy/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/zzy/KG-Policy/kgpolicy/modules/sampler/kgpolicy.py", line 91, in forward
    one_hop, one_hop_logits = self.kg_step(pos, users, adj_matrix, step=1)
  File "/home/zzy/KG-Policy/kgpolicy/modules/sampler/kgpolicy.py", line 139, in kg_step
    i_e = gcn_embedding[one_hop]
IndexError: tensors used as indices must be long, byte or bool tensors

定位到kgpolicy文件的第139行

    one_hop = adj_matrix[pos]
    i_e = gcn_embedding[one_hop] 

感觉索引indices是one_hop ,检查下one_hop:

print(one_hop)

结果为:

tensor([[ 1295.,  6801.,  6590.,  ..., 43463., 63427., 23697.],
        [ 2940., 23298.,  4720.,  ..., 54026., 45077., 68521.],
        [  136.,   137.,  1033.,  ..., 47605., 59133., 51274.],
        ...,
        [ 1629.,  6126.,   118.,  ..., 56004., 23626., 41774.],
        [12689.,  6415., 21290.,  ..., 28091., 24405., 37709.],
        [  739.,   818.,  1202.,  ..., 54062., 29978., 25789.]])

因为默认的是float类型,所以不对,就直接类型转换为long就可以了.
关于究竟转化为哪种类型,参考链接:
https://blog.youkuaiyun.com/junqing_wu/article/details/99692296
https://blog.youkuaiyun.com/jacke121/article/details/82703640
https://blog.youkuaiyun.com/weixin_38314865/article/details/105949825

        one_hop = adj_matrix[pos].type(torch.long)#已修改
        i_e = gcn_embedding[one_hop] 

ok,这样就可以啦…这么个错误弄了一上午,崩溃…希望能帮到小可爱们~

您可能感兴趣的与本文相关的镜像

PyTorch 2.7

PyTorch 2.7

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

评论 6
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值