深入探索Keras:高级使用技巧
1. 自定义回调函数记录损失值
在训练过程中,有时我们需要详细记录每一批次的损失值,并将其可视化。以下是一个简单的自定义回调函数示例,它可以记录每一批次的损失值,并在每个epoch结束时绘制损失值的图表。
from matplotlib import pyplot as plt
import keras
class LossHistory(keras.callbacks.Callback):
def on_train_begin(self, logs):
self.per_batch_losses = []
def on_batch_end(self, batch, logs):
self.per_batch_losses.append(logs.get("loss"))
def on_epoch_end(self, epoch, logs):
plt.clf()
plt.plot(range(len(self.per_batch_losses)), self.per_batch_losses,
label="Training loss for each batch")
plt.xlabel(f"Batch (epoch {epoch})")
plt.ylabel("Loss")
plt.legend()
plt.savefig(f"plot_at_epoch_{epoch}")
self.per_b
超级会员免费看
订阅专栏 解锁全文
283

被折叠的 条评论
为什么被折叠?



