检索式聊天机器人,客服系统
操作流程:
原始数据(原始的正确问题对数据):
question1, answer1
question2, answer2
question3, answer3
question4, answer4
question5, answer5
question6, answer6
question7, answer7
…
操作步骤如下 :
0. 使用question数据构建训练文本相似度度量的训练数据,并训练、部署模型。
1. 使用训练好的文本相似度度量模型,对所有的question提取对应的文本特征向量,然后将其保存到数据库中。----> 初始化过程,相当于将所有的问题向量保存
2. 对于当前用户所给的问题Q,使用训练好的文本相似度度量模型将问题Q转换为向量V,然后计算向量V和所有数据库中的问题向量之间的相似度,然后选择相似度最高的问答对(Q&A),如果该最高相似度是高于给定的阈值,那么认为当前问答对中的回复答案A就是当前用户所问问题Q的最佳回复;否则,回复一些提示信息或者采用其它方式产生回复结果。
其它方式:
a. 直接告诉用户,能不能将问题根据具体化,我不能回复现在,学习能力不够。
b. 使用其它模型产生回复结果。
代码如下
import sys
import requests
import pymysql
import numpy as np
import tensorflow as tf
from flask import Flask, request, jsonify
class DBConfig(object):
host = "127.0.0.1"
user = "root"
password = "root"
port = 3306
database = "chat_robot"
charset = "utf8"
class SearchChatroBotParser(object):
def __int__(self,text_embedding_url,config):
self, text_embedding_url=text_embedding_url
self.config = config
tf.logging.info("开始构建数据库连接对象.....")
self.conn = pymysql.connect(host=self.config.host, user=self.config.user,
password=self.config.password, database=self.config.database,
port=self.config.port, charset=self.config.charset,
autocommit=False)
self.select_sql = "SELECT answer,question_vectors FROM tb_question_answer"
self.insert_sql = "INSERT INTO tb_question_answer(question,answer,question_vectors) VALUES(%s,%s,%s)"
tf.logging.info("构建数据库连接对象完成!!!!")
def convert_embedding(self, question):
"""
通过url获取embedding的数据
:param question:
:return:
"""
data = {
"text": question
}
result = requests.post(self.text_embedding_url, data=data)
if result.status_code == 200:
result.encoding = 'utf-8'
result = result.json() # 转换为json格式
if result['code'] == 200:
embedding = result['data'][0]['embedding']
return embedding
else:
raise Exception("获取embedding http请求失败,请检查服务器!!!")
else:
raise Exception("获取embedding http请求失败,请检查服务器!!!")
def _internal_insert_question_and_answer(self, question, answer, embedding, cursor):
'''
插入Q@A问答对到数据库中
:param question: 问题字符串
:param answer: 问题对应答案的字符串
:param embedding: 问题字符串对应的向量
:param cursor: 数据库操作对应的游标
:return:
'''
cursor.execute(self.insert_sql(question, answer, ','.join(map(str,embedding))))
def insrt_question_and_answer(self, data_file, encoding='utf-8-sig'):
'''
将文件中所有问答对,全部保存到数据库中
:param data_file:
:param encoding:
:return:
'''
with open(data_file ,'r', encoding=encoding ) as reader:
with self.conn.cursor() as cursor:
count = 0
for line in reader:
line = line.split()
question , answer = line.split('\t')
#获取问题对应的向量
embedding = self.convert_embedding(question)
#数据填充到数据库中
self._internal_insert_question_and_answer(question, answer, embedding, cursor)
#累加数据
count +=1
if count % 100==0:
self.conn.commit()
print('已经插入100条数据了!!!')
self.conn.commit() #数据提交到数据库
print('所有问答对数据全部插入到数据库中!!!')
@classmethod
def _calc_similarity(cls, all_embeddings, embedding):
'''
计算embedding和所有all_embedding之间的相似度,并将以结果返回
:param allembedding: [N, 128]
:param embedding: [1, 128]
:return:
'''
assert np.shape(embedding)[0]==1 #'维度必须为1'
assert np.shape(embedding)[1]==np.shape(all_embeddings)[1] #'维度必须一致'
#相似度的实现方式1,遍历计算
def _similarity(x, y):
'''
计算x和y之间的夹角余弦相似度
:param x:
:param y:
:return:
'''
x = np.reshape(x, -1)
y = np.reshape(y, -1)
#分子
a = np.sum(x * y)
#分母
b = np.sqrt(np.sum(np.square(x))) * np.sqrt(np.sum(np.square(y)))
return 1.0 * a / b
similarity = []
for other_embedding in all_embeddings:
similarity.append(_similarity(other_embedding,embedding))
return np.asarray(similarity)
#相似度的计算方式2,直接numpy计算
#分子 [N,-1]
a = np.dot(all_embeddings,np.transpose(embedding))
# 分母
b1 = np.sqrt(np.sum(np.square(embedding)))#一个数字
b2 = np.sqrt(np.sum(np.square(all_embeddings),axis=1,keepdims=True)) #[N,1]
return np.reshape( 1.0 * a / (b1*b2), -1)
def fetch_answer(self, question, threshold=0.9):
'''
根据给定的问题,从数据库中获取相似度最高的问答对对应的回复答案(要求相似度大于等于阈值)
:param question: 问题字符串
:param threshold: 对应的阈值
:return:
'''
#1.将问题转换成向量 [128,] ---> [1,128]
question_embedding = np.reshape(self.convert_embedding(question=question), (1, -1))
# 2. 从数据库中加载所有数据,得到向量以及答案
with self.conn.cursor() as cursor:
cursor.execute(self.insert_sql)
embeddings = []
answers = []
for record in cursor.fetchall():
answers.append(record[0])
embeddings.append(list(map(float,record[1].split(','))))
embeddings = np.asarray(embeddings) #[N,128]
# 3. 计算之间的相似度, 得到一个相似度列表([1,128], [N,128]) --> [N,]
similarity = self._calc_similarity(all_embeddings=embeddings,embedding=question_embedding)
# 4. 获取相似度最大索引以及对应的相似度、回复
max_similarity_index = np.argmax(similarity, 0)
max_similarity = similarity[max_similarity_index]
answer = answers[max_similarity_index]
# 5. 根据最大相似度和阈值之间的关系,决定返回值
if max_similarity >= threshold:
return True, max_similarity, answer
else:
return False, max_similarity, None
if __name__=='__main__':
# 获取参数
port = 8082
text_embedding_url = "http://172.16.101.233:8088/fetch/embedding"
if len(sys.argv) > 1:
args = sys.argv[1:]
args_size = len(args)
idx = 0
while idx < args_size:
cur_arg = args[idx]
if cur_arg == '--port':
port = int(args[idx + 1])
elif cur_arg == '--url':
text_embedding_url = args[idx + 1].strip()
idx += 1
print("监听端口号为:{}".format(port))
print("获取Embedding词向量的URL路径为:{}".format(text_embedding_url))
parser = SearchChatBotParser(
text_embedding_url=text_embedding_url,
config=DBConfig() # TODO: 注意,需要给定自己的数据库的连接信息。
)
# 将所有数据插入到数据库中
# parser.insert_question_and_answer(data_file="./data/question_answer.txt")
# result = parser.fetch_answer(question="怎么更改花呗手机号码")
# print(result)
# 一、构建Flask的应用APP
app = Flask(__name__)
app.config['JSON_AS_ASCII'] = False # 返回JSON格式的时候,数据就不会进行编码
@app.route("/")
@app.route("/index")
def index():
return "欢迎您进入检索式聊天机器人的服务端!!!"
@app.route("/fetch/answer", methods=['GET', 'POST'])
@app.route("/fetch/answer/<float:threshold>", methods=['GET', 'POST'])
def fetch_answer(threshold=0.9):
try:
# print("阈值为:{}".format(threshold))
# 1. 获取参数(待预测的文本数据); 如果没有传入,那么默认为None
if request.method == 'GET':
text = request.args.get("text", None)
else:
text = request.form.get("text", None)
# 2. 对text数据进行预测,得到预测结果
pred = parser.fetch_answer(text, threshold=threshold)
# 3. 结果数据处理并返回
if pred[0]:
result = {
'code': 200, # 一般是要给数字用于表示调用返回的结果情况
'msg': '成功!!!',
'data': [{
"answer": pred[2],
"similarity": float(pred[1])
}]
}
else:
result = {
'code': 301,
'msg': '数据库中没有匹配的问题,最高匹配问题的相似度为:{}'.format(pred[1])
}
# 4. 以json的格式返回
return jsonify(result)
except Exception as e:
return jsonify({
'code': 501,
'msg': '服务器出现异常,异常信息为:{}'.format(e)
})
# 二、启动Flask应用
app.run(host="0.0.0.0", port=port)在这里插入代码片
这里是应用到了相似度的计算,将相似问题转成向量,计算相似度,若相似度大于给定的阈值,则返回其Q对应的A。
涉及到将数据导入到数据库中,需要自己给定数据,我这里有自己的一些数据,对于数据的摘取也是一件麻烦事情,需要慢慢获取。
data包里面就是相应的问题答案数据。
你们可以自己对着写写,欢迎留言哦!