""" Text-to-Cypher + 自然语言回答生成。 使用 Claude API(claude-3-5-haiku 生成 Cypher,claude-3-5-sonnet 生成回答)。 """ import os import re import anthropic from dotenv import load_dotenv from graph_query import run_query from prompts import CYPHER_SYSTEM_PROMPT, ANSWER_SYSTEM_PROMPT load_dotenv() _client = None # 禁止写操作的关键字 _WRITE_PATTERN = re.compile( r'\b(CREATE|DELETE|SET|MERGE|REMOVE|DROP|DETACH|CALL)\b', re.IGNORECASE, ) def _get_client() -> anthropic.Anthropic: global _client if _client is None: _client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) return _client def _generate_cypher(question: str, error_hint: str = "") -> str: user_msg = question if error_hint: user_msg += f"\n\n上次生成的 Cypher 执行出错:{error_hint}\n请修正后重新生成。" resp = _get_client().messages.create( model="claude-haiku-4-5", max_tokens=512, system=CYPHER_SYSTEM_PROMPT, messages=[{"role": "user", "content": user_msg}], ) raw = resp.content[0].text.strip() # 去掉可能的 markdown 代码块 raw = re.sub(r"^```(?:cypher)?\s*", "", raw, flags=re.IGNORECASE) raw = re.sub(r"\s*```$", "", raw) return raw.strip() def _format_results(results: list[dict]) -> str: if not results: return "(无查询结果)" lines = [] for i, row in enumerate(results[:30], 1): parts = [f"{k}: {v}" for k, v in row.items()] lines.append(f"{i}. {', '.join(parts)}") return "\n".join(lines) def answer_question(question: str) -> dict: """ Returns: { "answer": str, # 自然语言回答 "cypher": str | None, # 执行的 Cypher "results": list[dict], # 原始查询结果(最多20条) } """ # ── Step 1: 生成 Cypher ────────────────────────────── cypher = _generate_cypher(question) if cypher.upper() == "UNSUPPORTED": return { "answer": "抱歉,这个问题超出了知识图谱的范围。" "你可以询问人物行踪、势力控制区域、地点事件等相关问题。", "cypher": None, "results": [], } # 必须以 MATCH 开头 if not re.match(r"^\s*MATCH\b", cypher, re.IGNORECASE): return { "answer": "生成的查询语句格式有误,请换一种方式提问。", "cypher": cypher, "results": [], } # 安全检查:禁止写操作 if _WRITE_PATTERN.search(cypher): return { "answer": "生成的查询包含不允许的写操作,已拒绝执行。", "cypher": cypher, "results": [], } # ── Step 2: 执行查询 ───────────────────────────────── try: results = run_query(cypher) except Exception as e: # 出错后重试一次,附上错误提示 cypher = _generate_cypher(question, error_hint=str(e)) try: results = run_query(cypher) except Exception as e2: return { "answer": f"查询执行失败,请尝试换一种方式提问。(错误:{e2})", "cypher": cypher, "results": [], } # ── Step 3: 生成自然语言回答 ───────────────────────── results_text = _format_results(results) resp = _get_client().messages.create( model="claude-sonnet-4-5", max_tokens=1024, system=ANSWER_SYSTEM_PROMPT, messages=[{ "role": "user", "content": f"用户问题:{question}\n\n知识图谱查询结果:\n{results_text}", }], ) answer = resp.content[0].text.strip() return { "answer": answer, "cypher": cypher, "results": results[:20], }