运行出现报错。修改数据格式 输出sample_ids的值,可以看到数据类型是 torch.int32 解决 需要将sample_ids类型转为long,修改方式: idx= idx.type(torch.long) 或 idx= self.tensor(idx, dtype=torch.long) 参考: IndexError: tensors used as indices must be long, byte or bool tensors 知乎:https://zhuanlan.zhihu.com/p/565931659