手把手教你使用pytorch+flask搭建草图检索系统(二)

手把手教你使用pytorch+flask搭建草图检索系统(二)



一、提要与预告

  • 准备工作 -> 链接
  • 后端搭建 -> 本篇内容
  • 前端搭建
  • 前后端交互
  • demo

上篇《手把手教你使用pytorch+flask搭建草图检索系统(一)》介绍了搭建系统的准备工作,以及本系统的检索原理,还有孪生网络SketchTriplet提取特征的过程,预训练好的模型、模型所使用的代码、数据集已经贴在了上篇节尾,大家可以去google drive或者百度云上下载。本篇将介绍检索系统的后端搭建过程,包括使用flask搭建服务器,以及使用pytorch载入模型。已经咕咕咕到了现在,一是学业二是工作,近期我会努力逐步更新的,在此向那些等了半年的同学们说声抱歉。

二. flask搭建服务器

网上教程极多,我看的是这个FLASK的官方中文文档,写得极其详细且易懂,从零到一应有尽有,为了快速上手,我这里就简单讲下我的理解,flask有点像MVC(Module, View, Controller)那一套,它拥有两个固定的文件夹:static, templatesstatic里面存放着静态网页所加载的资源(如css、image、json等),所以针对某个静态端点static所使用的图像img_0.png,它的路径应该是:

url_for('static', filename='img_0.png')

而它在文件系统的位置应该是static/img_0.png,我们如果要访问这张图像,应该在浏览器中键入127.0.0.1/img_0.png。静态端点中可以有多个网页,保存在templates中,因此这个templates就有点像MVC中的V,而static则是MVC中的M,根目录下的python文件则是MVC中的C。也就是说,在一个APP中,应该有如下的文件结构:

- base_folder
	- static
		- img
		- css
		...
	- templates
		- 0.html
		- 1.html
		...
	- controller.py

注意static, templates不一定绝对叫这个名字,在flask初始化的时候,你可以随意修改,修改方式如下:

app = Flask(__name__, template_folder='templates', static_folder='static')

2.1 Controller实现

首先,对于手绘检索系统,按照流程来看:首先得有个界面绘制草图,这是前端工作;然后将绘制的草图保存下来,上传到后端,这是前后端交互;接着,后端依据上传的草图进行检索,得到检索结果,这是后端工作;然后将检索结果返回给前端,着是前后端交互;最后将检索结果展示出来,这是前端工作。也就是说,我们需要两次前端、两次交互、一次后端。

按上面的思路,我们先给绘图界面留个坑,在根目录下新建controller.py,按下面的代码,新建一个flask路由,使服务器能够正常运行、访问:

from flask import Flask, render_template
from datetime import timedelta

# 新建APP
app = Flask(__name__, template_folder='templates', static_folder='static')

# 设置静态文件缓存过期时间
app.send_file_max_age_default = timedelta(seconds=1)

# 新建路由,为绘图界面留个页面
@app.route('/')
def hello():
    return render_template('canva.html')


if __name__ == '__main__':
    # app.debug = True
    app.run(debug=True)

接着在./templates文件夹内,新建canva.html静态网页,如下所示:

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <h1>HELLO WORLD! I'm canva</h1>
</head>
<body>
</body>
</html>

好,在这里运行下,试一试效果:直接在pycharm里面运行,或者你也可以python controller.py,按照输出* Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)访问,就可以看到绘图界面了,如下图所示:

在这里插入图片描述
在这里插入图片描述

2.2 添加模型载入

继续在controller.py内部添加代码,加入upload函数,如下所示:

from flask      import Flask, render_template, request
from datetime   import timedelta
from scipy.misc import imsave
import json, os, time, base64

# 加载对应的包
from SketchTriplet.SketchTriplet_half_sharing import BranchNet
from SketchTriplet.SketchTriplet_half_sharing import SketchTriplet          as SketchTriplet_hs
from SketchTriplet.flickr15k_dataset          import flickr15k_dataset_lite
from SketchTriplet.retrieval                  import retrieval

# 定义载入函数
def load_model_retrieval():
	# 模型相对路径
    net_dict_path = '../SketchTriplet/500.pth'
    branch_net = BranchNet()  # for photography edge
    net = SketchTriplet_hs(branch_net)
    net.load_state_dict(torch.load(net_dict_path))
    net = net.cuda()
    net.eval()
    return net

#-----------------------------------------

# 加载flickr15k数据集
flickr15k_dataset = flickr15k_dataset_lite()

# 加载检索模型
retrieval_net = load_model_retrieval()

#-----------------------------------------

# 新建APP
app = Flask(__name__, template_folder='templates', static_folder='static')

# 设置静态文件缓存过期时间
app.send_file_max_age_default = timedelta(seconds=1)

# 新建路由,为绘图界面留个页面
@app.route('/')
def hello():
    return render_template('canva.html')

# 新建上传路由,因为有传输行为,所有添加POST, GET方法
@app.route('/upload', methods=['POST', 'GET'])
def upload():
    if request.method == 'POST':
    	# 获取上传的草图
        sketch_src = request.form.get("sketchUpload")
        # 获取上传成功的flag,要么是鼠标绘制,要么是本地上传
        upload_flag = request.form.get("uploadFlag")
        # 如果上传失败,返回上传页面
        sketch_src_2 = None
        if upload_flag:
            sketch_src_2 = request.files["uploadSketch"]
        if sketch_src:
            flag = 1
        elif sketch_src_2:
            flag = 2
        else:
            return render_template('upload.html')

        # 处理上传的草图
        basepath = os.path.dirname(__file__)
        upload_path = os.path.join(basepath, 'static/sketch_tmp', 'upload.png')
        if flag == 1:
            # 鼠标绘制
            sketch = base64.b64decode(sketch_src[22:])
            user_input = request.form.get("name")
            file = open(upload_path,"wb")
            file.write(sketch)
            file.close()

        elif flag == 2:
            # 本地上传
            sketch_src_2.save(upload_path)
            user_input = request.form.get("name")

        # 开始检索
        retrieval_list, real_path = retrieval(retrieval_net, upload_path, flickr15k_dataset)
        # 以json形式包装返回的路径
        real_path = json.dumps(real_path)
        # 检索成功后,将结果渲染出来
        return render_template('retrieval.html', userinput=user_input, val1=time.time(), upload_src=sketch_src, retrieval_list = retrieval_list, json_info = real_path)

    # 其它操作,返回上传页面
    return render_template('upload.html')

首先,在controller.py中,我们需要先将模型和数据集载入到系统内存中,这就是flickr15k_dataset_lite()load_model_retrieval()的工作。
接着,我希望在canva.html绘图界面中,有两种上传草图的方式,一是直接用鼠标手绘,二是通过本地上传草图,因此在代码中,有两个上传行为。
然后,我们需要将上传的图像进行保存,对于第一种鼠标手绘的方式,绘图界面会将绘制结果以Base64的形式编码,于是我使用了base64.b64decode对回传结果进行解码,然后保存到static/sketch_tmp/upload.png处,对于第二种本地上传的方式,这里我就偷懒直接拷贝本地了,这实际上是错误的。细看代码,这里就有前后端交互的行为了。
然后,后端拿到了上传的图像,使用retrieval进行检索,返回检索路径,这里我使用了json对路径进行包装。
最后,将包装好的检索结果、以及上传的手绘图像,一并交给retrieval.html检索结果页面进行渲染,将检索结果与输入展现在检索结果页面上。这就是整个后端的操作流程。

2.3 后端检索

controller.py中提到的retrieval方法,代码如下所示:

from PIL import Image
import numpy as np

def retrieval(net, sketch_path, dataset):
	# 用PIL打开输入图像
    sketch_src = Image.open(sketch_path).convert('RGB')
    # 提取手绘草图的特征
    feat_s = extract_feat_sketch(net, sketch_src)

    # 读入自然图像flickr15k的特征集
    feat_photo_path = '../SketchTriplet/feat.npz'
    feat_photo = np.load(feat_photo_path)
    # 解析文件内容
    feat_p     = feat_photo['feat']
    cls_name_p = feat_photo['cls_name']
    cls_num_p  = feat_photo['cls_num']
    path_p     = feat_photo['path']
    name_p     = feat_photo['name']

    # L2计算特征距离
    dist_l2 = np.sqrt(np.sum(np.square(feat_s - feat_p), 1))
    # 排序
    order   = np.argsort(dist_l2)
    # 按顺序返回相对路径
    order_path_p = path_p[order]

    # 返回绝对路径
    return get_real_path(order_path_p)

三. 总结

本篇讲解了如何使用flask进行后端搭建,简单说明了检索系统从上传手绘草图到计算检索结果的流程。以上就是本篇的全部内容,后续将会介绍前端的搭建过程,以及前后端的交互过程,demo展示,敬请期待。

评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值