生产环境部署:PyTorch 模型服务与导出
1. 初始服务器实现
我们首先有一个简单的服务器实现,用于加载模型并进行推理。以下是相关代码:
import sys
import torch
import numpy as np
from flask import Flask, request, jsonify
app = Flask(__name__)
model = ... # 这里应该是模型的定义
model.load_state_dict(torch.load(sys.argv[1], map_location='cpu')['model_state'])
model.eval()
def run_inference(in_tensor):
with torch.no_grad():
# LunaModel takes a batch and outputs a tuple (scores, probs)
out_tensor = model(in_tensor.unsqueeze(0))[1].squeeze(0)
probs = out_tensor.tolist()
out = {'prob_malignant': probs[1]}
return out
@app.route("/predict", methods=["POST"])
def predict():
meta = json.load(request.files['meta'])
blob = request.files['blob'].read()
超级会员免费看
订阅专栏 解锁全文
587

被折叠的 条评论
为什么被折叠?



