在本篇文章中,我们将使用推理服务器,并通过一个问答(Question Answering
)NLP 任务观察实时推理的过程。
Triton Inference Server
通过HTTP
和gRPC
接口提供服务。Triton没有使用自己的协议标准,而是遵循了KFServing Predict v2 协议,这样可以与更多当前和未来的服务工具兼容。Triton提供了多种客户端 API:Python API、C++ API和Protobuf API,在本示例中,我们学习如何通过Python API使用问答模型服务。
1 API 基础
首先,导入所需的库:
# 导入必要的模块
import os
import json
import argparse
import numpy as np
import tritonhttpclient
接下来,初始化客户端并连接到服务器:
# 初始化客户端连接 Triton 服务器
try:
triton_client = tritonhttpclient.InferenceServerClient(url="triton:8000", verbose=True)
except Exception as e:
print("channel creation failed: " + str(e))
检查服务器状态和模型是否就绪:
# 检查服务器和模型状态
modelName = "bertQA-torchscript"
print(triton_client.is_server_live())
print(triton_client.is_server_ready())
print(triton_client.is_model_ready(modelName,"1"))
查看服务器的元数据信息:
# 查看服务器元数据
triton_client.get_server_metadata()
除了基础的健康检查与模型推理,API 还支持更细粒度的控制,例如加载/卸载模型。详情请参考官方文档和示例代码。
2 推理 API 总览
由于我们使用的是问答模型,我们将向服务器发送一个示例问题。首先查看服务器期望的输入和输出格式:
# 获取模型元数据,查看输入输出张量
triton_client.get_model_metadata(modelName)
你应该会看到如下内容:
服务器期望接收三个输入张量,分别是 input0(input_ids
)、input_1(sequence_ids
)和 input_2(mask_ids
)。推理完成后,服务器将返回两个输出张量:output0表示起始位置的logits
,output__1 表示结束位置的logits
。
logits
是模型预测结果的原始得分,用来决定哪个位置或哪个类别最有可能是正确答案。
3 构造推理请求
首先,定义问题和上下文:
# 定义问题和上下文段落
question = "Most antibiotics target bacteria and don't affect what class of organisms? "
context = "Within the genitourinary and gastrointestinal tracts, commensal flora serve as biological barriers by " +\
"competing with pathogenic bacteria for food and space and, in some cases, by changing the conditions in " +\
"their environment, such as pH or available iron. This reduces the probability that pathogens will " +\
"reach sufficient numbers to cause illness. However, since most antibiotics non-specifically target bacteria" +\
"and do not affect fungi, oral antibiotics can lead to an overgrowth of fungi and cause conditions such as a" +\
"vaginal candidiasis (a yeast infection). There is good evidence that re-introduction of probiotic flora, such " +\
"as pure cultures of the lactobacilli normally found in unpasteurized yogurt, helps restore a healthy balance of" +\
"microbial populations in intestinal infections in children and encouraging preliminary data in studies on bacterial " +\
"gastroenteritis, inflammatory bowel diseases, urinary tract infection and post-surgical infections. "
然后引入辅助模块进行预处理:
# 添加路径并导入处理工具
import sys
sys.path.insert(0,'/dli/task/client')
from tokenization import BertTokenizer
from inference import preprocess_tokenized_text,parse_answer
使用 tokenizer 对文本进行编码:
# 创建分词器并进行预处理
tokenizer = BertTokenizer("/dli/task/vocab", do_lower_case=True, max_len=512)
doc_tokens = context.split()
query_tokens = tokenizer.tokenize(question)
tensors_for_inference, tokens_for_postprocessing = preprocess_tokenized_text(doc_tokens,
query_tokens,
tokenizer,
max_seq_length=384,
max_query_length=64)
# 构建模型输入张量 (batch size = 1)
dtype = np.int64
input_ids = np.array(tensors_for_inference.input_ids, dtype=dtype)[None,...]
segment_ids = np.array(tensors_for_inference.segment_ids, dtype=dtype)[None,...]
input_mask = np.array(tensors_for_inference.input_mask, dtype=dtype)[None,...]
封装成 Triton 所需的数据格式:
# 构建 Triton 输入对象
inputs = []
inputs.append(tritonhttpclient.InferInput('input__0', [1, len(input_ids[0])], "INT64"))
inputs.append(tritonhttpclient.InferInput('input__1', [1, len(segment_ids[0])], "INT64"))
inputs.append(tritonhttpclient.InferInput('input__2', [1, len(input_mask[0])], "INT64"))
# 将 numpy 数据加载进 InferInput 对象中
inputs[0].set_data_from_numpy(input_ids, binary_data=False)
inputs[1].set_data_from_numpy(segment_ids, binary_data=False)
inputs[2].set_data_from_numpy(input_mask, binary_data=False)
可选地查看某个输入张量:
# 查看张量数据
inputs[0]._get_tensor()
构造输出请求,只取必要字段:
# 指定要获取的输出张量
outputs = []
outputs.append(tritonhttpclient.InferRequestedOutput('output__0', binary_data=False))
outputs.append(tritonhttpclient.InferRequestedOutput('output__1', binary_data=False))
4 发送请求到服务器
将构造好的请求发送到 Triton:
# 发送推理请求到服务器
results = triton_client.infer(modelName,
inputs,
outputs=outputs)
查看结果对象类型:
# 打印结果和输出结构
results
outputs
5 处理服务器响应
提取输出张量并进一步处理:
# 将输出转换为 numpy 格式以便后处理
output0_data = results.as_numpy('output__0')
output1_data = results.as_numpy('output__1')
查看结果:
# 查看起始 logits
output0_data
处理为可读答案:
# 解析 logits 得到最终文本答案
start_logits = output0_data[0].tolist()
end_logits = output1_data[0].tolist()
answer, answers = parse_answer(doc_tokens, tokens_for_postprocessing,
start_logits, end_logits)
# 打印最终结果
print()
print(answer)
print()
print(json.dumps(answers, indent=4))
6 总结
本节内容展示了如何通过Triton Inference Server
使用问答模型进行实时推理。我们首先介绍了Triton的API接口结构,并通过Python客户端完成了从问题构造、文本预处理、构造模型输入、发送请求到服务器、接收输出结果,再到解析logits
并获取最终答案的完整推理流程。通过这一实践,我们掌握了如何将一个深度学习模型以服务形式部署并进行实时调用。