tf.cast()的介绍
这是tensorflow里的一个函数,
它的标准格式
tf.cast(x,dtype,name=None) 此函数作用是:数据类型转换
x:输入的张量
dtype:转换数据的类型
name:名字
这里举一个例子:
def wMSE(y_true, y_pred, binary=False):
if binary:
weights = tf.cast(y_true>0, tf.float32)
else:
weights = y_true
return tf.reduce_mean(weights*tf.square(y_true-y_pred))
这里的意思就是如果值为binary 那么weights中y_true的值大于0的时候,就把数据格式转换为 tf.float32
465

被折叠的 条评论
为什么被折叠?



