tf.metrics.accuracy用于计算模型输出的准确率
tf.metrics.accuracy(
labels,
predictions,
weights=None,
metrics_collections=None,
updates_collections=None,
name=None
)
return accuracy, update_op
参数:
labels 标签的真实值
predictions 标签的预测值
weights 每个值的权重
metrics_collections accuracy的集合
updates_collections update_op的集合
输出:
accuracy 上一个batch的准确率
update_op 加上本次训练数据后的准确率
例子:
import numpy as np
import pandas as pd
import tensorflow as tf
x = tf.placeholder(tf.float64, [5])
y = tf.placeholder(tf.int32, [5])
acc, acc_op = tf.metrics.accuracy(labels=y, predictions=tf.greater_equal(x,0.5))
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
sess.run(tf