import logging
from flask import Flask, request, jsonify
import numpy as np
import requests
import time
from pymilvus import (
Collection,
connections,
utility,
AnnSearchRequest, # 添加这个导入
Function, FunctionType, # 添加这个导入
WeightedRanker # 添加这个导入
)
from config.config1 import MILVUS_CONFIG, MODEL_CONFIG
# 初始化日志
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler("logs/classfication_query.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger("ClassficationQuery")
app = Flask(__name__)
class VectorServiceClient:
"""HTTP调用模型服务进行向量编码(密集和稀疏)"""
def __init__(self):
self.service_url = MODEL_CONFIG["model_service_url"]
self.timeout = 120
logger.info(f"Using vector service: {self.service_url}")
def batch_encode_dense(self, texts):
"""批量生成密集向量"""
return self._call_vector_service(texts, "dense")
def batch_encode_sparse(self, texts):
"""批量生成稀疏向量"""
return self._call_vector_service(texts, "sparse")
def _call_vector_service(self, texts, vector_type):
"""调用向量服务通用方法"""
try:
if not texts:
return []
payload = {
"texts": texts,
"type": vector_type
}
response = requests.post(
self.service_url,
headers={"Content-Type": "application/json"},
json=payload,
timeout=self.timeout
)
if response.status_code >= 400:
error_detail = f"HTTP {response.status_code} Error: "
try:
error_detail += response.json().get("error", response.text[:500])
except:
error_detail += response.text[:500]
logger.error(f"{vector_type} service error: {error_detail}")
response.raise_for_status()
result = response.json()
if "error" in result:
logger.error(f"Vector service error ({vector_type}): {result['error']}")
raise ValueError(result["error"])
if "vectors" not in result:
logger.error(f"Invalid response from {vector_type} service: vectors not found")
raise ValueError(f"Invalid response from {vector_type} service")
return result["vectors"]
except requests.exceptions.RequestException as e:
logger.error(f"Request to {vector_type} service failed: {str(e)}")
raise
except Exception as e:
logger.error(f"Encoding via {vector_type} service failed: {str(e)}")
raise
class ClassficationSearchHandler:
"""特征值搜索处理器"""
def __init__(self):
self.vector_service = VectorServiceClient()
self.collection = None
self.connect()
self.load_collection()
def connect(self):
"""连接Milvus数据库"""
try:
connections.connect(
host=MILVUS_CONFIG["host"],
port=MILVUS_CONFIG["port"]
)
logger.info(f"Connected to Milvus: {MILVUS_CONFIG['host']}")
except Exception as e:
logger.error(f"Milvus connection failed: {str(e)}")
raise
def load_collection(self):
"""加载集合"""
collection_name = MILVUS_CONFIG["collection_name"]
try:
if not utility.has_collection(collection_name):
logger.error(f"Collection {collection_name} does not exist")
raise ValueError(f"Collection {collection_name} not found")
self.collection = Collection(collection_name)
self.collection.load()
logger.info(f"Loaded collection: {collection_name}")
except Exception as e:
logger.error(f"Failed to load collection: {str(e)}")
raise
def hybrid_search(self, query_text, top_k=5, dense_weight=0.5, sparse_weight=0.5):
"""
执行混合检索(密集向量 + 稀疏向量)
参数:
query_text: 查询文本(特征值)
top_k: 返回结果数量
dense_weight: 密集向量权重
sparse_weight: 稀疏向量权重
返回:
排序后的结果列表
"""
start_time = time.time()
# 1. 编码查询文本
try:
logger.info(f"Encoding query text: '{query_text}'")
# 编码密集向量
dense_vectors = self.vector_service.batch_encode_dense([query_text])
if not dense_vectors or len(dense_vectors) == 0:
logger.error("Dense vector encoding returned empty result")
return []
dense_vector = dense_vectors[0]
logger.info(f"Dense vector generated, length: {len(dense_vector)}")
# 编码稀疏向量
sparse_vectors = self.vector_service.batch_encode_sparse([query_text])
if not sparse_vectors or len(sparse_vectors) == 0:
logger.error("Sparse vector encoding returned empty result")
return []
sparse_vector = sparse_vectors[0]
logger.info(f"Sparse vector generated, length: {len(sparse_vector)}")
except Exception as e:
logger.error(f"Vector encoding failed: {str(e)}")
return []
# 2. 创建搜索请求对象
try:
logger.info(f"dense_vector {dense_vector}")
# 创建密集向量搜索请求
dense_search_req = AnnSearchRequest(
data=[dense_vector], # 注意:需要是二维列表
anns_field="classfication_dense_vector",
param={"metric_type": "IP", "params": {"nprobe": 16}},
limit=top_k * 3, # 获取更多候选结果用于融合
weight=dense_weight
)
logger.info(f"dense_vector {sparse_vector}")
# 创建稀疏向量搜索请求
sparse_search_req = AnnSearchRequest(
data=[sparse_vector], # 注意:需要是二维列表
anns_field="classfication_sparse_vector",
param={"metric_type": "IP", "params": {}}, # 稀疏向量不需要nprobe
limit=top_k * 3,
weight=sparse_weight
)
# 3. 执行混合搜索
logger.info("Executing hybrid search...")
start_search = time.time()
ranker = Function(
name="rrf",
input_field_names=[], # Must be an empty list
function_type=FunctionType.RERANK,
params={
"reranker": "rrf",
"k": 60 # Optional
}
)
# 执行混合搜索
results = self.collection.hybrid_search(
[dense_search_req, sparse_search_req], # 搜索请求列表
rerank=ranker, # 重排策略
limit=top_k, # 最终返回结果数量
output_fields=["matnr", "matkl", "maktx", "classfication"]
)
search_time = time.time() - start_search
logger.info(f"Hybrid search completed in {search_time:.2f}s, found {len(results)} results")
except Exception as e:
logger.error(f"Hybrid search failed: {str(e)}", exc_info=True)
return []
# 4. 处理并返回结果
formatted_results = []
for i, hit in enumerate(results):
entity = hit.entity
formatted_results.append({
"rank": i + 1,
"matnr": entity.get("matnr", ""),
"matkl": entity.get("matkl", ""),
"maktx": entity.get("maktx", ""),
"classfication": entity.get("classfication", ""),
"score": hit.score
})
# 记录前5个结果
if i < 5:
logger.info(f"Result #{i + 1}: MATNR={entity.get('matnr')}, Score={hit.score:.4f}")
total_time = time.time() - start_time
logger.info(f"Total search time: {total_time:.2f}s")
return formatted_results
# 初始化搜索处理器
search_handler = ClassficationSearchHandler()
@app.route('/query_similar_by_classfication', methods=['POST'])
def query_similar_by_classfication():
"""特征值相似度查询接口"""
try:
data = request.json
if not data or "query_text" not in data:
return jsonify({"error": "Missing 'query_text' parameter"}), 400
query_text = data["query_text"]
top_k = data.get("top_k", 5)
dense_weight = data.get("dense_weight", 0.5)
sparse_weight = data.get("sparse_weight", 0.5)
logger.info(f"New query: text='{query_text}', top_k={top_k}, "
f"dense_weight={dense_weight}, sparse_weight={sparse_weight}")
# 执行混合搜索
results = search_handler.hybrid_search(
query_text,
top_k=top_k,
dense_weight=dense_weight,
sparse_weight=sparse_weight
)
if not results:
logger.info("No results found")
return jsonify({"results": []})
return jsonify({"results": results})
except Exception as e:
logger.error(f"API error: {str(e)}", exc_info=True)
return jsonify({"error": "Internal server error"}), 500
@app.route('/health', methods=['GET'])
def health_check():
"""健康检查端点"""
try:
# 检查Milvus连接
if not search_handler.collection:
return jsonify({"status": "down", "reason": "Milvus not connected"}), 500
# 检查模型服务
try:
response = requests.get(f"{MODEL_CONFIG['model_service_url']}/health", timeout=5)
if response.status_code != 200:
return jsonify({"status": "down", "reason": "Model service unavailable"}), 500
except Exception as e:
return jsonify({"status": "down", "reason": f"Model service error: {str(e)}"}), 500
return jsonify({"status": "up"}), 200
except Exception as e:
return jsonify({"status": "down", "reason": str(e)}), 500
if __name__ == '__main__':
try:
logger.info("Starting Classfication Query Service on 0.0.0.0:8081")
app.run(host='0.0.0.0', port=2379, debug=False)
except Exception as e:
logger.error(f"Failed to start service: {str(e)}")这段代码报错Traceback (most recent call last):
File "/data/material_sync/api/similarity_query1.py", line 173, in hybrid_search
dense_search_req = AnnSearchRequest(
TypeError: __init__() got an unexpected keyword argument 'weight'
2025-07-07 16:32:53,300 - INFO - No results found
最新发布