如何测试你的机器学习项目?初学者指南

原文:towardsdatascience.com/how-should-you-test-your-machine-learning-project-a-beginners-guide-2e22da5a9bfc

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/09f7be0e7b67a521f8bbbba62e10844f.png

代码测试,图片由作者提供

引言

测试是软件开发的一个关键组成部分,但根据我的经验,在机器学习项目中却被广泛忽视。很多人都知道他们应该测试代码,但知道如何做并真正去做的人却不多。

本指南旨在向您介绍测试机器学习管道各个部分的基本知识。我们将重点关注对 IMDb 数据集进行文本分类的 BERT 的微调,并使用行业标准库如pytestpytest-cov进行测试。

我强烈建议您遵循这个GitHub 仓库上的代码:

GitHub – FrancoisPorcher/awesome-ai-tutorials:让你成为 AI 教程大师的最佳集合

项目概述

这里是项目的简要概述。

bert-text-classification/
├── src/
│   ├── data_loader.py
│   ├── evaluation.py
│   ├── main.py
│   ├── trainer.py
│   └── utils.py
├── tests/
│   ├── conftest.py
│   ├── test_data_loader.py
│   ├── test_evaluation.py
│   ├── test_main.py
│   ├── test_trainer.py
│   └── test_utils.py
├── models/
│   └── imdb_bert_finetuned.pth
├── environment.yml
├── requirements.txt
├── README.md
└── setup.py

一种常见的做法是将代码分成几个部分:

  • src: 包含我们用来加载数据集、训练和评估模型的主体文件。

  • tests: 它包含不同的 Python 脚本。大多数情况下,每个脚本都有一个测试文件。我个人使用以下约定:如果您想测试的脚本名为XXX.py,则相应的测试脚本名为test_XXX.py,并位于tests文件夹中。

例如,如果您想测试evaluation.py文件,我使用的是test_evaluation.py文件。

NB:在测试文件夹中,您会注意到一个conftest.py文件。这个文件并不是按照常规说法进行测试函数,但它包含一些关于测试的配置信息,特别是fixtures,我们将在稍后解释。

如何开始

您可以阅读这篇文章,但我强烈建议您克隆仓库并开始与代码互动,因为我们通过积极参与学习得更好。为此,您需要克隆 GitHub 仓库,创建一个环境,并获取一个模型。

# clone github repo
git clone https://github.com/FrancoisPorcher/awesome-ai-tutorials/tree/main

# enter corresponding folder
cd MLOps/how_to_test/

# create environment
conda env create -f environment.yml
conda activate how_to_test

您还需要一个模型来运行评估。为了重现我的结果,您可以运行主文件。训练时间应在 2 到 20 分钟之间(取决于您是否有 CUDA、MPS 或 CPU)。

python src/main.py

如果您不想微调 BERT(但我强烈建议您自己微调 BERT),您可以使用 BERT 的库存版本,并添加一个线性层以获得 2 个类别,以下命令:

from transformers import BertForSequenceClassification

model = BertForSequenceClassification.from_pretrained(
            "bert-base-uncased", num_labels=2
        )

现在您已经准备就绪!

让我们编写一些测试:

但首先,让我们快速介绍一下 Pytest。

什么是 Pytest 以及如何使用它?

pytest 是一个行业标准的成熟测试框架,它使得编写测试变得容易。

pytest 的一个很棒之处在于,你可以以不同的粒度级别进行测试:单个函数、脚本或整个项目。让我们学习如何进行这三种选项。

测试看起来是什么样子?

测试是一个测试其他函数行为的函数。惯例是,如果你想测试名为 foo 的函数,你将调用你的测试函数 test_foo

然后,我们定义了几个测试,以检查我们正在测试的函数是否按预期工作。

让我们用一个例子来澄清这些想法:

data_loader.py 脚本中,我们使用了一个非常标准的函数 clean_text,该函数用于删除大写字母和空白字符,定义如下:

def clean_text(text: str) -> str:
    """
    Clean the input text by converting it to lowercase and stripping whitespace.

    Args:
        text (str): The text to clean.

    Returns:
        str: The cleaned text.
    """
    return text.lower().strip()

我们想确保这个函数表现良好,所以我们在 test_data_loader.py 文件中可以写一个名为 test_clean_text 的函数。

from src.data_loader import clean_text

def test_clean_text():
    # test capital letters
    assert clean_text("HeLlo, WoRlD!") == "hello, world!" 
    # test spaces removed
    assert clean_text("  Spaces  ") == "spaces"
    # test empty string
    assert clean_text("") == ""

注意,我们在这里使用了 assert 函数。如果断言为 True,则不会发生任何操作;如果为 False,则会引发 AssertionError

现在让我们调用测试。在终端中运行以下命令。

pytest tests/test_data_loader.py::test_clean_text

这个终端命令意味着你正在使用 pytest 运行测试,最具体的是位于 tests 文件夹中的 test_data_loader.py 脚本,并且你只想运行一个名为 test_clean_text 的测试。

如果测试通过,你应该得到以下结果:

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/688cda3d0e1e3fc6eb987871cd4ffeb5.png

Pytest 测试通过,图片由作者提供

如果测试未通过会发生什么?

为了这个例子,让我们假设我修改了 test_clean_text 函数如下:

def clean_text(text: str) -> str:
    # return text.lower().strip()
    return text.lower()

现在这个函数不再删除空格,将会在测试中失败。这是再次运行测试时我们得到的结果:

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/664287db51c039544034f2382094642e.png

失败的测试示例,图片由作者提供

这次我们知道为什么测试失败了。太棒了!

为什么我们甚至想要测试单个函数?

好吧,测试可能需要花费很多时间。对于这样一个小型项目,评估整个 IMDb 数据集可能已经需要几分钟。有时我们只想测试单个行为,而不必每次都重新测试整个代码库。

现在,让我们转到下一个粒度级别:测试脚本。

如何测试整个脚本?

现在让我们使 data_loader.py 脚本更复杂,并添加一个 tokenize_text 函数,该函数接受一个 字符串 或一个 字符串列表 作为输入,并输出输入的标记化版本。

# src/data_loader.py
import torch
from transformers import BertTokenizer

def clean_text(text: str) -> str:
    """
    Clean the input text by converting it to lowercase and stripping whitespace.

    Args:
        text (str): The text to clean.

    Returns:
        str: The cleaned text.
    """
    return text.lower().strip()

def tokenize_text(
    text: str, tokenizer: BertTokenizer, max_length: int
) -> Dict[str, torch.Tensor]:
    """
    Tokenize a single text using the BERT tokenizer.

    Args:
        text (str): The text to tokenize.
        tokenizer (BertTokenizer): The tokenizer to use.
        max_length (int): The maximum length of the tokenized sequence.

    Returns:
        Dict[str, torch.Tensor]: A dictionary containing the tokenized data.
    """
    return tokenizer(
        text,
        padding="max_length",
        truncation=True,
        max_length=max_length,
        return_tensors="pt",
    )

为了让你更好地理解这个函数的功能,让我们用一个例子来试试:

from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
txt = ["Hello, @! World! qwefqwef"]
tokenize_text(txt, tokenizer=tokenizer, max_length=16)

这将输出以下结果:

{'input_ids': tensor([[ 101, 7592, 1010, 1030,  999, 2088,  999, 1053, 8545, 2546, 4160, 8545,2546,  102,    0,    0]]), 
'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 
'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])}
  • max_length:是序列可以拥有的最大长度。在这种情况下,我们选择了 16,但我们可以看到序列的长度是 14,因此我们可以看到最后两个标记被填充了。

  • input_ids:每个标记被转换为其关联的 id,这些 id 是词汇表中的单词。NB:标记 101 是CLS标记,标记 id 102 是SEP标记。这两个标记标志着句子的开始和结束。阅读《注意力就是你需要的一切》论文以获取更多详细信息。

  • token_type_ids:这并不很重要。如果你输入两个序列作为输入,第二个句子将只有一个值。

  • attention_mask:这告诉模型在自注意力机制中需要关注哪些标记。因为句子是填充的,所以注意力机制不需要关注最后两个标记,所以那里是 0。

现在让我们编写我们的test_tokenize_text函数,该函数将检查tokenize_text函数是否表现正常:

def test_tokenize_text():
    """
    Test the tokenize_text function to ensure it correctly tokenizes text using BERT tokenizer.
    """
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

    # Example input texts
    txt = ["Hello, @! World!", 
           "Spaces    "]

    # Tokenize the text
    max_length = 128
    res = tokenize_text(text=txt, tokenizer=tokenizer, max_length=max_length)

    # let's test that the output is a dictionary and that the keys are correct
    assert all(key in res for key in ["input_ids", "token_type_ids", "attention_mask"]), "Missing keys in the output dictionary."

    # let's check the dimensions of the output tensors
    assert res["input_ids"].shape[0] == len(txt), "Incorrect number of input_ids."
    assert res['input_ids'].shape[1] == max_length, "Incorrect number of tokens."

    # let's check that all the associated tensors are pytorch tensors
    assert all(isinstance(res[key], torch.Tensor) for key in res), "Not all values are PyTorch tensors."

现在让我们运行test_data_loader.py文件的完整测试,该文件现在有两个函数:

  • test_tokenize_text

  • test_clean_text

你可以从终端使用以下命令运行完整测试

pytest tests/test_data_loader.py

你应该得到以下结果:

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/7591b734eb43a048dcb57ece87bc2a4a.png

test_data_loader.py脚本的测试成功,图片由作者提供

恭喜!你现在知道如何测试整个脚本了。让我们继续到最后一步,测试整个代码库。

如何测试整个代码库?

继续同样的推理,我们可以为每个脚本编写其他测试,你应该有一个类似的结构:

├── tests/
│   ├── conftest.py
│   ├── test_data_loader.py
│   ├── test_evaluation.py
│   ├── test_main.py
│   ├── test_trainer.py
│   └── test_utils.py

现在请注意,在这些所有测试函数中,一些变量是恒定的。例如,我们使用的tokenizer在所有脚本中都是相同的。Pytest有一个很好的方式通过Fixtures来处理这个问题。

固定装置是在运行测试之前设置某些上下文或状态以及之后清理的一种方式。它们提供了一个管理测试依赖项并将可重用代码注入测试的机制。

固定装置是通过使用@pytest.fixture装饰器来定义的。

分词器是我们可以使用的良好固定装置的例子。为此,让我们将其添加到位于tests文件夹中的conftest.py文件中:

import pytest
from transformers import BertTokenizer

@pytest.fixture()
def bert_tokenizer():
    """Fixture to initialize the BERT tokenizer."""
    return BertTokenizer.from_pretrained("bert-base-uncased")

现在在test_data_loader.py文件中,我们可以在test_tokenize_text的参数中调用固定装置bert_tokenizer

def test_tokenize_text(bert_tokenizer):
    """
    Test the tokenize_text function to ensure it correctly tokenizes text using BERT tokenizer.
    """
    tokenizer = bert_tokenizer

    # Example input texts
    txt = ["Hello, @! World!", 
           "Spaces    "]

    # Tokenize the text
    max_length = 128
    res = tokenize_text(text=txt, tokenizer=tokenizer, max_length=max_length)

    # let's test that the output is a dictionary and that the keys are correct
    assert all(key in res for key in ["input_ids", "token_type_ids", "attention_mask"]), "Missing keys in the output dictionary."

    # let's check the dimensions of the output tensors
    assert res["input_ids"].shape[0] == len(txt), "Incorrect number of input_ids."
    assert res['input_ids'].shape[1] == max_length, "Incorrect number of tokens."

    # let's check that all the associated tensors are pytorch tensors
    assert all(isinstance(res[key], torch.Tensor) for key in res), "Not all values are PyTorch tensors."

固定装置是一个非常强大且多功能的工具。如果你想了解更多关于它们的信息,官方的文档是你的首选资源。但至少现在,你已经有工具在手,可以覆盖大多数机器学习测试。

让我们从终端使用以下命令运行整个代码库:

pytest tests

你应该得到以下消息:

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/b9c688bbf259a52fc097c376b8a43501.png

使用 Pytest 测试整个代码库,图片由作者提供

恭喜!

如何使用 Pytest-cov 测量测试覆盖率?

在前面的章节中,我们学习了如何测试代码。在大型项目中,测量测试的覆盖率非常重要。换句话说,你的代码有多少是被测试的。

pytest-covpytest的一个插件,它生成测试覆盖率报告。

话虽如此,不要被覆盖率百分比所迷惑。并不是因为您有 100%的覆盖率,您的代码就没有 bug。它只是您用来识别代码哪些部分需要更多测试的工具。

您可以通过以下命令从终端生成覆盖率报告:

pytest --cov=src --cov-report=html tests/

您应该得到这个:

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/94622d48080822312058745ebbbb267d.png

Coverage with pytest-cov,图片由作者提供

让我们看看如何阅读它:

  1. 语句:代码中可执行语句的总数。它计算所有可以执行的代码行,包括条件、循环和函数调用。

  2. 缺失:这表示在测试运行期间未执行的语句数。这些是没有被任何测试覆盖的代码行。

  3. 覆盖率:测试期间执行的总语句的百分比。它是通过将执行语句数除以总语句数来计算的。

  4. 排除:这指的是被明确排除在覆盖率测量之外的代码行。这对于忽略与测试覆盖率不相关的代码很有用,例如调试语句。

我们可以看到main.py文件的覆盖率是 0%,这是正常的,我们没有编写test_main.py文件。

我们还可以看到只有 19%的evaluation代码被测试,这给我们一个关于我们应该首先关注哪里的想法。

恭喜,您已经做到了!

感谢阅读!在您离开之前:

想要更多精彩教程,请查看我在 GitHub 上的AI 教程汇编

GitHub – FrancoisPorcher/awesome-ai-tutorials: The best collection of AI tutorials to make you a…

您应该在我的收件箱中收到我的文章。**在此订阅。*

如果您想访问 Medium 上的优质文章,您只需每月支付 5 美元的会员费。如果您通过我的链接*注册**,您只需支付部分费用,无需额外费用。*


如果您觉得这篇文章有见地且有益,请考虑关注我并为我点赞,以获取更多深入的内容!您的支持帮助我继续创作有助于我们共同理解的内容。

参考文献

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值