添加GraphRAG后端
This commit is contained in:
129
backend/llm_router.py
Normal file
129
backend/llm_router.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""
|
||||
Text-to-Cypher + 自然语言回答生成。
|
||||
使用 Claude API(claude-3-5-haiku 生成 Cypher,claude-3-5-sonnet 生成回答)。
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import anthropic
|
||||
from dotenv import load_dotenv
|
||||
from graph_query import run_query
|
||||
from prompts import CYPHER_SYSTEM_PROMPT, ANSWER_SYSTEM_PROMPT
|
||||
|
||||
load_dotenv()
|
||||
|
||||
_client = None
|
||||
|
||||
# 禁止写操作的关键字
|
||||
_WRITE_PATTERN = re.compile(
|
||||
r'\b(CREATE|DELETE|SET|MERGE|REMOVE|DROP|DETACH|CALL)\b',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def _get_client() -> anthropic.Anthropic:
|
||||
global _client
|
||||
if _client is None:
|
||||
_client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
|
||||
return _client
|
||||
|
||||
|
||||
def _generate_cypher(question: str, error_hint: str = "") -> str:
|
||||
user_msg = question
|
||||
if error_hint:
|
||||
user_msg += f"\n\n上次生成的 Cypher 执行出错:{error_hint}\n请修正后重新生成。"
|
||||
|
||||
resp = _get_client().messages.create(
|
||||
model="claude-haiku-4-5",
|
||||
max_tokens=512,
|
||||
system=CYPHER_SYSTEM_PROMPT,
|
||||
messages=[{"role": "user", "content": user_msg}],
|
||||
)
|
||||
raw = resp.content[0].text.strip()
|
||||
|
||||
# 去掉可能的 markdown 代码块
|
||||
raw = re.sub(r"^```(?:cypher)?\s*", "", raw, flags=re.IGNORECASE)
|
||||
raw = re.sub(r"\s*```$", "", raw)
|
||||
return raw.strip()
|
||||
|
||||
|
||||
def _format_results(results: list[dict]) -> str:
|
||||
if not results:
|
||||
return "(无查询结果)"
|
||||
lines = []
|
||||
for i, row in enumerate(results[:30], 1):
|
||||
parts = [f"{k}: {v}" for k, v in row.items()]
|
||||
lines.append(f"{i}. {', '.join(parts)}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def answer_question(question: str) -> dict:
|
||||
"""
|
||||
Returns:
|
||||
{
|
||||
"answer": str, # 自然语言回答
|
||||
"cypher": str | None, # 执行的 Cypher
|
||||
"results": list[dict], # 原始查询结果(最多20条)
|
||||
}
|
||||
"""
|
||||
# ── Step 1: 生成 Cypher ──────────────────────────────
|
||||
cypher = _generate_cypher(question)
|
||||
|
||||
if cypher.upper() == "UNSUPPORTED":
|
||||
return {
|
||||
"answer": "抱歉,这个问题超出了知识图谱的范围。"
|
||||
"你可以询问人物行踪、势力控制区域、地点事件等相关问题。",
|
||||
"cypher": None,
|
||||
"results": [],
|
||||
}
|
||||
|
||||
# 必须以 MATCH 开头
|
||||
if not re.match(r"^\s*MATCH\b", cypher, re.IGNORECASE):
|
||||
return {
|
||||
"answer": "生成的查询语句格式有误,请换一种方式提问。",
|
||||
"cypher": cypher,
|
||||
"results": [],
|
||||
}
|
||||
|
||||
# 安全检查:禁止写操作
|
||||
if _WRITE_PATTERN.search(cypher):
|
||||
return {
|
||||
"answer": "生成的查询包含不允许的写操作,已拒绝执行。",
|
||||
"cypher": cypher,
|
||||
"results": [],
|
||||
}
|
||||
|
||||
# ── Step 2: 执行查询 ─────────────────────────────────
|
||||
try:
|
||||
results = run_query(cypher)
|
||||
except Exception as e:
|
||||
# 出错后重试一次,附上错误提示
|
||||
cypher = _generate_cypher(question, error_hint=str(e))
|
||||
try:
|
||||
results = run_query(cypher)
|
||||
except Exception as e2:
|
||||
return {
|
||||
"answer": f"查询执行失败,请尝试换一种方式提问。(错误:{e2})",
|
||||
"cypher": cypher,
|
||||
"results": [],
|
||||
}
|
||||
|
||||
# ── Step 3: 生成自然语言回答 ─────────────────────────
|
||||
results_text = _format_results(results)
|
||||
|
||||
resp = _get_client().messages.create(
|
||||
model="claude-sonnet-4-5",
|
||||
max_tokens=1024,
|
||||
system=ANSWER_SYSTEM_PROMPT,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": f"用户问题:{question}\n\n知识图谱查询结果:\n{results_text}",
|
||||
}],
|
||||
)
|
||||
answer = resp.content[0].text.strip()
|
||||
|
||||
return {
|
||||
"answer": answer,
|
||||
"cypher": cypher,
|
||||
"results": results[:20],
|
||||
}
|
||||
Reference in New Issue
Block a user