在jupyter notebook中运行代码时出现此问题,参数类型的错误
传的参数应该是torch.longtensor类型
原来是直接的input_ids = torch.tensor(…)
使用input_ids = torch.LongTensor() 定义即可
在jupyter notebook中运行代码时出现此问题,参数类型的错误
传的参数应该是torch.longtensor类型
原来是直接的input_ids = torch.tensor(…)
使用input_ids = torch.LongTensor() 定义即可