2026-03-31 17:18:30 +08:00
|
|
|
|
"""
|
|
|
|
|
|
Text-to-Cypher + 自然语言回答生成。
|
|
|
|
|
|
使用 Claude API(claude-3-5-haiku 生成 Cypher,claude-3-5-sonnet 生成回答)。
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
|
import re
|
|
|
|
|
|
import anthropic
|
|
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
|
|
from graph_query import run_query
|
|
|
|
|
|
from prompts import CYPHER_SYSTEM_PROMPT, ANSWER_SYSTEM_PROMPT
|
|
|
|
|
|
|
|
|
|
|
|
load_dotenv()
|
|
|
|
|
|
|
|
|
|
|
|
_client = None
|
|
|
|
|
|
|
|
|
|
|
|
# 禁止写操作的关键字
|
|
|
|
|
|
_WRITE_PATTERN = re.compile(
|
|
|
|
|
|
r'\b(CREATE|DELETE|SET|MERGE|REMOVE|DROP|DETACH|CALL)\b',
|
|
|
|
|
|
re.IGNORECASE,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_client() -> anthropic.Anthropic:
|
|
|
|
|
|
global _client
|
|
|
|
|
|
if _client is None:
|
2026-03-31 19:07:20 +08:00
|
|
|
|
_client = anthropic.Anthropic(base_url=os.getenv("ANTHROPIC_BASE_URL"))
|
2026-03-31 17:18:30 +08:00
|
|
|
|
return _client
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _generate_cypher(question: str, error_hint: str = "") -> str:
|
|
|
|
|
|
user_msg = question
|
|
|
|
|
|
if error_hint:
|
|
|
|
|
|
user_msg += f"\n\n上次生成的 Cypher 执行出错:{error_hint}\n请修正后重新生成。"
|
|
|
|
|
|
|
|
|
|
|
|
resp = _get_client().messages.create(
|
|
|
|
|
|
model="claude-haiku-4-5",
|
|
|
|
|
|
max_tokens=512,
|
|
|
|
|
|
system=CYPHER_SYSTEM_PROMPT,
|
|
|
|
|
|
messages=[{"role": "user", "content": user_msg}],
|
|
|
|
|
|
)
|
|
|
|
|
|
raw = resp.content[0].text.strip()
|
|
|
|
|
|
|
|
|
|
|
|
# 去掉可能的 markdown 代码块
|
|
|
|
|
|
raw = re.sub(r"^```(?:cypher)?\s*", "", raw, flags=re.IGNORECASE)
|
|
|
|
|
|
raw = re.sub(r"\s*```$", "", raw)
|
|
|
|
|
|
return raw.strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _format_results(results: list[dict]) -> str:
|
|
|
|
|
|
if not results:
|
|
|
|
|
|
return "(无查询结果)"
|
|
|
|
|
|
lines = []
|
|
|
|
|
|
for i, row in enumerate(results[:30], 1):
|
|
|
|
|
|
parts = [f"{k}: {v}" for k, v in row.items()]
|
|
|
|
|
|
lines.append(f"{i}. {', '.join(parts)}")
|
|
|
|
|
|
return "\n".join(lines)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def answer_question(question: str) -> dict:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
{
|
|
|
|
|
|
"answer": str, # 自然语言回答
|
|
|
|
|
|
"cypher": str | None, # 执行的 Cypher
|
|
|
|
|
|
"results": list[dict], # 原始查询结果(最多20条)
|
|
|
|
|
|
}
|
|
|
|
|
|
"""
|
|
|
|
|
|
# ── Step 1: 生成 Cypher ──────────────────────────────
|
|
|
|
|
|
cypher = _generate_cypher(question)
|
|
|
|
|
|
|
|
|
|
|
|
if cypher.upper() == "UNSUPPORTED":
|
|
|
|
|
|
return {
|
|
|
|
|
|
"answer": "抱歉,这个问题超出了知识图谱的范围。"
|
|
|
|
|
|
"你可以询问人物行踪、势力控制区域、地点事件等相关问题。",
|
|
|
|
|
|
"cypher": None,
|
|
|
|
|
|
"results": [],
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 必须以 MATCH 开头
|
|
|
|
|
|
if not re.match(r"^\s*MATCH\b", cypher, re.IGNORECASE):
|
|
|
|
|
|
return {
|
|
|
|
|
|
"answer": "生成的查询语句格式有误,请换一种方式提问。",
|
|
|
|
|
|
"cypher": cypher,
|
|
|
|
|
|
"results": [],
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 安全检查:禁止写操作
|
|
|
|
|
|
if _WRITE_PATTERN.search(cypher):
|
|
|
|
|
|
return {
|
|
|
|
|
|
"answer": "生成的查询包含不允许的写操作,已拒绝执行。",
|
|
|
|
|
|
"cypher": cypher,
|
|
|
|
|
|
"results": [],
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# ── Step 2: 执行查询 ─────────────────────────────────
|
|
|
|
|
|
try:
|
|
|
|
|
|
results = run_query(cypher)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
# 出错后重试一次,附上错误提示
|
|
|
|
|
|
cypher = _generate_cypher(question, error_hint=str(e))
|
|
|
|
|
|
try:
|
|
|
|
|
|
results = run_query(cypher)
|
|
|
|
|
|
except Exception as e2:
|
|
|
|
|
|
return {
|
|
|
|
|
|
"answer": f"查询执行失败,请尝试换一种方式提问。(错误:{e2})",
|
|
|
|
|
|
"cypher": cypher,
|
|
|
|
|
|
"results": [],
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# ── Step 3: 生成自然语言回答 ─────────────────────────
|
|
|
|
|
|
results_text = _format_results(results)
|
|
|
|
|
|
|
|
|
|
|
|
resp = _get_client().messages.create(
|
|
|
|
|
|
model="claude-sonnet-4-5",
|
|
|
|
|
|
max_tokens=1024,
|
|
|
|
|
|
system=ANSWER_SYSTEM_PROMPT,
|
|
|
|
|
|
messages=[{
|
|
|
|
|
|
"role": "user",
|
|
|
|
|
|
"content": f"用户问题:{question}\n\n知识图谱查询结果:\n{results_text}",
|
|
|
|
|
|
}],
|
|
|
|
|
|
)
|
|
|
|
|
|
answer = resp.content[0].text.strip()
|
|
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
|
"answer": answer,
|
|
|
|
|
|
"cypher": cypher,
|
|
|
|
|
|
"results": results[:20],
|
|
|
|
|
|
}
|