手把手教你使用pytorch+flask搭建草图检索系统(二)
一、提要与预告
- 准备工作 -> 链接
- 后端搭建 -> 本篇内容
- 前端搭建
- 前后端交互
- demo
上篇《手把手教你使用pytorch+flask搭建草图检索系统(一)》介绍了搭建系统的准备工作,以及本系统的检索原理,还有孪生网络SketchTriplet提取特征的过程,预训练好的模型、模型所使用的代码、数据集已经贴在了上篇节尾,大家可以去google drive或者百度云上下载。本篇将介绍检索系统的后端搭建过程,包括使用flask搭建服务器,以及使用pytorch载入模型。已经咕咕咕到了现在,一是学业二是工作,近期我会努力逐步更新的,在此向那些等了半年的同学们说声抱歉。
二. flask搭建服务器
网上教程极多,我看的是这个FLASK的官方中文文档,写得极其详细且易懂,从零到一应有尽有,为了快速上手,我这里就简单讲下我的理解,flask有点像MVC(Module, View, Controller)那一套,它拥有两个固定的文件夹:static
, templates
,static
里面存放着静态网页所加载的资源(如css、image、json等),所以针对某个静态端点static
所使用的图像img_0.png
,它的路径应该是:
url_for('static', filename='img_0.png')
而它在文件系统的位置应该是static/img_0.png
,我们如果要访问这张图像,应该在浏览器中键入127.0.0.1/img_0.png
。静态端点中可以有多个网页,保存在templates
中,因此这个templates
就有点像MVC中的V,而static
则是MVC中的M,根目录下的python文件则是MVC中的C。也就是说,在一个APP中,应该有如下的文件结构:
- base_folder
- static
- img
- css
...
- templates
- 0.html
- 1.html
...
- controller.py
注意,static
, templates
不一定绝对叫这个名字,在flask初始化的时候,你可以随意修改,修改方式如下:
app = Flask(__name__, template_folder='templates', static_folder='static')
2.1 Controller实现
首先,对于手绘检索系统,按照流程来看:首先得有个界面绘制草图,这是前端工作;然后将绘制的草图保存下来,上传到后端,这是前后端交互;接着,后端依据上传的草图进行检索,得到检索结果,这是后端工作;然后将检索结果返回给前端,着是前后端交互;最后将检索结果展示出来,这是前端工作。也就是说,我们需要两次前端、两次交互、一次后端。
按上面的思路,我们先给绘图界面留个坑,在根目录下新建controller.py
,按下面的代码,新建一个flask路由,使服务器能够正常运行、访问:
from flask import Flask, render_template
from datetime import timedelta
# 新建APP
app = Flask(__name__, template_folder='templates', static_folder='static')
# 设置静态文件缓存过期时间
app.send_file_max_age_default = timedelta(seconds=1)
# 新建路由,为绘图界面留个页面
@app.route('/')
def hello():
return render_template('canva.html')
if __name__ == '__main__':
# app.debug = True
app.run(debug=True)
接着在./templates
文件夹内,新建canva.html
静态网页,如下所示:
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<h1>HELLO WORLD! I'm canva</h1>
</head>
<body>
</body>
</html>
好,在这里运行下,试一试效果:直接在pycharm里面运行,或者你也可以python controller.py
,按照输出* Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)
访问,就可以看到绘图界面了,如下图所示:
2.2 添加模型载入
继续在controller.py
内部添加代码,加入upload
函数,如下所示:
from flask import Flask, render_template, request
from datetime import timedelta
from scipy.misc import imsave
import json, os, time, base64
# 加载对应的包
from SketchTriplet.SketchTriplet_half_sharing import BranchNet
from SketchTriplet.SketchTriplet_half_sharing import SketchTriplet as SketchTriplet_hs
from SketchTriplet.flickr15k_dataset import flickr15k_dataset_lite
from SketchTriplet.retrieval import retrieval
# 定义载入函数
def load_model_retrieval():
# 模型相对路径
net_dict_path = '../SketchTriplet/500.pth'
branch_net = BranchNet() # for photography edge
net = SketchTriplet_hs(branch_net)
net.load_state_dict(torch.load(net_dict_path))
net = net.cuda()
net.eval()
return net
#-----------------------------------------
# 加载flickr15k数据集
flickr15k_dataset = flickr15k_dataset_lite()
# 加载检索模型
retrieval_net = load_model_retrieval()
#-----------------------------------------
# 新建APP
app = Flask(__name__, template_folder='templates', static_folder='static')
# 设置静态文件缓存过期时间
app.send_file_max_age_default = timedelta(seconds=1)
# 新建路由,为绘图界面留个页面
@app.route('/')
def hello():
return render_template('canva.html')
# 新建上传路由,因为有传输行为,所有添加POST, GET方法
@app.route('/upload', methods=['POST', 'GET'])
def upload():
if request.method == 'POST':
# 获取上传的草图
sketch_src = request.form.get("sketchUpload")
# 获取上传成功的flag,要么是鼠标绘制,要么是本地上传
upload_flag = request.form.get("uploadFlag")
# 如果上传失败,返回上传页面
sketch_src_2 = None
if upload_flag:
sketch_src_2 = request.files["uploadSketch"]
if sketch_src:
flag = 1
elif sketch_src_2:
flag = 2
else:
return render_template('upload.html')
# 处理上传的草图
basepath = os.path.dirname(__file__)
upload_path = os.path.join(basepath, 'static/sketch_tmp', 'upload.png')
if flag == 1:
# 鼠标绘制
sketch = base64.b64decode(sketch_src[22:])
user_input = request.form.get("name")
file = open(upload_path,"wb")
file.write(sketch)
file.close()
elif flag == 2:
# 本地上传
sketch_src_2.save(upload_path)
user_input = request.form.get("name")
# 开始检索
retrieval_list, real_path = retrieval(retrieval_net, upload_path, flickr15k_dataset)
# 以json形式包装返回的路径
real_path = json.dumps(real_path)
# 检索成功后,将结果渲染出来
return render_template('retrieval.html', userinput=user_input, val1=time.time(), upload_src=sketch_src, retrieval_list = retrieval_list, json_info = real_path)
# 其它操作,返回上传页面
return render_template('upload.html')
首先,在controller.py
中,我们需要先将模型和数据集载入到系统内存中,这就是flickr15k_dataset_lite()
与load_model_retrieval()
的工作。
接着,我希望在canva.html
绘图界面中,有两种上传草图的方式,一是直接用鼠标手绘,二是通过本地上传草图,因此在代码中,有两个上传行为。
然后,我们需要将上传的图像进行保存,对于第一种鼠标手绘的方式,绘图界面会将绘制结果以Base64的形式编码,于是我使用了base64.b64decode
对回传结果进行解码,然后保存到static/sketch_tmp/upload.png
处,对于第二种本地上传的方式,这里我就偷懒直接拷贝本地了,这实际上是错误的。细看代码,这里就有前后端交互的行为了。
然后,后端拿到了上传的图像,使用retrieval
进行检索,返回检索路径,这里我使用了json对路径进行包装。
最后,将包装好的检索结果、以及上传的手绘图像,一并交给retrieval.html
检索结果页面进行渲染,将检索结果与输入展现在检索结果页面上。这就是整个后端的操作流程。
2.3 后端检索
controller.py
中提到的retrieval
方法,代码如下所示:
from PIL import Image
import numpy as np
def retrieval(net, sketch_path, dataset):
# 用PIL打开输入图像
sketch_src = Image.open(sketch_path).convert('RGB')
# 提取手绘草图的特征
feat_s = extract_feat_sketch(net, sketch_src)
# 读入自然图像flickr15k的特征集
feat_photo_path = '../SketchTriplet/feat.npz'
feat_photo = np.load(feat_photo_path)
# 解析文件内容
feat_p = feat_photo['feat']
cls_name_p = feat_photo['cls_name']
cls_num_p = feat_photo['cls_num']
path_p = feat_photo['path']
name_p = feat_photo['name']
# L2计算特征距离
dist_l2 = np.sqrt(np.sum(np.square(feat_s - feat_p), 1))
# 排序
order = np.argsort(dist_l2)
# 按顺序返回相对路径
order_path_p = path_p[order]
# 返回绝对路径
return get_real_path(order_path_p)
三. 总结
本篇讲解了如何使用flask进行后端搭建,简单说明了检索系统从上传手绘草图到计算检索结果的流程。以上就是本篇的全部内容,后续将会介绍前端的搭建过程,以及前后端的交互过程,demo展示,敬请期待。