import time
from flask import Flask, request, render_template, jsonify
from typing import List, Dict, Optional
from SPARQLWrapper import SPARQLWrapper, JSON
import json
import copy
import os
import logging
from dotenv import load_dotenv
# Neo4j 驱动
from neo4j import GraphDatabase
# 加载 .env 文件中的环境变量
load_dotenv()
# 创建Flask应用实例
app = Flask(__name__)
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 缓存结构:uri -> { status: bool, timestamp: int }
endpoint_health_cache = {}
CACHE_TTL = 300 # 5 分钟缓存时间
# Neo4j 配置
NEO4J_URI = os.getenv("NEO4J_URI")
NEO4J_USER = os.getenv("NEO4J_USER")
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD")
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
# 定义Relation类,表示知识图谱路径中的一个步骤
class Relation:
def __init__(self, start: str, end: str, relation: str, sparql_uri: str):
self.start = start
self.end = end
self.relation = relation
self.sparql_uri = sparql_uri
def to_dict(self):
return {
"start": self.start,
"end": self.end,
"relation": self.relation,
"sparql_uri": self.sparql_uri
}
def __str__(self):
return f"Start: {self.start}\nEnd: {self.end}\nRelation: {self.relation}\nSPARQL Endpoint: {self.sparql_uri}"
# 定义PathInfo类,封装知识图谱路径的所有信息
class PathInfo:
def __init__(self, start_iri: str, end_iri: str, path: List['Relation'], end_class_iri: Optional[str] = None):
self.start_iri = start_iri
self.end_iri = end_iri
self.path = copy.deepcopy(path)
self.end_class_iri = end_class_iri
def __str__(self):
path_str = ', '.join(str(rel) for rel in self.path)
return f"Start: {self.start_iri}\nEnd: {self.end_iri}\nPath: {path_str}\nEnd Class IRI: {self.end_class_iri}"
def __repr__(self):
return f"<PathInfo(start_iri='{self.start_iri}', end_iri='{self.end_iri}', path_length={len(self.path)}, end_class_iri='{self.end_class_iri}')>"
# SPARQL 查询服务
class SPARQLQueryService:
@staticmethod
def get_entity_info(entity_name: str, endpoints: List[str]) -> List[Dict]:
"""
从 SPARQL 端点查询实体信息(如 rdf:type, rdfs:label)
"""
results = []
for endpoint in endpoints:
sparql = SPARQLWrapper(endpoint)
sparql.setQuery(f"""
SELECT DISTINCT ?s ?type WHERE {{
?s a ?type .
?s rdfs:label ?label .
VALUES ?label {{ "{entity_name}"
"{entity_name}"@zh
"{entity_name}"@en
"{entity_name}"@la
"{entity_name}"@ja }}
}}
"""
)
sparql.setReturnFormat(JSON)
try:
response = sparql.query().convert()
if isinstance(response, bytes):
logger.info(f"SPARQL 查询结果({endpoint}): {response.decode('utf-8')}")
else:
logger.info(f"SPARQL 查询结果({endpoint}): {json.dumps(response, indent=2)}")
for result in response["results"]["bindings"]:
results.append({
"entityIri": result["s"]["value"],
"type": result["type"]["value"]
})
except Exception as e:
logger.error(f"SPARQL 查询失败({endpoint}): {e}")
return results
class Neo4jService:
def __init__(self, uri: str, username: str, password: str):
self._driver = GraphDatabase.driver(uri, auth=(username, password))
def close(self):
self._driver.close()
def find_by_start_and_end(self, start_node_name: str, end_node_name: str, depth: int, sparql_uri_list: List[str]) -> List[List[Dict]]:
if not sparql_uri_list:
raise ValueError(f"端点数组不能为空: {json.dumps(sparql_uri_list)}")
result = []
with self._driver.session() as session:
query = f"""
MATCH p=(s:Point)-[r*1..{depth}]-(e:Point)
WHERE s.name = '{start_node_name}' AND e.name = '{end_node_name}'
AND ALL(node IN nodes(p)[1..-1] WHERE node <> s AND node <> e)
AND ALL(rel IN relationships(p) WHERE rel.sparqlURI IN {sparql_uri_list})
RETURN nodes(p) AS nodes, relationships(p) AS relations
"""
print(query)
query_result = session.run(query)
for record in query_result:
node_list = record["nodes"]
rel_list = record["relations"]
path = []
for i, rel in enumerate(rel_list):
start_node = node_list[i]
end_node = node_list[i + 1]
# 确定关系方向
if rel.start_node.element_id == start_node.element_id:
s = start_node["name"]
o = end_node["name"]
else:
s = end_node["name"]
o = start_node["name"]
path.append({
"start": s,
"end": o,
"relation": rel["name"],
"sparql_uri": rel["sparqlURI"]
})
result.append(path)
return result
# SPARQL 生成器
class SparqlGenerator:
@staticmethod
def generate_from_paths(real_paths: List[List[Dict]]) -> Dict:
if not real_paths:
return {"sparql": None, "graph_data": None}
template = """CONSTRUCT {
$construct
}
WHERE {
$where
} LIMIT 20"""
construct = []
where = []
variables = {}
node_ids = {}
graph_data = {"nodes": [], "links": []}
var_counter = 1
for path_index, path in enumerate(real_paths):
for step_index, step in enumerate(path):
s = step["start"]
o = step["end"]
p = step["relation"]
endpoint = step.get("sparql_uri", "http://default/sparql")
# 注册变量和节点
for uri in [s, o]:
if uri not in variables:
variables[uri] = f"?s{var_counter}"
var_counter += 1
if uri not in node_ids:
node_id = len(node_ids)
node_ids[uri] = node_id
graph_data["nodes"].append({
"id": node_id,
"name": uri.split("/")[-1],
"uri": uri,
"isClass": False
})
# 构造 triple
triple = f"{variables[s]} <{p}> {variables[o]} ."
construct.append(triple)
where.append(f"SERVICE SILENT <{endpoint}> {{\n{triple}\n}}")
# 构造 links 数据
graph_data["links"].append({
"source": node_ids[s],
"target": node_ids[o],
"relation": p,
"sparqlEndpoint": endpoint
})
# 去重处理
construct_str = "\n".join(set(construct))
where_str = "\n".join(set(where))
return {
"sparql": template.replace("$construct", construct_str).replace("$where", where_str),
"graph_data": graph_data
}
def check_sparql_endpoint(uri):
"""
检查 SPARQL 端点是否可用(带缓存)
"""
current_time = time.time()
# 如果缓存存在且未过期,直接返回缓存结果
if uri in endpoint_health_cache:
cached = endpoint_health_cache[uri]
if current_time - cached['timestamp'] < CACHE_TTL:
return cached['status']
# 使用轻量 SPARQL 查询
sparql_query = """
SELECT ?s ?p ?o WHERE {
?s ?p ?o
} LIMIT 1
"""
sparql = SPARQLWrapper(uri)
sparql.setQuery(sparql_query)
sparql.setReturnFormat(JSON)
sparql.setTimeout(5) # 设置超时时间
flag = True
try:
results = sparql.query().convert()
if results["results"]["bindings"]:
logger.info(f"Endpoint {uri} is accessible.")
flag = True
else:
logger.warning(f"Endpoint {uri} returned no results.")
flag = False
except Exception as e:
flag = False
if uri in endpoint_health_cache and endpoint_health_cache[uri]['status']:
logger.error(f"端点 {uri} 不可用: {e}")
# 统一更新缓存
endpoint_health_cache[uri] = {
'status': flag,
'timestamp': current_time
}
return flag
# 根路由
@app.route('/')
def index():
return render_template('index.html')
@app.route('/health-check', methods=['POST'])
def health_check():
data = request.get_json()
endpoints = data.get('endpoints', [])
if not endpoints or not isinstance(endpoints, list):
return jsonify({"error": "缺少有效的端点列表"}), 400
# 查询每个端点的健康状态(使用缓存)
status = {}
for uri in endpoints:
status[uri] = check_sparql_endpoint(uri)
return jsonify({"status": status})
# 处理路径生成与查询
@app.route('/generate', methods=['POST'])
def generate():
try:
data = request.json
# Step 1: 提取参数
start_name = data['startName']
end_name = data['endName']
depth = data.get('depth', 3)
sparql_endpoints = data.get('sparqlEndpoints', [])
if not sparql_endpoints:
return jsonify({"success": False, "error": "必须指定至少一个 SPARQL 端点"})
# Step 2: 查询 SPARQL 获取实体信息(模拟)
start_entity_info = SPARQLQueryService.get_entity_info(start_name, sparql_endpoints)
end_entity_info = SPARQLQueryService.get_entity_info(end_name, sparql_endpoints)
if not start_entity_info:
return jsonify({"success": False, "error": f"无法获取实体{start_name}信息"})
if not end_entity_info:
return jsonify({"success": False, "error": f"无法获取实体{end_name}信息"})
start_iri = start_entity_info[0]["entityIri"]
end_iri = end_entity_info[0]["entityIri"]
start_type = start_entity_info[0]["type"]
end_type = end_entity_info[0]["type"]
# Step 3: 查询 Neo4j 获取真实路径
neo4j_service = Neo4jService(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)
real_paths = neo4j_service.find_by_start_and_end(start_type, end_type, depth, sparql_endpoints)
neo4j_service.close()
if not real_paths:
return jsonify({"success": True, "graph_data": {"nodes": [], "links": []}, "sparql": ""})
# Step 4: 根据真实路径生成 SPARQL 查询和 graph_data
result = SparqlGenerator.generate_from_paths(real_paths)
# Step 5: 添加 start_iri 和 end_iri 到返回结果
result["startIri"] = start_iri
result["endIri"] = end_iri
return jsonify({"success": True, **result})
except Exception as e:
return jsonify({"success": False, "error": str(e)})
if __name__ == '__main__':
app.run(debug=True)
代码返回的图是一个由uri构成的图,最终想要的效果是根据uri查询到节点的具体信息,图中显示节点的名称例如地不容等,鼠标悬浮节点的时候显示所有信息