130 lines
4.0 KiB
Python
130 lines
4.0 KiB
Python
"""
|
||
Text-to-Cypher + 自然语言回答生成。
|
||
使用 Claude API(claude-3-5-haiku 生成 Cypher,claude-3-5-sonnet 生成回答)。
|
||
"""
|
||
|
||
import os
|
||
import re
|
||
import anthropic
|
||
from dotenv import load_dotenv
|
||
from graph_query import run_query
|
||
from prompts import CYPHER_SYSTEM_PROMPT, ANSWER_SYSTEM_PROMPT
|
||
|
||
load_dotenv()
|
||
|
||
_client = None
|
||
|
||
# 禁止写操作的关键字
|
||
_WRITE_PATTERN = re.compile(
|
||
r'\b(CREATE|DELETE|SET|MERGE|REMOVE|DROP|DETACH|CALL)\b',
|
||
re.IGNORECASE,
|
||
)
|
||
|
||
|
||
def _get_client() -> anthropic.Anthropic:
|
||
global _client
|
||
if _client is None:
|
||
_client = anthropic.Anthropic(base_url=os.getenv("ANTHROPIC_BASE_URL"))
|
||
return _client
|
||
|
||
|
||
def _generate_cypher(question: str, error_hint: str = "") -> str:
|
||
user_msg = question
|
||
if error_hint:
|
||
user_msg += f"\n\n上次生成的 Cypher 执行出错:{error_hint}\n请修正后重新生成。"
|
||
|
||
resp = _get_client().messages.create(
|
||
model="claude-haiku-4-5",
|
||
max_tokens=512,
|
||
system=CYPHER_SYSTEM_PROMPT,
|
||
messages=[{"role": "user", "content": user_msg}],
|
||
)
|
||
raw = resp.content[0].text.strip()
|
||
|
||
# 去掉可能的 markdown 代码块
|
||
raw = re.sub(r"^```(?:cypher)?\s*", "", raw, flags=re.IGNORECASE)
|
||
raw = re.sub(r"\s*```$", "", raw)
|
||
return raw.strip()
|
||
|
||
|
||
def _format_results(results: list[dict]) -> str:
|
||
if not results:
|
||
return "(无查询结果)"
|
||
lines = []
|
||
for i, row in enumerate(results[:30], 1):
|
||
parts = [f"{k}: {v}" for k, v in row.items()]
|
||
lines.append(f"{i}. {', '.join(parts)}")
|
||
return "\n".join(lines)
|
||
|
||
|
||
def answer_question(question: str) -> dict:
|
||
"""
|
||
Returns:
|
||
{
|
||
"answer": str, # 自然语言回答
|
||
"cypher": str | None, # 执行的 Cypher
|
||
"results": list[dict], # 原始查询结果(最多20条)
|
||
}
|
||
"""
|
||
# ── Step 1: 生成 Cypher ──────────────────────────────
|
||
cypher = _generate_cypher(question)
|
||
|
||
if cypher.upper() == "UNSUPPORTED":
|
||
return {
|
||
"answer": "抱歉,这个问题超出了知识图谱的范围。"
|
||
"你可以询问人物行踪、势力控制区域、地点事件等相关问题。",
|
||
"cypher": None,
|
||
"results": [],
|
||
}
|
||
|
||
# 必须以 MATCH 开头
|
||
if not re.match(r"^\s*MATCH\b", cypher, re.IGNORECASE):
|
||
return {
|
||
"answer": "生成的查询语句格式有误,请换一种方式提问。",
|
||
"cypher": cypher,
|
||
"results": [],
|
||
}
|
||
|
||
# 安全检查:禁止写操作
|
||
if _WRITE_PATTERN.search(cypher):
|
||
return {
|
||
"answer": "生成的查询包含不允许的写操作,已拒绝执行。",
|
||
"cypher": cypher,
|
||
"results": [],
|
||
}
|
||
|
||
# ── Step 2: 执行查询 ─────────────────────────────────
|
||
try:
|
||
results = run_query(cypher)
|
||
except Exception as e:
|
||
# 出错后重试一次,附上错误提示
|
||
cypher = _generate_cypher(question, error_hint=str(e))
|
||
try:
|
||
results = run_query(cypher)
|
||
except Exception as e2:
|
||
return {
|
||
"answer": f"查询执行失败,请尝试换一种方式提问。(错误:{e2})",
|
||
"cypher": cypher,
|
||
"results": [],
|
||
}
|
||
|
||
# ── Step 3: 生成自然语言回答 ─────────────────────────
|
||
results_text = _format_results(results)
|
||
|
||
resp = _get_client().messages.create(
|
||
model="claude-sonnet-4-5",
|
||
max_tokens=1024,
|
||
system=ANSWER_SYSTEM_PROMPT,
|
||
messages=[{
|
||
"role": "user",
|
||
"content": f"用户问题:{question}\n\n知识图谱查询结果:\n{results_text}",
|
||
}],
|
||
)
|
||
answer = resp.content[0].text.strip()
|
||
|
||
return {
|
||
"answer": answer,
|
||
"cypher": cypher,
|
||
"results": results[:20],
|
||
}
|