-
tf.cast()可附加条件,例如:
#tensor大于0的元素的位置返回True,其余返回False(在dtype为bool的情况下) #其他情况同理 tf.cast(tensor>0,dtype)
-
tf.where
#在predict张量和tartget张量元素在0.5两边时 #theta矩阵响应元素置1,否则置0 x=tf.ones_like(log_yi, dtype=tf.float32) y=tf.zeros_like(log_yi, dtype=tf.float32) condition = ((target - 0.5) * (predict - 0.5)<0.) theta = tf.where(condition,x,y)
-
evaluation_hook可以使用tf.train里面的hook
evaluation_hook = tf.train.LoggingTensorHook({"total_loss:": total_loss}, every_n_iter=10) output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn, evaluation_hooks=[evaluation_hook])
#12345是传入的seed,用来选定第一次random的开始点 random.Random(12345) rng.shuffle(train_examples)