import keras.backend as K import numpy as np w = K.variable(np.random.randint(10,size=(64, 128, 4, 8))) k = K.variable(np.random.randint(10,size=(64, 128, 8, 16))) z = K.batch_dot(w,k) print(z.shape) #(64, 128, 4, 16)
使用K.batch_dot(qw, kw)时,我们想要z的shape是(64,128,4,16),但是获取到的shape(64,128,4,128,16)
这是由于keras的版本造成的,keras版本不要高于2.2.4,pip install Keras==2.2.4