@创建于:2022.04.17
@修改于:2022.04.17
predict() 在tf.keras.Sequential 和 tf.keras.Model模块都有效;
predict_classes()、predict_proba()方法 在tf.keras.Sequential 模块下有效,在tf.keras.Model模块下无效。
1、方法介绍
predict()方法预测时,返回值是数值,表示样本属于每一个类别的概率。
predict_proba() 方法预测时,返回值是数值,表示样本属于每一个类别的概率。与predict输出结果一致。
predict_classes() 方法预测时,返回的是类别的索引,即该样本所属的类别标签。
predict_classes() 和 predict_proba()方法将逐渐被弃用了,不建议尝试了。虽然在tensorflow2.5.0中还存在,如果使用会报出如下警告。
UserWarning: `model.predict_proba()` is deprecated and will be removed after 2021-01-01. Please use `model.predict()` instead.
warnings.warn('`model.predict_proba()` is deprecated and '
2、建议使用predict(),进而改写
predict_classes() 已被弃用并且将在 2021-01-01 之后删除,请改用下面做法。这样能够与tf.keras.Model模块保持一致,实现统一。
- 如果你的模型进行多类分类(例如,如果它使用
softmax最后一层激活),使用np.argmax获得最大值索引。因为深度学习的多分类模型里面需要使用独热向量编码,或参考keras中to_categorical函数解析
np.argmax(model.predict(x), axis=-1)
- 如果你的模型进行二分类(例如,如果它使用
sigmoid最后一层激活),使用下面代码获得分类
(model.predict(x) > 0.5).astype("int32")

本文介绍了在TensorFlow中,predict()、predict_classes()和predict_proba()等模型预测方法的使用和变化。predict()方法适用于多类和二分类任务,返回样本属于每个类别的概率。predict_classes()和predict_proba()已逐渐被弃用,推荐使用predict()结合np.argmax或条件判断来获取类别。此外,详细解释了predict()函数的参数及其在处理大规模输入时的效率优势。
最低0.47元/天 解锁文章





