Torch、Java、Milvus快速搭建以图搜图系统
1 原理概述
以图搜图大致原理(口水话版)
以图搜图,即通过一张图片去匹配数据库中的图片,找到最相似的N张图。在我们普通的搜索系统中,文字匹配的搜索单纯的MySQL数据库就能实现简单的搜索,但是图片就存在很多难点。
1、首先要解决的是图片怎么表达的问题,肯定不会是每个像素点去匹配,而是对图像提取特征。在传统的数字图像处理中,图像的特征有很多:颜色特征、纹理特征、关键点特征、几何特征,可以将具有代表性的特征提取处理归一化后形成一个多维向量去表示图片。在深度学习如火如荼的时代,卷积神经网络能更好的做到特征提取这个工作。
2、特征提取到了,自然而然的就是将每个图片的特征(即一个向量)存入数据库,要搜索一张图片时就去数据库匹配。第二个问题就是如何去匹配图片,两个向量相等?当然不是。我们用距离来表达两个向量的相似程度,距离越近就越相似。距离用得最多的就是欧式距离和余弦距离(简单来说区别就是欧氏距离体现数值上的差异、余弦距离体现方向上的相对差异)。
3、怎么判断两个图片是否相似解决了,通过距离!第三个问题:来一张图时去数据库查询怎么查?一个一个匹配,最后排个序?当然不是!MySQL可以建索引,这个好像建索引也无从下手。这里就需要借助向量搜索引擎了。目前开源的向量搜索引擎还是有很多的,这里采用Milvus这个开源项目实现向量搜索引擎,详细了解的去自行百度。
2、ResNet提取深度特征向量
环境:Pytorch1.1 python3.6 cuda9.0 采用pretrainedmodels库快速搭建ResNet(pip安装即可)
几行代码搭建出一个特征提取网络
from torch.autograd import Variable
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import pretrainedmodels
from PIL import Image
TARGET_IMG_SIZE = 224
img_to_tensor = transforms.ToTensor()
def get_seresnet50():
encoder = pretrainedmodels.se_resnet50()
model = nn.Sequential(encoder.layer0,
encoder.layer1,
encoder.layer2,
encoder.layer3,
encoder.layer4,
encoder.avg_pool # 平均池化,张成一个[batchSize,2048]的特征向量
)
for param in model.parameters():
param.requires_grad = False
model.cuda() # 使用GPU,CPU版去掉
model.eval()
return model
# 特征提取
def extract_feature(model, imgpath):
img = Image.open(imgpath) # 读取图片
img = img.resize((TARGET_IMG_SIZE, TARGET_IMG_SIZE))
tensor = img_to_tensor(img) # 将图片矩阵转化成tensor
tensor = tensor.cuda() # GPU
tensor = torch.unsqueeze(tensor, 0)
result = model(Variable(tensor))
result_npy = result.data.cpu().numpy()[0].ravel().tolist()
return result_npy
利用serverSocket搭建服务器端,Java端通信,调用python的特征提取。
import socket
import threading
import json
from model import extract_feature, get_seresnet50
def main():
# 创建服务器套接字
serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# 获取本地主机名称
host = socket.gethostname()
# 设置一个端口
port = 12345
# 将套接字与本地主机和端口绑定
serversocket.bind((host, port))
# 设置监听最大连接数
serversocket.listen(10)
# 模型创建
model = get_seresnet50()
print("等待连接")
while True:
# 获取一个客户端连接
clientsocket, addr = serversocket.accept()
print("连接地址:%s" % str(addr))
try:
t = ServerThreading(model, clientsocket) # 为每一个请求开启一个处理线程
t.start()
except Exception as identifier:
print(identifier)
pass
serversocket.close()
pass
class ServerThreading(threading.Thread):
def