From d2cf97387ba86c956c97115108e5eb6210960400 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BE=99=E6=BE=B3?= Date: Tue, 31 Mar 2026 17:18:30 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0GraphRAG=E5=90=8E=E7=AB=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/.env.example | 4 + backend/app.py | 80 +++++++++++++++ backend/docker-compose.yml | 23 +++++ backend/graph_builder.py | 197 +++++++++++++++++++++++++++++++++++++ backend/graph_query.py | 48 +++++++++ backend/llm_router.py | 129 ++++++++++++++++++++++++ backend/prompts.py | 78 +++++++++++++++ backend/requirements.txt | 5 + backend/run_import.py | 37 +++++++ 9 files changed, 601 insertions(+) create mode 100644 backend/.env.example create mode 100644 backend/app.py create mode 100644 backend/docker-compose.yml create mode 100644 backend/graph_builder.py create mode 100644 backend/graph_query.py create mode 100644 backend/llm_router.py create mode 100644 backend/prompts.py create mode 100644 backend/requirements.txt create mode 100644 backend/run_import.py diff --git a/backend/.env.example b/backend/.env.example new file mode 100644 index 0000000..cf26b35 --- /dev/null +++ b/backend/.env.example @@ -0,0 +1,4 @@ +ANTHROPIC_API_KEY=sk-ant-... +NEO4J_URI=bolt://localhost:7687 +NEO4J_USER=neo4j +NEO4J_PASSWORD=dtmap2024 diff --git a/backend/app.py b/backend/app.py new file mode 100644 index 0000000..d4cd91b --- /dev/null +++ b/backend/app.py @@ -0,0 +1,80 @@ +""" +大唐双龙传 GraphRAG — FastAPI 后端 + +端点: + GET /api/health — 健康检查(含 Neo4j 连通性) + GET /api/stats — 图谱节点/关系统计 + POST /api/import — 触发数据导入(一次性操作) + POST /api/chat — 知识问答(Text-to-Cypher + LLM 回答) +""" + +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel + +from graph_query import get_driver, get_graph_stats +from graph_builder import build_graph +from llm_router import answer_question + +app = FastAPI(title="大唐双龙传 GraphRAG API", version="1.0.0") + +app.add_middleware( + CORSMiddleware, + allow_origins=[ + "http://localhost:5173", # Vite dev server + "http://localhost:4173", # Vite preview + "http://127.0.0.1:5173", + ], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +# ── Models ──────────────────────────────────────────────── + +class ChatRequest(BaseModel): + question: str + + +class ImportRequest(BaseModel): + clear: bool = False # True = 先清空图谱再重新导入 + + +# ── Endpoints ───────────────────────────────────────────── + +@app.get("/api/health") +def health(): + driver = get_driver() + try: + driver.verify_connectivity() + return {"status": "ok", "neo4j": "connected"} + except Exception as e: + raise HTTPException(status_code=503, detail=f"Neo4j 连接失败: {e}") + + +@app.get("/api/stats") +def stats(): + try: + return get_graph_stats() + except Exception as e: + raise HTTPException(status_code=503, detail=str(e)) + + +@app.post("/api/import") +def import_data(req: ImportRequest = ImportRequest()): + """导入所有卷数据到 Neo4j(耗时约 1-3 分钟,请勿重复调用)""" + driver = get_driver() + try: + build_graph(driver, clear=req.clear) + stats = get_graph_stats() + return {"status": "ok", "stats": stats} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/api/chat") +def chat(req: ChatRequest): + if not req.question.strip(): + raise HTTPException(status_code=400, detail="问题不能为空") + return answer_question(req.question) diff --git a/backend/docker-compose.yml b/backend/docker-compose.yml new file mode 100644 index 0000000..99b4ea8 --- /dev/null +++ b/backend/docker-compose.yml @@ -0,0 +1,23 @@ +services: + neo4j: + image: neo4j:5-community + container_name: dt_neo4j + ports: + - "7474:7474" + - "7687:7687" + environment: + NEO4J_AUTH: neo4j/dtmap2024 + NEO4J_server_memory_heap_initial__size: 512m + NEO4J_server_memory_heap_max__size: 1G + volumes: + - neo4j_data:/data + - neo4j_logs:/logs + healthcheck: + test: ["CMD", "neo4j", "status"] + interval: 10s + timeout: 5s + retries: 10 + +volumes: + neo4j_data: + neo4j_logs: diff --git a/backend/graph_builder.py b/backend/graph_builder.py new file mode 100644 index 0000000..8e24e8f --- /dev/null +++ b/backend/graph_builder.py @@ -0,0 +1,197 @@ +""" +JSON → Neo4j 导入脚本。 + +图谱 Schema: + 节点: Character, Location, Faction, Event + 关系: VISITED, CONTROLS, HAS_MEMBER, LEADS, OCCURRED_AT +""" + +import json +from pathlib import Path +from neo4j import Driver + +DATA_DIR = Path(__file__).parent.parent / "data" + + +# ── 工具函数 ────────────────────────────────────────────── + +def _split_characters(name: str) -> list[str]: + """'寇仲 & 徐子陵' → ['寇仲', '徐子陵']""" + return [c.strip() for c in name.split("&") if c.strip()] + + +def _split_leaders(leader: str) -> list[str]: + """'翟让/李密' → ['翟让', '李密'];过滤'未提及'""" + parts = [p.strip() for p in leader.split("/") if p.strip()] + return [p for p in parts if p not in ("未提及", "")] + + +# ── Schema 初始化 ───────────────────────────────────────── + +def setup_schema(driver: Driver): + with driver.session() as s: + s.run("CREATE CONSTRAINT IF NOT EXISTS FOR (n:Character) REQUIRE n.name IS UNIQUE") + s.run("CREATE CONSTRAINT IF NOT EXISTS FOR (n:Location) REQUIRE n.id IS UNIQUE") + s.run("CREATE CONSTRAINT IF NOT EXISTS FOR (n:Faction) REQUIRE n.id IS UNIQUE") + s.run("CREATE CONSTRAINT IF NOT EXISTS FOR (n:Event) REQUIRE n.id IS UNIQUE") + s.run("CREATE INDEX IF NOT EXISTS FOR (e:Event) ON (e.vol)") + s.run("CREATE INDEX IF NOT EXISTS FOR ()-[r:VISITED]-() ON (r.vol)") + s.run("CREATE INDEX IF NOT EXISTS FOR ()-[r:CONTROLS]-() ON (r.vol)") + + +# ── 各类型导入 ──────────────────────────────────────────── + +def _import_locations(session, locations: list[dict]): + for loc in locations: + session.run( + """ + MERGE (l:Location {id: $id}) + SET l.name = $name, + l.type = $type, + l.lat = $lat, + l.lng = $lng + """, + id=loc["id"], + name=loc["name"], + type=loc.get("type", ""), + lat=loc.get("lat"), + lng=loc.get("lng"), + ) + + +def _import_factions(session, factions: list[dict], vol: int): + for f in factions: + session.run( + """ + MERGE (n:Faction {id: $id}) + SET n.name = $name, n.type = $type, n.color = $color + """, + id=f["id"], name=f["name"], + type=f.get("type", ""), color=f.get("color", ""), + ) + + # Faction → CONTROLS → Location + for loc_id in f.get("territory", []): + session.run( + """ + MATCH (fac:Faction {id: $fid}) + MATCH (loc:Location {id: $lid}) + MERGE (fac)-[:CONTROLS {vol: $vol}]->(loc) + """, + fid=f["id"], lid=loc_id, vol=vol, + ) + + # Faction → HAS_MEMBER → Character + for figure in f.get("key_figures", []): + if not figure: + continue + session.run( + """ + MERGE (c:Character {name: $name}) + WITH c + MATCH (fac:Faction {id: $fid}) + MERGE (fac)-[:HAS_MEMBER {vol: $vol}]->(c) + """, + name=figure, fid=f["id"], vol=vol, + ) + + # Character → LEADS → Faction + for leader_name in _split_leaders(f.get("leader", "")): + session.run( + """ + MERGE (c:Character {name: $name}) + WITH c + MATCH (fac:Faction {id: $fid}) + MERGE (c)-[:LEADS {vol: $vol}]->(fac) + """, + name=leader_name, fid=f["id"], vol=vol, + ) + + +def _import_routes(session, routes: list[dict], vol: int): + for route in routes: + char_color = route.get("color", "") + char_names = _split_characters(route["character"]) + + for char_name in char_names: + session.run( + "MERGE (c:Character {name: $name}) SET c.color = $color", + name=char_name, color=char_color, + ) + + for wp in route.get("route", []): + loc_id = wp.get("location") + if not loc_id: + continue # lat/lng only → 跳过(无命名地点节点) + chapter = wp.get("chapter", 0) + event = wp.get("event", "") + + session.run( + """ + MATCH (c:Character {name: $char}) + MATCH (l:Location {id: $lid}) + MERGE (c)-[v:VISITED {vol: $vol, chapter: $chapter}]->(l) + SET v.event = $event + """, + char=char_name, lid=loc_id, + vol=vol, chapter=chapter, event=event, + ) + + +def _import_events(session, events: list[dict], vol: int): + for i, evt in enumerate(events): + event_id = f"v{vol:02d}_e{i:03d}" + chapter = evt.get("chapter", 0) + description = evt.get("event", "") + + session.run( + """ + MERGE (e:Event {id: $id}) + SET e.vol = $vol, e.chapter = $chapter, e.description = $description + """, + id=event_id, vol=vol, chapter=chapter, description=description, + ) + + # 只在有命名地点 id 时建立关系(lat/lng 条目跳过) + loc_ref = evt.get("location") + if isinstance(loc_ref, str) and loc_ref: + session.run( + """ + MATCH (e:Event {id: $eid}) + MATCH (l:Location {id: $lid}) + MERGE (e)-[:OCCURRED_AT]->(l) + """, + eid=event_id, lid=loc_ref, + ) + + +# ── 主入口 ──────────────────────────────────────────────── + +def build_graph(driver: Driver, clear: bool = False): + if clear: + print("Clearing existing graph data...") + with driver.session() as s: + s.run("MATCH (n) DETACH DELETE n") + + print("Setting up schema constraints and indexes...") + setup_schema(driver) + + imported = 0 + for vol_num in range(1, 64): + filepath = DATA_DIR / f"vol{vol_num:02d}.json" + if not filepath.exists(): + continue + + with open(filepath, encoding="utf-8") as f: + data = json.load(f) + + with driver.session() as session: + _import_locations(session, data.get("locations", [])) + _import_factions(session, data.get("factions", []), vol_num) + _import_routes(session, data.get("character_routes", []), vol_num) + _import_events(session, data.get("key_events", []), vol_num) + + imported += 1 + print(f" [✓] vol{vol_num:02d} imported") + + print(f"\nDone. Imported {imported} volumes.") diff --git a/backend/graph_query.py b/backend/graph_query.py new file mode 100644 index 0000000..60d7ecc --- /dev/null +++ b/backend/graph_query.py @@ -0,0 +1,48 @@ +import os +from neo4j import GraphDatabase +from dotenv import load_dotenv + +load_dotenv() + +_driver = None + + +def get_driver(): + global _driver + if _driver is None: + uri = os.getenv("NEO4J_URI", "bolt://localhost:7687") + user = os.getenv("NEO4J_USER", "neo4j") + password = os.getenv("NEO4J_PASSWORD", "dtmap2024") + _driver = GraphDatabase.driver(uri, auth=(user, password)) + return _driver + + +def run_query(cypher: str, params: dict = None) -> list[dict]: + driver = get_driver() + with driver.session() as session: + result = session.run(cypher, params or {}) + rows = [] + for record in result: + row = {} + for key in record.keys(): + val = record[key] + # Node/Relationship → plain dict + if hasattr(val, "_properties"): + row[key] = dict(val._properties) + else: + row[key] = val + rows.append(row) + return rows + + +def get_graph_stats() -> dict: + driver = get_driver() + stats = {} + with driver.session() as session: + for label in ["Character", "Location", "Faction", "Event"]: + r = session.run(f"MATCH (n:{label}) RETURN count(n) AS cnt") + stats[f"{label}_count"] = r.single()["cnt"] + for rel in ["VISITED", "CONTROLS", "HAS_MEMBER", "LEADS", "OCCURRED_AT"]: + r = session.run(f"MATCH ()-[r:{rel}]->() RETURN count(r) AS cnt") + stats[f"{rel}_count"] = r.single()["cnt"] + return stats diff --git a/backend/llm_router.py b/backend/llm_router.py new file mode 100644 index 0000000..07f57b7 --- /dev/null +++ b/backend/llm_router.py @@ -0,0 +1,129 @@ +""" +Text-to-Cypher + 自然语言回答生成。 +使用 Claude API(claude-3-5-haiku 生成 Cypher,claude-3-5-sonnet 生成回答)。 +""" + +import os +import re +import anthropic +from dotenv import load_dotenv +from graph_query import run_query +from prompts import CYPHER_SYSTEM_PROMPT, ANSWER_SYSTEM_PROMPT + +load_dotenv() + +_client = None + +# 禁止写操作的关键字 +_WRITE_PATTERN = re.compile( + r'\b(CREATE|DELETE|SET|MERGE|REMOVE|DROP|DETACH|CALL)\b', + re.IGNORECASE, +) + + +def _get_client() -> anthropic.Anthropic: + global _client + if _client is None: + _client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) + return _client + + +def _generate_cypher(question: str, error_hint: str = "") -> str: + user_msg = question + if error_hint: + user_msg += f"\n\n上次生成的 Cypher 执行出错:{error_hint}\n请修正后重新生成。" + + resp = _get_client().messages.create( + model="claude-haiku-4-5", + max_tokens=512, + system=CYPHER_SYSTEM_PROMPT, + messages=[{"role": "user", "content": user_msg}], + ) + raw = resp.content[0].text.strip() + + # 去掉可能的 markdown 代码块 + raw = re.sub(r"^```(?:cypher)?\s*", "", raw, flags=re.IGNORECASE) + raw = re.sub(r"\s*```$", "", raw) + return raw.strip() + + +def _format_results(results: list[dict]) -> str: + if not results: + return "(无查询结果)" + lines = [] + for i, row in enumerate(results[:30], 1): + parts = [f"{k}: {v}" for k, v in row.items()] + lines.append(f"{i}. {', '.join(parts)}") + return "\n".join(lines) + + +def answer_question(question: str) -> dict: + """ + Returns: + { + "answer": str, # 自然语言回答 + "cypher": str | None, # 执行的 Cypher + "results": list[dict], # 原始查询结果(最多20条) + } + """ + # ── Step 1: 生成 Cypher ────────────────────────────── + cypher = _generate_cypher(question) + + if cypher.upper() == "UNSUPPORTED": + return { + "answer": "抱歉,这个问题超出了知识图谱的范围。" + "你可以询问人物行踪、势力控制区域、地点事件等相关问题。", + "cypher": None, + "results": [], + } + + # 必须以 MATCH 开头 + if not re.match(r"^\s*MATCH\b", cypher, re.IGNORECASE): + return { + "answer": "生成的查询语句格式有误,请换一种方式提问。", + "cypher": cypher, + "results": [], + } + + # 安全检查:禁止写操作 + if _WRITE_PATTERN.search(cypher): + return { + "answer": "生成的查询包含不允许的写操作,已拒绝执行。", + "cypher": cypher, + "results": [], + } + + # ── Step 2: 执行查询 ───────────────────────────────── + try: + results = run_query(cypher) + except Exception as e: + # 出错后重试一次,附上错误提示 + cypher = _generate_cypher(question, error_hint=str(e)) + try: + results = run_query(cypher) + except Exception as e2: + return { + "answer": f"查询执行失败,请尝试换一种方式提问。(错误:{e2})", + "cypher": cypher, + "results": [], + } + + # ── Step 3: 生成自然语言回答 ───────────────────────── + results_text = _format_results(results) + + resp = _get_client().messages.create( + model="claude-sonnet-4-5", + max_tokens=1024, + system=ANSWER_SYSTEM_PROMPT, + messages=[{ + "role": "user", + "content": f"用户问题:{question}\n\n知识图谱查询结果:\n{results_text}", + }], + ) + answer = resp.content[0].text.strip() + + return { + "answer": answer, + "cypher": cypher, + "results": results[:20], + } diff --git a/backend/prompts.py b/backend/prompts.py new file mode 100644 index 0000000..6c8bf20 --- /dev/null +++ b/backend/prompts.py @@ -0,0 +1,78 @@ +SCHEMA_DESCRIPTION = """ +大唐双龙传知识图谱 Schema(Neo4j): + +节点类型: +- Character {name, color} + 主要人物:寇仲、徐子陵、宇文化及、傅君婥、宋师道、李靖、石青璇、李密、李子通、 + 杜伏威、跋锋寒、李世民、李渊、宋缺、寇仲、毕玄、阴后 + +- Location {id, name, type, lat, lng} + type 取值:city / town / waterway / landmark / grassland / forest / region + 主要城市:扬州(yangzhou)、洛阳(luoyang)、长安/大兴(daxing)、丹阳(danyang)、 + 梁都、历阳(liyang)、江陵 + +- Faction {id, name, type, color} + type 取值:朝廷 / 门阀 / 义军 / 游牧政权 / 江湖势力 / 地方军阀 / 帮会 / 外族 + 主要势力:隋朝(sui)、李阀(li_clan)、宋阀(song_clan)、宇文阀(yuwen)、 + 瓦岗军(wagang_army)、突厥(turks)、慈航静斋、阴癸派 + +- Event {id, vol, chapter, description} + vol 是卷号(整数 1-63),chapter 是章节号 + +关系类型: +- (Character)-[:VISITED {vol, chapter, event}]->(Location) + 人物在某卷某章到访某地 + +- (Faction)-[:CONTROLS {vol}]->(Location) + 势力在某卷控制某地 + +- (Faction)-[:HAS_MEMBER {vol}]->(Character) + 势力在某卷拥有某成员 + +- (Character)-[:LEADS {vol}]->(Faction) + 人物在某卷领导某势力 + +- (Event)-[:OCCURRED_AT]->(Location) + 事件发生于某地 + +注意:vol 属性用整数表示(如 vol=1 代表第一卷,vol=20 代表第二十卷) +""" + +CYPHER_SYSTEM_PROMPT = f"""你是大唐双龙传知识图谱的 Cypher 查询专家。 + +{SCHEMA_DESCRIPTION} + +生成 Cypher 查询的规则: +1. 只输出 Cypher 语句,不要任何解释或 markdown 代码块 +2. 只使用 MATCH / RETURN / WHERE / WITH / ORDER BY / LIMIT / DISTINCT / COLLECT +3. 严禁使用 CREATE / SET / DELETE / MERGE / REMOVE / DROP +4. 默认加 LIMIT 30,除非用户指定数量 +5. 使用 DISTINCT 去重 +6. 属性名用 n.name、r.vol 格式,不要用整个节点 +7. 如果问题完全无法用图谱回答,只输出单词:UNSUPPORTED + +示例: +Q: 寇仲去过哪些地方? +A: MATCH (c:Character {{name: "寇仲"}})-[v:VISITED]->(l:Location) RETURN DISTINCT l.name, l.type, min(v.vol) AS first_vol ORDER BY first_vol LIMIT 30 + +Q: 第30卷时宇文阀控制哪些城市? +A: MATCH (f:Faction {{name: "宇文阀"}})-[r:CONTROLS]->(l:Location) WHERE r.vol <= 30 AND l.type = "city" RETURN DISTINCT l.name, r.vol ORDER BY r.vol LIMIT 30 + +Q: 扬州发生过哪些重要事件? +A: MATCH (e:Event)-[:OCCURRED_AT]->(l:Location {{name: "扬州"}}) RETURN e.description, e.vol, e.chapter ORDER BY e.vol, e.chapter LIMIT 30 + +Q: 谁领导过瓦岗军? +A: MATCH (c:Character)-[r:LEADS]->(f:Faction {{name: "瓦岗军"}}) RETURN DISTINCT c.name, r.vol ORDER BY r.vol LIMIT 30 + +Q: 寇仲和哪些势力有过关联? +A: MATCH (c:Character {{name: "寇仲"}})-[:VISITED]->(l:Location)<-[:CONTROLS]-(f:Faction) RETURN DISTINCT f.name, f.type LIMIT 30 +""" + +ANSWER_SYSTEM_PROMPT = """你是大唐双龙传的知识问答助手,熟悉小说中的人物、势力、地点和事件。 + +请根据知识图谱的查询结果,用中文给出准确、自然的回答: +- 直接回答问题,语言简洁流畅 +- 如果数据为空,说明"图谱中暂无相关记录" +- 可以适当补充人物背景,但以图谱数据为主 +- 数据量大时,做适当归纳而非逐条列举 +""" diff --git a/backend/requirements.txt b/backend/requirements.txt new file mode 100644 index 0000000..074be96 --- /dev/null +++ b/backend/requirements.txt @@ -0,0 +1,5 @@ +fastapi>=0.110.0 +uvicorn[standard]>=0.27.0 +neo4j>=5.18.0 +anthropic>=0.25.0 +python-dotenv>=1.0.0 diff --git a/backend/run_import.py b/backend/run_import.py new file mode 100644 index 0000000..4d62237 --- /dev/null +++ b/backend/run_import.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +""" +独立命令行导入脚本。 + +用法: + python run_import.py # 增量导入(MERGE,不删除现有数据) + python run_import.py --clear # 清空图谱后全量重新导入 +""" + +import sys +from dotenv import load_dotenv +from graph_query import get_driver +from graph_builder import build_graph + +load_dotenv() + + +def main(): + clear = "--clear" in sys.argv + + print("Connecting to Neo4j...") + driver = get_driver() + driver.verify_connectivity() + print("Connected.\n") + + build_graph(driver, clear=clear) + + print("\nGraph stats:") + from graph_query import get_graph_stats + for k, v in get_graph_stats().items(): + print(f" {k}: {v}") + + driver.close() + + +if __name__ == "__main__": + main()