Files
mars-nanobot/nanobot-channel-web/nanobot_channel_web/channel.py
2026-03-27 16:10:45 +08:00

328 lines
12 KiB
Python

from __future__ import annotations
import asyncio
import json
from collections import defaultdict, deque
from datetime import datetime
from typing import Any
from aiohttp import ClientError, ClientSession, ClientTimeout, web
from loguru import logger
from pydantic import BaseModel, ConfigDict, Field
from pydantic.alias_generators import to_camel
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
class WebChannelConfig(BaseModel):
model_config = ConfigDict(alias_generator=to_camel, populate_by_name=True)
enabled: bool = False
host: str = "127.0.0.1"
port: int = 9000
allow_from: list[str] = Field(default_factory=lambda: ["*"])
cors_origin: str = "*"
history_size: int = 200
ping_interval_s: int = 15
api_token: str = ""
auth_enabled: bool = False
auth_service_url: str = ""
auth_service_timeout_s: int = 8
class WebChannel(BaseChannel):
name = "web"
display_name = "Web"
@classmethod
def default_config(cls) -> dict[str, Any]:
return WebChannelConfig().model_dump(by_alias=True)
def __init__(self, config: Any, bus: MessageBus):
if isinstance(config, dict):
config = WebChannelConfig.model_validate(config)
elif not isinstance(config, WebChannelConfig):
# Extra channel sections in nanobot config can arrive as generic objects.
config = WebChannelConfig.model_validate(getattr(config, "model_dump", lambda: {})())
super().__init__(config, bus)
self.config: WebChannelConfig = config
self._app: web.Application | None = None
self._runner: web.AppRunner | None = None
self._site: web.TCPSite | None = None
self._http: ClientSession | None = None
self._listeners: dict[str, set[asyncio.Queue[dict[str, Any]]]] = defaultdict(set)
self._history: dict[str, deque[dict[str, Any]]] = defaultdict(
lambda: deque(maxlen=self.config.history_size)
)
async def start(self) -> None:
self._running = True
if self.config.auth_enabled:
if not self.config.auth_service_url.strip():
raise RuntimeError("authEnabled=true requires channels.web.authServiceUrl")
self._http = ClientSession(
timeout=ClientTimeout(total=max(1, self.config.auth_service_timeout_s)),
)
logger.info("web channel external auth enabled -> {}", self.config.auth_service_url)
self._app = web.Application(middlewares=[self._cors_middleware])
self._app.router.add_get("/health", self._health)
self._app.router.add_post("/message", self._on_message)
self._app.router.add_get("/events/{chat_id}", self._events)
self._app.router.add_get("/history/{chat_id}", self._history_api)
self._runner = web.AppRunner(self._app)
await self._runner.setup()
self._site = web.TCPSite(self._runner, self.config.host, self.config.port)
await self._site.start()
logger.info("web channel listening on http://{}:{}", self.config.host, self.config.port)
while self._running:
await asyncio.sleep(1)
async def stop(self) -> None:
self._running = False
if self._http:
await self._http.close()
self._http = None
if self._runner:
await self._runner.cleanup()
self._runner = None
self._site = None
async def send(self, msg: OutboundMessage) -> None:
payload = {
"type": "progress" if msg.metadata.get("_progress") else "message",
"role": "assistant",
"chat_id": msg.chat_id,
"content": msg.content,
"at": datetime.now().strftime("%H:%M:%S"),
"metadata": msg.metadata,
}
self._append_history(msg.chat_id, payload)
await self._fanout(msg.chat_id, payload)
def _allowed_cors_origins(self) -> list[str]:
configured = [o.strip() for o in self.config.cors_origin.split(",") if o.strip()]
if not configured:
configured = [str(o).strip() for o in self.config.allow_from if str(o).strip()]
return configured or ["*"]
def _resolve_cors_origin(self, request: web.Request) -> str:
allowed = self._allowed_cors_origins()
if "*" in allowed:
return "*"
req_origin = request.headers.get("Origin", "").strip()
if req_origin and req_origin in allowed:
return req_origin
return allowed[0]
def _apply_cors_headers(self, request: web.Request, resp: web.StreamResponse) -> web.StreamResponse:
origin = self._resolve_cors_origin(request)
resp.headers["Access-Control-Allow-Origin"] = origin
if origin != "*":
resp.headers["Vary"] = "Origin"
resp.headers["Access-Control-Allow-Credentials"] = "true"
resp.headers["Access-Control-Allow-Methods"] = "GET,POST,OPTIONS"
req_headers = request.headers.get("Access-Control-Request-Headers", "").strip()
resp.headers["Access-Control-Allow-Headers"] = req_headers or "Content-Type, Authorization"
resp.headers["Access-Control-Max-Age"] = "86400"
return resp
@web.middleware
async def _cors_middleware(self, request: web.Request, handler):
try:
if request.method == "OPTIONS":
resp: web.StreamResponse = web.Response(status=204)
else:
resp = await handler(request)
except web.HTTPException as e:
resp = web.json_response({"ok": False, "error": e.reason or "http error"}, status=e.status)
except Exception as e:
logger.exception("web request failed: {}", e)
resp = web.json_response({"ok": False, "error": "internal server error"}, status=500)
return self._apply_cors_headers(request, resp)
def _extract_token(self, request: web.Request) -> str:
auth = request.headers.get("Authorization", "")
if auth.startswith("Bearer "):
return auth[7:].strip()
# Native EventSource cannot set Authorization headers, so allow query token.
return request.query.get("token", "").strip()
async def _verify_external_user(self, token: str) -> dict[str, Any] | None:
if not self._http:
return None
url = f"{self.config.auth_service_url.rstrip('/')}/auth/me"
try:
async with self._http.get(url, headers={"Authorization": f"Bearer {token}"}) as resp:
if resp.status != 200:
return None
payload = await resp.json()
if not payload.get("ok"):
return None
user = payload.get("user")
if not isinstance(user, dict):
return None
phone = str(user.get("phone", ""))
uid = int(user.get("id", 0))
if not phone or uid <= 0:
return None
return {"id": uid, "phone": phone}
except (ClientError, asyncio.TimeoutError, ValueError) as e:
logger.warning("web auth service request failed: {}", e)
return None
async def _require_auth(self, request: web.Request) -> tuple[dict[str, Any] | None, web.Response | None]:
if self.config.auth_enabled:
token = self._extract_token(request)
if not token:
return None, self._unauthorized()
user = await self._verify_external_user(token)
if not user:
return None, self._unauthorized()
return user, None
expected = (self.config.api_token or "").strip()
if expected and self._extract_token(request) != expected:
return None, self._unauthorized()
return None, None
def _unauthorized(self) -> web.Response:
return web.json_response({"ok": False, "error": "unauthorized"}, status=401)
def _is_chat_allowed(self, user_phone: str, chat_id: str) -> bool:
return chat_id == user_phone or chat_id.startswith(f"{user_phone}:")
async def _health(self, _request: web.Request) -> web.Response:
return web.json_response({"ok": True, "channel": self.name})
async def _history_api(self, request: web.Request) -> web.Response:
user, err = await self._require_auth(request)
if err:
return err
chat_id = request.match_info["chat_id"]
if user and not self._is_chat_allowed(user["phone"], chat_id):
return web.json_response({"ok": False, "error": "forbidden"}, status=403)
return web.json_response({"messages": list(self._history.get(chat_id, []))})
async def _on_message(self, request: web.Request) -> web.Response:
user, err = await self._require_auth(request)
if err:
return err
try:
body = await request.json()
except Exception:
return web.json_response({"ok": False, "error": "invalid json"}, status=400)
sender = str(body.get("sender", "web-user"))
if user:
sender = user["phone"]
chat_id = str(body.get("chat_id", sender))
if user and not self._is_chat_allowed(user["phone"], chat_id):
return web.json_response({"ok": False, "error": "forbidden"}, status=403)
text = str(body.get("text", "")).strip()
media = body.get("media", [])
metadata = body.get("metadata", {})
if not text:
return web.json_response({"ok": False, "error": "text is required"}, status=400)
self._append_history(
chat_id,
{
"type": "message",
"role": "user",
"chat_id": chat_id,
"content": text,
"at": datetime.now().strftime("%H:%M:%S"),
},
)
await self._handle_message(
sender_id=sender,
chat_id=chat_id,
content=text,
media=media if isinstance(media, list) else [],
metadata=metadata if isinstance(metadata, dict) else {},
)
return web.json_response({"ok": True})
async def _events(self, request: web.Request) -> web.StreamResponse:
user, err = await self._require_auth(request)
if err:
return err
chat_id = request.match_info["chat_id"]
if user and not self._is_chat_allowed(user["phone"], chat_id):
return web.json_response({"ok": False, "error": "forbidden"}, status=403)
resp = web.StreamResponse(
status=200,
headers={
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
await resp.prepare(request)
queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
self._listeners[chat_id].add(queue)
try:
await self._write_sse(
resp,
{
"type": "system",
"role": "system",
"chat_id": chat_id,
"content": "stream connected",
"at": datetime.now().strftime("%H:%M:%S"),
},
)
for item in self._history.get(chat_id, []):
await self._write_sse(resp, item)
while self._running:
try:
item = await asyncio.wait_for(
queue.get(), timeout=max(5, self.config.ping_interval_s)
)
await self._write_sse(resp, item)
except asyncio.TimeoutError:
await resp.write(b": ping\\n\\n")
await resp.drain()
except asyncio.CancelledError:
raise
except ConnectionResetError:
pass
finally:
self._listeners[chat_id].discard(queue)
if not self._listeners[chat_id]:
del self._listeners[chat_id]
return resp
async def _fanout(self, chat_id: str, payload: dict[str, Any]) -> None:
listeners = self._listeners.get(chat_id)
if not listeners:
return
for q in listeners:
q.put_nowait(payload)
def _append_history(self, chat_id: str, payload: dict[str, Any]) -> None:
self._history[chat_id].append(payload)
async def _write_sse(self, resp: web.StreamResponse, payload: dict[str, Any]) -> None:
body = f"data: {json.dumps(payload, ensure_ascii=False)}\\n\\n".encode("utf-8")
await resp.write(body)
await resp.drain()