编写自定义的训练循环和评估循环
fit()工作流程在易用性和灵活性之间实现了很好的平衡。在大多数情况下会用到它。然而,即使有了自定义指标、自定义损失函数和自定义回调函数,它也无法实现深度学习研究人员想做的一切事情。毕竟,内置的fit()工作流程只针对于监督学习(supervised learning)。监督学习是指,已知与输入数据相关联的目标(也叫标签或注释),将损失计算为这些目标和模型预测值的函数。然而,并非所有机器学习任务都属于这个类别。还有一些机器学习任务没有明确的目标,比如生成式学习(generative learning)、自监督学习(self-supervised learning,目标是从输入中得到的)和强化学习(reinforcement learning,学习由偶尔的“奖励”驱动,就像训练狗一样)。即使是常规的监督学习,研究人员也可能想添加一些新奇的附加功能,需要用到低阶灵活性。
如果发现内置的fit()不够用,那么就需要编写自定义的训练逻辑。我们已经遇到过低阶训练循环的简单示例。提醒一下,典型的训练循环包含以下内容:(1)在梯度带中运行前向传播(计算模型输出),得到当前数据批量的损失值;(2)检索损失相对于模型权重的梯度;(3)更新模型权重,以降低当前数据批量的损失值。这些步骤需要对多个批量重复进行。这基本上就是fit()在后台所做的工作。本节将从头开始重新实现fit(),你将了解编写任意训练算法所需的全部知识。我们来看一下具体细节。
训练与推断
在前面的低阶训练循环示例中,步骤1(前向传播)是通过predictions =model(inputs)完成的,步骤2(检索梯度带计算的梯度)是通过gradients =tape.gradient(loss, model.weights)完成的。在一般情况下,还有两个细节需要考虑。某些Keras层(如Dropout层),在训练过程和推断过程(将其用于预测时)中具有不同的行为。这些层的call()方法中有一个名为training的布尔参数。调用dropout(inputs, training=True)将舍弃一些激活单元,而调用dropout(inputs,training=False)则不会舍弃。推而广之,函数式模型和序贯模型的call()方法也有这个training参数。在前向传播中调用Keras模型时,一定要记得传入training=True。也就是说,前向传播应该变成predictions = model(inputs, training=True)。
此外请注意,检索模型权重的梯度时,不应使用tape.gradients(loss,model.weights),而应使用tape.gradients(loss, model.trainable_weights)。层和模型具有以下两种权重。可训练权重(trainable weight):通过反向传播对这些权重进行更新,以便将模型损失最小化。比如,Dense层的核和偏置就是可训练权重。不可训练权重(non-trainable weight):在前向传播过程中,这些权重所在的层对它们进行更新。如果你想自定义一层,用于记录该层处理了多少个批量,那么这一信息需要存储在一个不可训练权重中。每处理一个批量,该层将计数器加1。在Keras的所有内置层中,唯一具有不可训练权重的层是BatchNormalization层。BatchNormalization层需要使用不可训练权重,以便跟踪关于传入数据的均值和标准差的信息,从而实时进行特征规范化。将这两个细节考虑在内,监督学习的训练步骤如下所示。
def train_step(inputs, targets):
with tf.GradientTape() as tape:
predictions = model(inputs, training=True)
loss = loss_fn(targets, predictions)
gradients = tape.gradients(loss, model.trainable_weights)
optimizer.apply_gradients(zip(model.trainable_weights, gradients))
指标的低阶用法
在低阶训练循环中,可能会用到Keras指标(无论是自定义指标还是内置指标)。你已经了解了指标API:只需对每一个目标和预测值组成的批量调用update_state(y_true, y_pred),然后使用result()查询当前指标值。
metric = keras.metrics.SparseCategoricalAccuracy()
targets = [0, 1, 2]
predictions = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]
metric.update_state(targets, predictions)
current_result = metric.result()
print(f"result: {current_result:.2f}")
可能还需要跟踪某个标量值(比如模型损失)的均值。这可以通过keras.metrics.Mean指标来实现。
values = [0, 1, 2, 3, 4]
mean_tracker = keras.metrics.Mean()
for value in values:
mean_tracker.update_state(value)
print(f"Mean of values: {mean_tracker.result():.2f}")
如果想重置当前结果(在一轮训练开始时或评估开始时),记得使用metric.reset_state()。
完整的训练循环和评估循环
我们将前向传播、反向传播和指标跟踪组合成一个类似于fit()的训练步骤函数。这个函数接收数据和目标组成的批量,并返回由fit()进度条显示的日志。
代码清单 逐步编写训练循环:训练步骤函数
model = get_mnist_model()
loss_fn = keras.losses.SparseCategoricalCrossentropy() #准备损失函数
optimizer = keras.optimizers.RMSprop() #准备优化器
metrics = [keras.metrics.SparseCategoricalAccuracy()] #准备需监控的指标列表
loss_tracking_metric = keras.metrics.Mean() ←----准备Mean指标跟踪器来跟踪损失均值
def train_step(inputs, targets):
with tf.GradientTape() as tape: # (本行及以下2行)运行前向传播。注意,这里传入了training=True
predictions = model(inputs, training=True)
loss = loss_fn(targets, predictions)
gradients = tape.gradient(loss, model.trainable_weights) # (本行及以下1行)运行反向传播。注意,这里使用了model.trainable_weights
optimizer.apply_gradients(zip(gradients, model.trainable_weights))
logs = {} # (本行及以下3行)跟踪指标
for metric in metrics:
metric.update_state(targets, predictions)
logs[metric.name] = metric.result()
loss_tracking_metric.update_state(loss) # (本行及以下1行)跟踪损失均值
logs["loss"] = loss_tracking_metric.result()
return logs #返回当前的指标和损失值
在每轮开始时和进行评估之前,需要重置指标的状态。有一个实用函数可以实现这项操作。
代码清单 逐步编写训练循环:重置指标
def reset_metrics():
for metric in metrics:
metric.reset_state()
loss_tracking_metric.reset_state()
现在可以编写完整的训练循环了。请注意,我们使用tf.data.Dataset对象将NumPy数据转换为一个迭代器,以大小为32的批量来迭代数据。
代码清单 逐步编写训练循环:循环本身
training_dataset = tf.data.Dataset.from_tensor_slices(
(train_images, train_labels))
training_dataset = training_dataset.batch(32)
epochs = 3
for epoch in range(epochs):
reset_metrics()
for inputs_batch, targets_batch in training_dataset:
logs = train_step(inputs_batch, targets_batch)
print(f"Results at the end of epoch {epoch}")
for key, value in logs.items():
print(f"...{key}: {value:.4f}")
接下来是评估循环:一个简单的for循环,重复调用test_step()函数。该函数用于处理一批数据。test_step()函数只是train_step()逻辑的子集。它省略了处理更新模型权重的代码,即所有涉及GradientTape和优化器的代码。
代码清单 逐步编写评估循环
def test_step(inputs, targets):
predictions = model(inputs, training=False) ←----注意,这里传入了training=False
loss = loss_fn(targets, predictions)
logs = {}
for metric in metrics:
metric.update_state(targets, predictions)
logs["val_" + metric.name] = metric.result()
loss_tracking_metric.update_state(loss)
logs["val_loss"] = loss_tracking_metric.result()
return logs
val_dataset = tf.data.Dataset.from_tensor_slices((val_images, val_labels))
val_dataset = val_dataset.batch(32)
reset_metrics()
for inputs_batch, targets_batch in val_dataset:
logs = test_step(inputs_batch, targets_batch)
print("Evaluation results:")
for key, value in logs.items():
print(f"...{key}: {value:.4f}")
刚刚重新实现了fit()和evaluate(),或者说几乎重新实现,因为fit()和evaluate()还支持更多功能,包括大规模分布式计算,这需要更多的代码。fit()和evaluate()还包含几项关键的性能优化措施。