import keras.backend as K import numpy as np w = K.variable(np.random.randint(10,size=(64, 128, 4, 8))) k = K.variable(np.random.randint(10,size=(64, 128, 8, 16))) z = K.batch_dot(w,k) print(z.shape) #(64, 128, 4, 16)
使用K.batch_dot(qw, kw)时,我们想要z的shape是(64,128,4,16),但是获取到的shape(64,128,4,128,16)
这是由于keras的版本造成的,keras版本不要高于2.2.4,pip install Keras==2.2.4

本文探讨了在使用Keras的K.batch_dot进行批量点乘运算时遇到的问题,即输出维度不符合预期的情况,并指出了该问题与Keras版本有关,建议使用版本2.2.4以确保正确的输出形状。
656

被折叠的 条评论
为什么被折叠?



