Files
Novel-Map/backend/llm_router.py
2026-03-31 19:07:20 +08:00

130 lines
4.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Text-to-Cypher + 自然语言回答生成。
使用 Claude APIclaude-3-5-haiku 生成 Cypherclaude-3-5-sonnet 生成回答)。
"""
import os
import re
import anthropic
from dotenv import load_dotenv
from graph_query import run_query
from prompts import CYPHER_SYSTEM_PROMPT, ANSWER_SYSTEM_PROMPT
load_dotenv()
_client = None
# 禁止写操作的关键字
_WRITE_PATTERN = re.compile(
r'\b(CREATE|DELETE|SET|MERGE|REMOVE|DROP|DETACH|CALL)\b',
re.IGNORECASE,
)
def _get_client() -> anthropic.Anthropic:
global _client
if _client is None:
_client = anthropic.Anthropic(base_url=os.getenv("ANTHROPIC_BASE_URL"))
return _client
def _generate_cypher(question: str, error_hint: str = "") -> str:
user_msg = question
if error_hint:
user_msg += f"\n\n上次生成的 Cypher 执行出错:{error_hint}\n请修正后重新生成。"
resp = _get_client().messages.create(
model="claude-haiku-4-5",
max_tokens=512,
system=CYPHER_SYSTEM_PROMPT,
messages=[{"role": "user", "content": user_msg}],
)
raw = resp.content[0].text.strip()
# 去掉可能的 markdown 代码块
raw = re.sub(r"^```(?:cypher)?\s*", "", raw, flags=re.IGNORECASE)
raw = re.sub(r"\s*```$", "", raw)
return raw.strip()
def _format_results(results: list[dict]) -> str:
if not results:
return "(无查询结果)"
lines = []
for i, row in enumerate(results[:30], 1):
parts = [f"{k}: {v}" for k, v in row.items()]
lines.append(f"{i}. {', '.join(parts)}")
return "\n".join(lines)
def answer_question(question: str) -> dict:
"""
Returns:
{
"answer": str, # 自然语言回答
"cypher": str | None, # 执行的 Cypher
"results": list[dict], # 原始查询结果最多20条
}
"""
# ── Step 1: 生成 Cypher ──────────────────────────────
cypher = _generate_cypher(question)
if cypher.upper() == "UNSUPPORTED":
return {
"answer": "抱歉,这个问题超出了知识图谱的范围。"
"你可以询问人物行踪、势力控制区域、地点事件等相关问题。",
"cypher": None,
"results": [],
}
# 必须以 MATCH 开头
if not re.match(r"^\s*MATCH\b", cypher, re.IGNORECASE):
return {
"answer": "生成的查询语句格式有误,请换一种方式提问。",
"cypher": cypher,
"results": [],
}
# 安全检查:禁止写操作
if _WRITE_PATTERN.search(cypher):
return {
"answer": "生成的查询包含不允许的写操作,已拒绝执行。",
"cypher": cypher,
"results": [],
}
# ── Step 2: 执行查询 ─────────────────────────────────
try:
results = run_query(cypher)
except Exception as e:
# 出错后重试一次,附上错误提示
cypher = _generate_cypher(question, error_hint=str(e))
try:
results = run_query(cypher)
except Exception as e2:
return {
"answer": f"查询执行失败,请尝试换一种方式提问。(错误:{e2}",
"cypher": cypher,
"results": [],
}
# ── Step 3: 生成自然语言回答 ─────────────────────────
results_text = _format_results(results)
resp = _get_client().messages.create(
model="claude-sonnet-4-5",
max_tokens=1024,
system=ANSWER_SYSTEM_PROMPT,
messages=[{
"role": "user",
"content": f"用户问题:{question}\n\n知识图谱查询结果:\n{results_text}",
}],
)
answer = resp.content[0].text.strip()
return {
"answer": answer,
"cypher": cypher,
"results": results[:20],
}