328 lines
12 KiB
Python
328 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
from collections import defaultdict, deque
|
|
from datetime import datetime
|
|
from typing import Any
|
|
|
|
from aiohttp import ClientError, ClientSession, ClientTimeout, web
|
|
from loguru import logger
|
|
from pydantic import BaseModel, ConfigDict, Field
|
|
from pydantic.alias_generators import to_camel
|
|
|
|
from nanobot.bus.events import OutboundMessage
|
|
from nanobot.bus.queue import MessageBus
|
|
from nanobot.channels.base import BaseChannel
|
|
|
|
|
|
class WebChannelConfig(BaseModel):
|
|
model_config = ConfigDict(alias_generator=to_camel, populate_by_name=True)
|
|
|
|
enabled: bool = False
|
|
host: str = "127.0.0.1"
|
|
port: int = 9000
|
|
allow_from: list[str] = Field(default_factory=lambda: ["*"])
|
|
cors_origin: str = "*"
|
|
history_size: int = 200
|
|
ping_interval_s: int = 15
|
|
api_token: str = ""
|
|
auth_enabled: bool = False
|
|
auth_service_url: str = ""
|
|
auth_service_timeout_s: int = 8
|
|
|
|
|
|
class WebChannel(BaseChannel):
|
|
name = "web"
|
|
display_name = "Web"
|
|
|
|
@classmethod
|
|
def default_config(cls) -> dict[str, Any]:
|
|
return WebChannelConfig().model_dump(by_alias=True)
|
|
|
|
def __init__(self, config: Any, bus: MessageBus):
|
|
if isinstance(config, dict):
|
|
config = WebChannelConfig.model_validate(config)
|
|
elif not isinstance(config, WebChannelConfig):
|
|
# Extra channel sections in nanobot config can arrive as generic objects.
|
|
config = WebChannelConfig.model_validate(getattr(config, "model_dump", lambda: {})())
|
|
super().__init__(config, bus)
|
|
self.config: WebChannelConfig = config
|
|
self._app: web.Application | None = None
|
|
self._runner: web.AppRunner | None = None
|
|
self._site: web.TCPSite | None = None
|
|
self._http: ClientSession | None = None
|
|
self._listeners: dict[str, set[asyncio.Queue[dict[str, Any]]]] = defaultdict(set)
|
|
self._history: dict[str, deque[dict[str, Any]]] = defaultdict(
|
|
lambda: deque(maxlen=self.config.history_size)
|
|
)
|
|
|
|
async def start(self) -> None:
|
|
self._running = True
|
|
|
|
if self.config.auth_enabled:
|
|
if not self.config.auth_service_url.strip():
|
|
raise RuntimeError("authEnabled=true requires channels.web.authServiceUrl")
|
|
self._http = ClientSession(
|
|
timeout=ClientTimeout(total=max(1, self.config.auth_service_timeout_s)),
|
|
)
|
|
logger.info("web channel external auth enabled -> {}", self.config.auth_service_url)
|
|
|
|
self._app = web.Application(middlewares=[self._cors_middleware])
|
|
self._app.router.add_get("/health", self._health)
|
|
self._app.router.add_post("/message", self._on_message)
|
|
self._app.router.add_get("/events/{chat_id}", self._events)
|
|
self._app.router.add_get("/history/{chat_id}", self._history_api)
|
|
|
|
self._runner = web.AppRunner(self._app)
|
|
await self._runner.setup()
|
|
self._site = web.TCPSite(self._runner, self.config.host, self.config.port)
|
|
await self._site.start()
|
|
|
|
logger.info("web channel listening on http://{}:{}", self.config.host, self.config.port)
|
|
|
|
while self._running:
|
|
await asyncio.sleep(1)
|
|
|
|
async def stop(self) -> None:
|
|
self._running = False
|
|
if self._http:
|
|
await self._http.close()
|
|
self._http = None
|
|
if self._runner:
|
|
await self._runner.cleanup()
|
|
self._runner = None
|
|
self._site = None
|
|
|
|
async def send(self, msg: OutboundMessage) -> None:
|
|
payload = {
|
|
"type": "progress" if msg.metadata.get("_progress") else "message",
|
|
"role": "assistant",
|
|
"chat_id": msg.chat_id,
|
|
"content": msg.content,
|
|
"at": datetime.now().strftime("%H:%M:%S"),
|
|
"metadata": msg.metadata,
|
|
}
|
|
self._append_history(msg.chat_id, payload)
|
|
await self._fanout(msg.chat_id, payload)
|
|
|
|
def _allowed_cors_origins(self) -> list[str]:
|
|
configured = [o.strip() for o in self.config.cors_origin.split(",") if o.strip()]
|
|
if not configured:
|
|
configured = [str(o).strip() for o in self.config.allow_from if str(o).strip()]
|
|
return configured or ["*"]
|
|
|
|
def _resolve_cors_origin(self, request: web.Request) -> str:
|
|
allowed = self._allowed_cors_origins()
|
|
if "*" in allowed:
|
|
return "*"
|
|
req_origin = request.headers.get("Origin", "").strip()
|
|
if req_origin and req_origin in allowed:
|
|
return req_origin
|
|
return allowed[0]
|
|
|
|
def _apply_cors_headers(self, request: web.Request, resp: web.StreamResponse) -> web.StreamResponse:
|
|
origin = self._resolve_cors_origin(request)
|
|
resp.headers["Access-Control-Allow-Origin"] = origin
|
|
if origin != "*":
|
|
resp.headers["Vary"] = "Origin"
|
|
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
|
resp.headers["Access-Control-Allow-Methods"] = "GET,POST,OPTIONS"
|
|
req_headers = request.headers.get("Access-Control-Request-Headers", "").strip()
|
|
resp.headers["Access-Control-Allow-Headers"] = req_headers or "Content-Type, Authorization"
|
|
resp.headers["Access-Control-Max-Age"] = "86400"
|
|
return resp
|
|
|
|
@web.middleware
|
|
async def _cors_middleware(self, request: web.Request, handler):
|
|
try:
|
|
if request.method == "OPTIONS":
|
|
resp: web.StreamResponse = web.Response(status=204)
|
|
else:
|
|
resp = await handler(request)
|
|
except web.HTTPException as e:
|
|
resp = web.json_response({"ok": False, "error": e.reason or "http error"}, status=e.status)
|
|
except Exception as e:
|
|
logger.exception("web request failed: {}", e)
|
|
resp = web.json_response({"ok": False, "error": "internal server error"}, status=500)
|
|
|
|
return self._apply_cors_headers(request, resp)
|
|
|
|
def _extract_token(self, request: web.Request) -> str:
|
|
auth = request.headers.get("Authorization", "")
|
|
if auth.startswith("Bearer "):
|
|
return auth[7:].strip()
|
|
# Native EventSource cannot set Authorization headers, so allow query token.
|
|
return request.query.get("token", "").strip()
|
|
|
|
async def _verify_external_user(self, token: str) -> dict[str, Any] | None:
|
|
if not self._http:
|
|
return None
|
|
url = f"{self.config.auth_service_url.rstrip('/')}/auth/me"
|
|
try:
|
|
async with self._http.get(url, headers={"Authorization": f"Bearer {token}"}) as resp:
|
|
if resp.status != 200:
|
|
return None
|
|
payload = await resp.json()
|
|
if not payload.get("ok"):
|
|
return None
|
|
user = payload.get("user")
|
|
if not isinstance(user, dict):
|
|
return None
|
|
phone = str(user.get("phone", ""))
|
|
uid = int(user.get("id", 0))
|
|
if not phone or uid <= 0:
|
|
return None
|
|
return {"id": uid, "phone": phone}
|
|
except (ClientError, asyncio.TimeoutError, ValueError) as e:
|
|
logger.warning("web auth service request failed: {}", e)
|
|
return None
|
|
|
|
async def _require_auth(self, request: web.Request) -> tuple[dict[str, Any] | None, web.Response | None]:
|
|
if self.config.auth_enabled:
|
|
token = self._extract_token(request)
|
|
if not token:
|
|
return None, self._unauthorized()
|
|
user = await self._verify_external_user(token)
|
|
if not user:
|
|
return None, self._unauthorized()
|
|
return user, None
|
|
|
|
expected = (self.config.api_token or "").strip()
|
|
if expected and self._extract_token(request) != expected:
|
|
return None, self._unauthorized()
|
|
return None, None
|
|
|
|
def _unauthorized(self) -> web.Response:
|
|
return web.json_response({"ok": False, "error": "unauthorized"}, status=401)
|
|
|
|
def _is_chat_allowed(self, user_phone: str, chat_id: str) -> bool:
|
|
return chat_id == user_phone or chat_id.startswith(f"{user_phone}:")
|
|
|
|
async def _health(self, _request: web.Request) -> web.Response:
|
|
return web.json_response({"ok": True, "channel": self.name})
|
|
|
|
async def _history_api(self, request: web.Request) -> web.Response:
|
|
user, err = await self._require_auth(request)
|
|
if err:
|
|
return err
|
|
chat_id = request.match_info["chat_id"]
|
|
if user and not self._is_chat_allowed(user["phone"], chat_id):
|
|
return web.json_response({"ok": False, "error": "forbidden"}, status=403)
|
|
return web.json_response({"messages": list(self._history.get(chat_id, []))})
|
|
|
|
async def _on_message(self, request: web.Request) -> web.Response:
|
|
user, err = await self._require_auth(request)
|
|
if err:
|
|
return err
|
|
try:
|
|
body = await request.json()
|
|
except Exception:
|
|
return web.json_response({"ok": False, "error": "invalid json"}, status=400)
|
|
|
|
sender = str(body.get("sender", "web-user"))
|
|
if user:
|
|
sender = user["phone"]
|
|
chat_id = str(body.get("chat_id", sender))
|
|
if user and not self._is_chat_allowed(user["phone"], chat_id):
|
|
return web.json_response({"ok": False, "error": "forbidden"}, status=403)
|
|
text = str(body.get("text", "")).strip()
|
|
media = body.get("media", [])
|
|
metadata = body.get("metadata", {})
|
|
|
|
if not text:
|
|
return web.json_response({"ok": False, "error": "text is required"}, status=400)
|
|
|
|
self._append_history(
|
|
chat_id,
|
|
{
|
|
"type": "message",
|
|
"role": "user",
|
|
"chat_id": chat_id,
|
|
"content": text,
|
|
"at": datetime.now().strftime("%H:%M:%S"),
|
|
},
|
|
)
|
|
|
|
await self._handle_message(
|
|
sender_id=sender,
|
|
chat_id=chat_id,
|
|
content=text,
|
|
media=media if isinstance(media, list) else [],
|
|
metadata=metadata if isinstance(metadata, dict) else {},
|
|
)
|
|
|
|
return web.json_response({"ok": True})
|
|
|
|
async def _events(self, request: web.Request) -> web.StreamResponse:
|
|
user, err = await self._require_auth(request)
|
|
if err:
|
|
return err
|
|
chat_id = request.match_info["chat_id"]
|
|
if user and not self._is_chat_allowed(user["phone"], chat_id):
|
|
return web.json_response({"ok": False, "error": "forbidden"}, status=403)
|
|
|
|
resp = web.StreamResponse(
|
|
status=200,
|
|
headers={
|
|
"Content-Type": "text/event-stream",
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
},
|
|
)
|
|
await resp.prepare(request)
|
|
|
|
queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
|
|
self._listeners[chat_id].add(queue)
|
|
|
|
try:
|
|
await self._write_sse(
|
|
resp,
|
|
{
|
|
"type": "system",
|
|
"role": "system",
|
|
"chat_id": chat_id,
|
|
"content": "stream connected",
|
|
"at": datetime.now().strftime("%H:%M:%S"),
|
|
},
|
|
)
|
|
|
|
for item in self._history.get(chat_id, []):
|
|
await self._write_sse(resp, item)
|
|
|
|
while self._running:
|
|
try:
|
|
item = await asyncio.wait_for(
|
|
queue.get(), timeout=max(5, self.config.ping_interval_s)
|
|
)
|
|
await self._write_sse(resp, item)
|
|
except asyncio.TimeoutError:
|
|
await resp.write(b": ping\\n\\n")
|
|
await resp.drain()
|
|
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except ConnectionResetError:
|
|
pass
|
|
finally:
|
|
self._listeners[chat_id].discard(queue)
|
|
if not self._listeners[chat_id]:
|
|
del self._listeners[chat_id]
|
|
|
|
return resp
|
|
|
|
async def _fanout(self, chat_id: str, payload: dict[str, Any]) -> None:
|
|
listeners = self._listeners.get(chat_id)
|
|
if not listeners:
|
|
return
|
|
for q in listeners:
|
|
q.put_nowait(payload)
|
|
|
|
def _append_history(self, chat_id: str, payload: dict[str, Any]) -> None:
|
|
self._history[chat_id].append(payload)
|
|
|
|
async def _write_sse(self, resp: web.StreamResponse, payload: dict[str, Any]) -> None:
|
|
body = f"data: {json.dumps(payload, ensure_ascii=False)}\\n\\n".encode("utf-8")
|
|
await resp.write(body)
|
|
await resp.drain()
|