from transformers import BertTokenizer, BertModel
import torch
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import util
import time
import numpy as np
def bert_base_cos_sim(sentence1,sentence2,top_k,bert_version='base'):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bert_path='/code/brs-sug-app/bert_model/bert-base-chinese'
padding=True
if bert_version=='small':
bert_path='/code/brs-sug-app/bert_model/chinese_roberta_L12H128'
padding='max_length'
tokenizer = BertTokenizer.from_pretrained(bert_path)
model = BertModel.from_pretrained(bert_path).to(device)
model.eval()
s=time.time()
# 使用tokenizer将文本转换为tokens
input_sentence1 = tokenizer(sentence1, return_tensors="pt", padding=padding, truncation=True)
inputs_sentence2 = tokenizer(sentence2, return_tensors="pt", padding=padding, truncation=True)
# 将数据移到GPU上
input_sentence1 = {key: input_sentence1[key].to(device) for key in input_sentence1}
inputs_sentence2 = {key: inputs_sentence2[key].to(device) for key in inputs_sentence2}
# 使用BERT模型进行编码
with torch.no_grad():
outputs_sentence1 = model(**input_sentence1)
outputs_sentence2 = model(**inputs_sentence2)
# 获取句子的向量表示
sentence1_embedding = torch.mean(outputs_sentence1.last_hidden_state, dim=1) # 取所有词向量的平均值作为句子向量
sentence2_embedding = torch.mean(outputs_sentence2.last_hidden_state, dim=1)
# 计算余弦相似度
# similarity = cosine_similarity(sentence1_embedding.cpu().numpy(), sentence2_embedding.cpu().numpy())
# # 对每行数据按大小取top k
# topk_indices = np.argpartition(-similarity, top_k, axis=1)[:, :top_k] # 获取每行top k元素的索引
# topk_data=[[sentence2[j] for j in i] for i in topk_indices]
# topk_values = np.take_along_axis(similarity, topk_indices, axis=1) # 根据索引获取对应的值
# print(topk_data)
# print(topk_values)
hits = util.semantic_search(sentence1_embedding, sentence2_embedding, score_function=util.cos_sim, top_k=top_k) #util.dot_score
result=[]
for i in range(len(hits)):
res=[]
for h in hits[i]:
id = h["corpus_id"]
score = round(h["score"], 4)
res.append((id,score))
result.append(res)
e=time.time()
print(e-s)
return result
sentence1=['跳舞','吃饭','找对象','王者']
sentence2=['抖音','快手','外卖','饿了么','美团','王者荣耀']
bert_version='small'
top_k=3
res=bert_base_cos_sim(sentence1,sentence2,top_k,bert_version)
print(res)
potorch bert计算句子相似度
最新推荐文章于 2025-02-05 20:53:08 发布