pytorch索引加一出错
为了加速运算,写了一行对应索引加一的代码,大概长这个样子:
#注意data是pytorch里的tensor
data[index] += 1
少量数据的测试中,上面的代码运行正常,结果也正确。
当index长度大于data的长度时候,运行也正常,但是加和总数对不上。
后来发现,在pytorch中上述的索引加一操作在一次调用中只会对单个索引执行一次。具体讲个例子:
比如index=[1, 1, 1]
预想的结果是data[1] = data[1]+3, 但是实际上只加了一次也就是data[1]=data[1]+1。
猜测这可能是pytorch的速度优化导致的。
numpy与pytorch的负数步长
原意是想找个翻转tensor的操作,但是transformer里面要求是PIL的Image格式。所以想自己写了一个。
第一个版本长这样
# 注意data是tensor
data = data[::-1]
这很pythonic,遗憾的是pytorch不支持负数步长,程序报错
pytorch不行,numpy可以啊
所以第二个版本长这样
# 注意data是numpy
data = data[::-1]
# 或者这样
data = np.flip(data, 0)
整挺好,确实翻转了。
但是翻转后的数据作为pytorch的Dataset输出,竟然告诉我pytorch不支持负数步长 ?_ ?!
哪来的负数步长,不都给翻转了吗?
看来numpy也跟pytorch一样都是做表面功夫,表面上告诉我转好了,实际上就是设置了下读取规则(参考pytorch中的数据reshape等操作,只操作数据表头),数据还是没动(可能出于加速考虑)。
所以就有第三个版本了
data = np.flip(data, 0).copy()
成功解决