编程深度学习模型并不容易(说实话),但测试一个模型更难。这就是为什么大多数TensorFlow和PyTorch代码没有包含单元测试。但当你的代码将运行在生产环境中时,确保它确实按预期工作应该是优先事项。毕竟,机器学习与其他任何软件并无不同。
请注意,本文是《生产中的深度学习》课程的第三部分,在该课程中,我们将探讨如何将笔记本代码转换为可服务于数百万用户的生产就绪代码。
在本文中,我们将重点讨论如何正确测试机器学习代码,分析编写单元测试时的一些最佳实践,并展示一些测试几乎是必需的示例场景。我们将从为什么需要在代码中引入测试开始,然后快速回顾Python测试的基础知识,接着探讨一些实际场景。
为何需要单元测试
在开发神经网络时,大多数人并不关心捕捉所有可能的异常、发现所有边缘情况或调试每个函数。我们只需要看到模型开始训练(fit)。然后我们只需要提高其准确性,直到达到可接受的点。这都很好,但是当模型被部署到服务器并在实际的面向公众的应用程序中使用时会发生什么?很可能会因为某些用户发送了错误的数据,或者因为某些静默的错误破坏了我们的数据预处理流程而崩溃。我们甚至可能发现我们的模型其实一直都有问题。
这就是单元测试发挥作用的地方。在问题发生之前就预防它们。单元测试非常有用,因为它们:
- 早期发现软件错误
- 调试我们的代码
- 确保代码按预期工作
- 简化重构过程
- 加快集成过程
- 充当文档
别告诉我你至少不想要以上的一些好处。当然,测试可能会占用你宝贵的时间,但这100%是值得的。稍后你就会明白。
但究竟什么是单元测试?
单元测试基础
简单来说,单元测试就是一个函数调用另一个函数(或一个单元),并检查返回的值是否与预期输出匹配。让我们看一个使用我们的UNet模型的例子,让它更清楚。
如果你没有跟随这个系列,可以在我们的GitHub仓库中找到代码。
简而言之,我们采用了一个官方TensorFlow谷歌Colab进行图像分割,并尝试将其转换为高度优化的、生产就绪的代码。查看前两部分([此处]和[此处])。
我们有一个简单的函数,通过将所有像素除以255来归一化图像。
def _normalize(self, input_image, input_mask):
""" Normalise input image
Args:
input_image (tf.image): The input image
input_mask (int): The image mask
Returns:
input_image (tf.image): The normalized input image
input_mask (int): The new image mask
"""
input_image = tf.cast(input_image, tf.float32) / 255.0
input_mask -= 1
return input_image, input_mask
为了确保它完全按照预期工作,我们可以编写另一个使用“_normalize”并检查其结果的函数。它看起来会像这样。
def test_normalize(self):
input_image = np.array([[1., 1.], [1., 1.]])
input_mask = 1
expected_image = np.array([[0.00392157, 0.00392157], [0.00392157, 0.00392157]])
result = self.unet._normalize(input_image, input_mask)
self.assertEquals(expected_image, result[0])
“test_normalize”函数创建一个假的输入图像,使用该图像作为参数调用函数,然后确保结果等于预期的图像。“assertEquals”是一个特殊的函数,来自Python的unittest包(稍后会详细介绍),它的作用正如其名。它断言两个值相等。请注意,你也可以使用下面这样的方式,但使用内置函数有其优势。
assert expected_image == result[0]
就是这样。这就是单元测试。测试既可用于非常小的函数,也可用于跨不同模块的更大、更复杂的功能。
Python中的单元测试
在我们看更多例子之前,我想快速回顾一下Python如何支持单元测试。
Python标准库中主要的测试框架/运行器是unittest。Unittest使用起来非常简单,只有两个要求:将测试放在一个类中,并使用其特殊的assert函数。一个简单的例子如下:
import unittest
class UnetTest(unittest.TestCase):
def test_normalize(self):
. . .
if __name__ == '__main__':
unittest.main()
需要注意的一些事项:
- 我们有一个测试类,其中包含一个作为方法的“testnormalize”函数。通常,测试函数以“test”作为前缀命名,后跟它们要测试的函数名。(这是一个约定,但也启用了unittest的自动发现功能,即库能够自动检测项目或模块中的所有单元测试,这样你就不必一个一个地运行它们。)
- 要运行单元测试,我们调用“unittest.main()”函数,该函数发现模块中的所有测试,运行它们并打印输出。
- 我们的UnetTest类继承“unittest.TestCase”类。这个类帮助我们设置具有不同输入的唯一测试用例,因为它带有“setUp()”和“tearDown()”方法。在setUp()中,我们可以定义可以被所有测试访问的输入,在tearDown()中我们可以释放它们(参见下一章的代码片段)。这很有用,因为所有测试都应该独立运行,通常它们不能共享信息。嗯,现在它们可以了。
另外两个强大的框架是pytest和nose,它们遵循几乎相同的原则。我建议在决定哪种最适合你之前稍微尝试一下。我个人大多数时候使用pytest,因为它感觉更简单一些,并且支持一些很好的功能,如fixtures和测试参数化(这里我不详细说明,你可以查看官方文档了解更多)。但老实说,差别不大,所以使用其中任何一个都可以。
TensorFlow中的测试:tf.test
但在这里我要讨论另一个不太为人所知的。由于我们使用TensorFlow编写模型,我们可以利用“tf.test”,它是unittest的扩展,但包含了针对TensorFlow代码定制的断言(是的,当我发现这一点时也很震惊)。在这种情况下,我们的代码变成了这样:
import tensorflow as tf
class UnetTest(tf.test.TestCase):
def setUp(self):
super(UnetTest, self).setUp()
. . .
def tearDown(self):
pass
def test_normalize(self):
. . .
if __name__ == '__main__':
tf.test.main()
它的基本规则完全相同,需要注意的是我们需要在“setUp”内部调用“super()”函数,这使“tf.test”能够施展它的魔法。很酷吧?
模拟(Mocking)
你应该了解的另一个超级重要的主题是模拟(Mocking)和模拟对象(mock objects)。模拟类和函数在编写Java时非常常见,但在Python中却很少被使用。模拟使得在测试代码时,使用虚拟对象替换复杂逻辑或繁重的依赖项变得非常容易。虚拟对象指的是结构与我们真实对象相同但包含虚假或无意义数据的简单易编码的对象。在我们的例子中,一个虚拟对象可能是一个全为1的2D张量,模拟一个实际的图像(就像第一个代码片段中的“input_image”)。
模拟还有助于我们控制代码的行为并模拟昂贵的调用。让我们再看一个使用我们的UNet的例子。
假设我们想确保数据预处理步骤是正确的,并且我们的代码按预期分割数据并创建训练和测试数据集(一个非常常见的测试用例)。这是我们想要测试的代码:
def load_data(self):
""" Loads and Preprocess data """
self.dataset, self.info = DataLoader().load_data(self.config.data)
self._preprocess_data()
def _preprocess_data(self):
""" Splits into training and test and set training parameters"""
train = self.dataset['train'].map(self._load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
test = self.dataset['test'].map(self._load_image_test)
self.train_dataset = train.cache().shuffle(self.buffer_size).batch(self.batch_size).repeat()
self.train_dataset = self.train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
self.test_dataset = test.batch(self.batch_size)
def _load_image_train(self, datapoint):
""" Loads and preprocess a single training image """
input_image = tf.image.resize(datapoint['image'], (self.image_size, self.image_size))
input_mask = tf.image.resize(datapoint['segmentation_mask'], (self.image_size, self.image_size))
if tf.random.uniform(()) > 0.5:
input_image = tf.image.flip_left_right(input_image)
input_mask = tf.image.flip_left_right(input_mask)
input_image, input_mask = self._normalize(input_image, input_mask)
return input_image, input_mask
def _load_image_test(self, datapoint):
""" Loads and preprocess a single test image"""
input_image = tf.image.resize(datapoint['image'], (self.image_size, self.image_size))
input_mask = tf.image.resize(datapoint['segmentation_mask'], (self.image_size, self.image_size))
input_image, input_mask = self._normalize(input_image, input_mask)
return input_image, input_mask
无需深入代码细节,它实际上所做的是数据分割、一些随机打乱、调整大小和批处理。所以我们想测试这段代码。一切都很好,除了那个该死的加载函数。
self.dataset, self.info = DataLoader().load_data(self.config.data)
难道我们每次运行一个单元测试都要加载整个数据集吗?绝对不是。因此,我们可以模拟那个函数,让它返回一个虚拟数据集,而不是调用真实的函数。模拟来拯救。
我们可以使用unittest的模拟对象包来完成这个。它提供了一个模拟类“Mock()”来直接创建模拟对象,以及一个“patch()”装饰器,用于在我们测试的模块内用一个模拟对象替换导入的模块。由于理解其区别并不简单,我将在末尾为那些想了解更多细节的人留下一篇精彩文章的链接。
对于那些不熟悉的人,装饰器只是一个包装另一个函数以扩展其功能的函数。一旦我们声明了包装函数,我们就可以注释其他函数来增强它们。看到下面的@patch了吗?这是一个用“patch”函数包装“test_load_data”的装饰器。更多信息请参阅文章末尾的链接。
通过使用“patch()”装饰器,我们得到:
@patch('model.unet.DataLoader.load_data')
def test_load_data(self, mock_data_loader):
mock_data_loader.side_effect = dummy_load_data
shape = tf.TensorShape([None, self.unet.image_size, self.unet.image_size, 3])
self.unet.load_data()
mock_data_loader.assert_called()
self.assertItemsEqual(self.unet.train_dataset.element_spec[0].shape, shape)
self.assertItemsEqual(self.unet.test_dataset.element_spec[0].shape, shape)
我能看出你对这个感到惊讶。别试图隐藏。
测试覆盖率
在我们看一些机器学习中特定的测试用例之前,我想提另一个重要方面:覆盖率。覆盖率是指我们的代码有多少实际上被单元测试覆盖了。
覆盖率是一个宝贵的指标,可以帮助我们编写更好的单元测试,发现测试未覆盖的领域,找到新的测试用例,并确保测试的质量。你可以像这样简单地检查你的覆盖率:
- 安装coverage包
$ conda install coverage - 在你的测试文件中运行该包
$ coverage run -m unittest /home/aisummer/PycharmProjects/Deep-Learning-Production-Course/model/tests/unet_test.py - 打印结果
$ coverage report -m /home/aisummer/PycharmProjects/Deep-Learning-Production-Course/model/tests/unet_test.py Name Stmts Miss Cover Missing ------------------------------------------------------------- model/tests/unet_test.py 35 1 97% 52
这表示我们覆盖了97%的代码。总共有35条语句,我们只漏掉了1条。缺失信息告诉我们哪些代码行还需要覆盖(多么方便!)。
测试示例场景
我认为是时候探索一些深度学习的不同场景和代码库中单元测试非常有用的部分了。好吧,我不会为每一个都编写代码,但我认为概述几个用例非常重要。
我们已经讨论过其中之一。确保我们的数据格式正确是至关重要的。我能想到的其他一些是:
数据
- 确保我们的数据格式正确(是的,我为了完整性再次把它放在这里)
- 确保训练标签正确
- 测试我们的复杂处理步骤,如图像操作
- 断言数据的完整性、质量和错误
- 测试特征的分布
训练
- 运行一个训练步骤,并比较权重前后,确保它们已更新
- 检查我们的损失函数是否确实可以用于我们的数据
评估:
- 在迭代不同架构时,使用测试确保你的指标(例如准确率、精确率和召回率)高于某个阈值
- 你可以在训练上运行速度/基准测试以捕捉可能的过拟合
- 当然,交叉验证可以以单元测试的形式进行
模型架构:
- 模型的层确实在堆叠
- 模型的输出形状正确
实际上,让我们编写最后一个来证明给你看有多简单:
def test_ouput_size(self):
shape = (1, self.unet.image_size, self.unet.image_size, 3)
image = tf.ones(shape)
self.unet.build()
self.assertEqual(self.unet.model.predict(image).shape, shape)
就是这样。定义预期的形状,构建一个虚拟输入,构建模型,并进行预测,这就是全部。对于一个如此有用的测试来说,还不错,对吧?你看单元测试不必很复杂。有时几行代码可以让我们免去很多麻烦。相信我。但另一方面,我们也不应该测试每一个可以想象到的东西。这是一个巨大的时间消耗。我们需要找到一个平衡点。
我相信,在开发你自己的模型时,你可以想出更多更多的测试场景。这只是为了大致了解你可以关注的不同领域。
集成/验收测试
我故意避免提及的是集成测试和验收测试。这些类型的测试是非常强大的工具,旨在测试我们的系统与其他系统的集成程度。如果你有一个包含许多服务或客户端/服务器交互的应用程序,验收测试是确保一切在更高层次上按预期工作的首选功能。
在课程后期,当我们将模型部署到服务器时,我们绝对需要编写一些验收测试,因为我们希望确保模型以用户/客户端期望的形式返回他们期望的结果。当我们的应用程序是实时的并为用户提供服务时,我们对其进行迭代,我们不能因为一些愚蠢的错误而导致失败(还记得第一篇文章中的可靠性原则吗?)验收测试帮助我们避免这类事情。
所以让我们暂时搁置它们,等到时机成熟时再处理。为了确保你会在本课程的下一部分发布时收到通知,你可以订阅我们的新闻通讯。
结论
单元测试确实是我们武器库中非常宝贵的工具,特别是在构建复杂的深度学习模型时。我的意思是,我能想到一百万件可能在机器学习应用程序上出错的事情。尽管编写好的测试可能很困难且耗时,但这是你不应该忽视的事情。我的懒惰已经不止一次地反噬我,所以我决定从现在开始总是编写它们。但同样,我们总是需要找到一个平衡点。
然而,单元测试只是使我们的代码生产就绪的方法之一。为了确保我们原始的笔记本代码可以在部署环境中可靠地使用,我们还必须做更多的事情。到目前为止,我们讨论了深度学习的系统设计和编写Python深度学习代码的最佳实践。我们清单上的下一步是为我们的代码库添加日志记录,并学习如何调试我们的TensorFlow代码。等不及了。
到时候见…
参考文献:
-
programiz.com,Python装饰器
-
wikipedia.org,代码覆盖率
-
toptal.com,Python模拟入门
-
realpython.com,Python测试入门
-
披露:请注意,以上部分链接可能是联盟链接,在你点击后决定购买时,我们将在不增加你额外成本的情况下获得佣金。
更多精彩内容 请关注我的个人公众号 公众号(办公AI智能小助手)或者 我的个人博客 https://blog.qife122.com/
对网络安全、黑客技术感兴趣的朋友可以关注我的安全公众号(网络安全技术点滴分享)
90

被折叠的 条评论
为什么被折叠?



