函数原型:
Lambda(function, output_shape=None, mask=None, arguments=None)
参数说明:
function:要实现的函数,该函数仅接受一个变量,即上一层的输出
output_shape:函数应该返回的值的shape,可以是一个tuple,也可以是一个根据输入shape计算输出shape的函数
mask: 掩膜
arguments:可选,字典,用来记录向函数中传递的其他关键字参数
它的实际使用方法如下,这是一个切片的例子:
import tensorflow as tf
from tensorflow.keras.layers import Input, Lambda
def slice(x,index): #定义的切片函数
return x[:,:,index+1]
x_test = tf.conv = np.array([[[1,2],[2,3],[3,4],[4,5]]])
x_test = tf.convert_to_tensor(x_test)
print(x_test.numpy())
x1 = Lambda(slice,output_shape=(4,1),arguments={'index':0})(x_test)
print(x1.numpy())
输出:
可以看到,在该lambda实现了进行一维的切片,其中x_test代表输入,仅接受一个参数,替换函数中的X,其他参数通过arguments以字典的形式进行传递。