在pytorch代码编写过程中,经常容易碰到一类错误,比如:
- RuntimeError: Expected object of scalar type Float but got scalar type Double for argument #4 'mat1'
-
Expected object of backend CPU but got backend CUDA for sequence element 1 in sequence
-
RuntimeError: Expected object of scalar type Double but got scalar type Float
-
RuntimeError: Expected object of backend CPU but got backend CUDA for argument #2 'mask'
- RuntimeError: Expected object of scalar type Byte but got scalar type Long for argument #2 'mask
这些错误,都与我们数据的类型有关系,比如数据是float还是double;tensor是cuda类型的还是cpu类型的。
因为有时候我们写的代码规模很大,所以很难直接找到错误所在,调试起来很麻烦,可能我们需要插入一大堆print语句才可以找到错误。
TorchSnooper 是一个设计了用来解决这类问题的工具。因此以后再pytorch中碰到类似的问题,可以用这个利器来辅助我们进行pytorch debug 工作。
TorchSnooper 的具体使用方式可以参照如下链接&#