tensorflow.keras.models.Sequential——predict()、predict_classes()、predict_proba()方法的区别

本文介绍了在TensorFlow中,predict()、predict_classes()和predict_proba()等模型预测方法的使用和变化。predict()方法适用于多类和二分类任务,返回样本属于每个类别的概率。predict_classes()和predict_proba()已逐渐被弃用,推荐使用predict()结合np.argmax或条件判断来获取类别。此外,详细解释了predict()函数的参数及其在处理大规模输入时的效率优势。

@创建于:2022.04.17
@修改于:2022.04.17

predict() 在tf.keras.Sequential 和 tf.keras.Model模块都有效;
predict_classes()、predict_proba()方法 在tf.keras.Sequential 模块下有效,在tf.keras.Model模块下无效。

1、方法介绍

predict()方法预测时,返回值是数值,表示样本属于每一个类别的概率。

predict_proba() 方法预测时,返回值是数值,表示样本属于每一个类别的概率。与predict输出结果一致。

predict_classes() 方法预测时,返回的是类别的索引,即该样本所属的类别标签。

predict_classes() 和 predict_proba()方法将逐渐被弃用了,不建议尝试了。虽然在tensorflow2.5.0中还存在,如果使用会报出如下警告。

UserWarning: `model.predict_proba()` is deprecated and will be removed after 2021-01-01. Please use `model.predict()` instead.
  warnings.warn('`model.predict_proba()` is deprecated and '

2、建议使用predict(),进而改写

predict_classes() 已被弃用并且将在 2021-01-01 之后删除,请改用下面做法。这样能够与tf.keras.Model模块保持一致,实现统一。

np.argmax(model.predict(x), axis=-1)
  • 如果你的模型进行二分类(例如,如果它使用 sigmoid 最后一层激活),使用下面代码获得分类
(model.predict(x) > 0.5).astype("int32")

3、predict()参数介绍</

from tensorflow.keras.models import Sequential from tensorflow.keras.layers import LSTM, Dense, Dropout from tensorflow.keras.metrics import Precision from tensorflow.keras.metrics import Recall import tensorflow as tf def build_lstm(): model = Sequential() model.add(LSTM(128, input_shape=(5, 39), return_sequences=True)) model.add(Dropout(0.3)) model.add(LSTM(64, return_sequences=True)) model.add(Dropout(0.3)) model.add(LSTM(32)) model.add(Dropout(0.3)) model.add(Dense(1, activation=‘sigmoid’)) model.compile(optimizer=tf.keras.optimizers.Adam( clipnorm=1.0, # 梯度范数裁剪 clipvalue=0.5 # 逐元素梯度裁剪 ), loss=‘binary_crossentropy’, metrics=[‘accuracy’, Precision(name=‘precision’),Recall(name=‘recall’)]) return model model = build_lstm() from tensorflow.keras.callbacks import Callback, EarlyStopping 自定义回调记录指标 class MetricsLogger(Callback): def on_epoch_end(self, epoch, logs=None): print(f"Epoch {epoch+1}:“) print(f” Train Loss: {logs[‘loss’]:.4f} | Val Loss: {logs[‘val_loss’]:.4f}“) print(f” Train Precision: {logs[‘precision’]:.4f} | Val Precision: {logs[‘val_precision’]:.4f}“) print(f” Train Recall: {logs[‘recall’]:.4f} | Val Recall: {logs[‘val_recall’]:.4f}") 早停策略(监控验证集Precision) early_stop = EarlyStopping(monitor=‘val_precision’, patience=10, mode=‘max’, restore_best_weights=True) history = model.fit(X_train, y_train, validation_data=(X_val, y_val), epochs=100, batch_size=32, class_weight=class_weights, callbacks=[MetricsLogger(), early_stop], verbose=1) 直接帮我改进整个代码
03-29
评论 6
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值