49 lines
1.5 KiB
Python
49 lines
1.5 KiB
Python
import os
|
|
from neo4j import GraphDatabase
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv()
|
|
|
|
_driver = None
|
|
|
|
|
|
def get_driver():
|
|
global _driver
|
|
if _driver is None:
|
|
uri = os.getenv("NEO4J_URI", "bolt://localhost:7687")
|
|
user = os.getenv("NEO4J_USER", "neo4j")
|
|
password = os.getenv("NEO4J_PASSWORD", "dtmap2024")
|
|
_driver = GraphDatabase.driver(uri, auth=(user, password))
|
|
return _driver
|
|
|
|
|
|
def run_query(cypher: str, params: dict = None) -> list[dict]:
|
|
driver = get_driver()
|
|
with driver.session() as session:
|
|
result = session.run(cypher, params or {})
|
|
rows = []
|
|
for record in result:
|
|
row = {}
|
|
for key in record.keys():
|
|
val = record[key]
|
|
# Node/Relationship → plain dict
|
|
if hasattr(val, "_properties"):
|
|
row[key] = dict(val._properties)
|
|
else:
|
|
row[key] = val
|
|
rows.append(row)
|
|
return rows
|
|
|
|
|
|
def get_graph_stats() -> dict:
|
|
driver = get_driver()
|
|
stats = {}
|
|
with driver.session() as session:
|
|
for label in ["Character", "Location", "Faction", "Event"]:
|
|
r = session.run(f"MATCH (n:{label}) RETURN count(n) AS cnt")
|
|
stats[f"{label}_count"] = r.single()["cnt"]
|
|
for rel in ["VISITED", "CONTROLS", "HAS_MEMBER", "LEADS", "OCCURRED_AT"]:
|
|
r = session.run(f"MATCH ()-[r:{rel}]->() RETURN count(r) AS cnt")
|
|
stats[f"{rel}_count"] = r.single()["cnt"]
|
|
return stats
|