from __future__ import annotations import asyncio import json from collections import defaultdict, deque from datetime import datetime from typing import Any from aiohttp import ClientError, ClientSession, ClientTimeout, web from loguru import logger from pydantic import BaseModel, ConfigDict, Field from pydantic.alias_generators import to_camel from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.base import BaseChannel class WebChannelConfig(BaseModel): model_config = ConfigDict(alias_generator=to_camel, populate_by_name=True) enabled: bool = False host: str = "127.0.0.1" port: int = 9000 allow_from: list[str] = Field(default_factory=lambda: ["*"]) cors_origin: str = "*" history_size: int = 200 ping_interval_s: int = 15 api_token: str = "" auth_enabled: bool = False auth_service_url: str = "" auth_service_timeout_s: int = 8 class WebChannel(BaseChannel): name = "web" display_name = "Web" @classmethod def default_config(cls) -> dict[str, Any]: return WebChannelConfig().model_dump(by_alias=True) def __init__(self, config: Any, bus: MessageBus): if isinstance(config, dict): config = WebChannelConfig.model_validate(config) elif not isinstance(config, WebChannelConfig): # Extra channel sections in nanobot config can arrive as generic objects. config = WebChannelConfig.model_validate(getattr(config, "model_dump", lambda: {})()) super().__init__(config, bus) self.config: WebChannelConfig = config self._app: web.Application | None = None self._runner: web.AppRunner | None = None self._site: web.TCPSite | None = None self._http: ClientSession | None = None self._listeners: dict[str, set[asyncio.Queue[dict[str, Any]]]] = defaultdict(set) self._history: dict[str, deque[dict[str, Any]]] = defaultdict( lambda: deque(maxlen=self.config.history_size) ) async def start(self) -> None: self._running = True if self.config.auth_enabled: if not self.config.auth_service_url.strip(): raise RuntimeError("authEnabled=true requires channels.web.authServiceUrl") self._http = ClientSession( timeout=ClientTimeout(total=max(1, self.config.auth_service_timeout_s)), ) logger.info("web channel external auth enabled -> {}", self.config.auth_service_url) self._app = web.Application(middlewares=[self._cors_middleware]) self._app.router.add_get("/health", self._health) self._app.router.add_post("/message", self._on_message) self._app.router.add_get("/events/{chat_id}", self._events) self._app.router.add_get("/history/{chat_id}", self._history_api) self._runner = web.AppRunner(self._app) await self._runner.setup() self._site = web.TCPSite(self._runner, self.config.host, self.config.port) await self._site.start() logger.info("web channel listening on http://{}:{}", self.config.host, self.config.port) while self._running: await asyncio.sleep(1) async def stop(self) -> None: self._running = False if self._http: await self._http.close() self._http = None if self._runner: await self._runner.cleanup() self._runner = None self._site = None async def send(self, msg: OutboundMessage) -> None: payload = { "type": "progress" if msg.metadata.get("_progress") else "message", "role": "assistant", "chat_id": msg.chat_id, "content": msg.content, "at": datetime.now().strftime("%H:%M:%S"), "metadata": msg.metadata, } self._append_history(msg.chat_id, payload) await self._fanout(msg.chat_id, payload) def _allowed_cors_origins(self) -> list[str]: configured = [o.strip() for o in self.config.cors_origin.split(",") if o.strip()] if not configured: configured = [str(o).strip() for o in self.config.allow_from if str(o).strip()] return configured or ["*"] def _resolve_cors_origin(self, request: web.Request) -> str: allowed = self._allowed_cors_origins() if "*" in allowed: return "*" req_origin = request.headers.get("Origin", "").strip() if req_origin and req_origin in allowed: return req_origin return allowed[0] def _apply_cors_headers(self, request: web.Request, resp: web.StreamResponse) -> web.StreamResponse: origin = self._resolve_cors_origin(request) resp.headers["Access-Control-Allow-Origin"] = origin if origin != "*": resp.headers["Vary"] = "Origin" resp.headers["Access-Control-Allow-Credentials"] = "true" resp.headers["Access-Control-Allow-Methods"] = "GET,POST,OPTIONS" req_headers = request.headers.get("Access-Control-Request-Headers", "").strip() resp.headers["Access-Control-Allow-Headers"] = req_headers or "Content-Type, Authorization" resp.headers["Access-Control-Max-Age"] = "86400" return resp @web.middleware async def _cors_middleware(self, request: web.Request, handler): try: if request.method == "OPTIONS": resp: web.StreamResponse = web.Response(status=204) else: resp = await handler(request) except web.HTTPException as e: resp = web.json_response({"ok": False, "error": e.reason or "http error"}, status=e.status) except Exception as e: logger.exception("web request failed: {}", e) resp = web.json_response({"ok": False, "error": "internal server error"}, status=500) return self._apply_cors_headers(request, resp) def _extract_token(self, request: web.Request) -> str: auth = request.headers.get("Authorization", "") if auth.startswith("Bearer "): return auth[7:].strip() # Native EventSource cannot set Authorization headers, so allow query token. return request.query.get("token", "").strip() async def _verify_external_user(self, token: str) -> dict[str, Any] | None: if not self._http: return None url = f"{self.config.auth_service_url.rstrip('/')}/auth/me" try: async with self._http.get(url, headers={"Authorization": f"Bearer {token}"}) as resp: if resp.status != 200: return None payload = await resp.json() if not payload.get("ok"): return None user = payload.get("user") if not isinstance(user, dict): return None phone = str(user.get("phone", "")) uid = int(user.get("id", 0)) if not phone or uid <= 0: return None return {"id": uid, "phone": phone} except (ClientError, asyncio.TimeoutError, ValueError) as e: logger.warning("web auth service request failed: {}", e) return None async def _require_auth(self, request: web.Request) -> tuple[dict[str, Any] | None, web.Response | None]: if self.config.auth_enabled: token = self._extract_token(request) if not token: return None, self._unauthorized() user = await self._verify_external_user(token) if not user: return None, self._unauthorized() return user, None expected = (self.config.api_token or "").strip() if expected and self._extract_token(request) != expected: return None, self._unauthorized() return None, None def _unauthorized(self) -> web.Response: return web.json_response({"ok": False, "error": "unauthorized"}, status=401) def _is_chat_allowed(self, user_phone: str, chat_id: str) -> bool: return chat_id == user_phone or chat_id.startswith(f"{user_phone}:") async def _health(self, _request: web.Request) -> web.Response: return web.json_response({"ok": True, "channel": self.name}) async def _history_api(self, request: web.Request) -> web.Response: user, err = await self._require_auth(request) if err: return err chat_id = request.match_info["chat_id"] if user and not self._is_chat_allowed(user["phone"], chat_id): return web.json_response({"ok": False, "error": "forbidden"}, status=403) return web.json_response({"messages": list(self._history.get(chat_id, []))}) async def _on_message(self, request: web.Request) -> web.Response: user, err = await self._require_auth(request) if err: return err try: body = await request.json() except Exception: return web.json_response({"ok": False, "error": "invalid json"}, status=400) sender = str(body.get("sender", "web-user")) if user: sender = user["phone"] chat_id = str(body.get("chat_id", sender)) if user and not self._is_chat_allowed(user["phone"], chat_id): return web.json_response({"ok": False, "error": "forbidden"}, status=403) text = str(body.get("text", "")).strip() media = body.get("media", []) metadata = body.get("metadata", {}) if not text: return web.json_response({"ok": False, "error": "text is required"}, status=400) self._append_history( chat_id, { "type": "message", "role": "user", "chat_id": chat_id, "content": text, "at": datetime.now().strftime("%H:%M:%S"), }, ) await self._handle_message( sender_id=sender, chat_id=chat_id, content=text, media=media if isinstance(media, list) else [], metadata=metadata if isinstance(metadata, dict) else {}, ) return web.json_response({"ok": True}) async def _events(self, request: web.Request) -> web.StreamResponse: user, err = await self._require_auth(request) if err: return err chat_id = request.match_info["chat_id"] if user and not self._is_chat_allowed(user["phone"], chat_id): return web.json_response({"ok": False, "error": "forbidden"}, status=403) resp = web.StreamResponse( status=200, headers={ "Content-Type": "text/event-stream", "Cache-Control": "no-cache", "Connection": "keep-alive", }, ) await resp.prepare(request) queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue() self._listeners[chat_id].add(queue) try: await self._write_sse( resp, { "type": "system", "role": "system", "chat_id": chat_id, "content": "stream connected", "at": datetime.now().strftime("%H:%M:%S"), }, ) for item in self._history.get(chat_id, []): await self._write_sse(resp, item) while self._running: try: item = await asyncio.wait_for( queue.get(), timeout=max(5, self.config.ping_interval_s) ) await self._write_sse(resp, item) except asyncio.TimeoutError: await resp.write(b": ping\\n\\n") await resp.drain() except asyncio.CancelledError: raise except ConnectionResetError: pass finally: self._listeners[chat_id].discard(queue) if not self._listeners[chat_id]: del self._listeners[chat_id] return resp async def _fanout(self, chat_id: str, payload: dict[str, Any]) -> None: listeners = self._listeners.get(chat_id) if not listeners: return for q in listeners: q.put_nowait(payload) def _append_history(self, chat_id: str, payload: dict[str, Any]) -> None: self._history[chat_id].append(payload) async def _write_sse(self, resp: web.StreamResponse, payload: dict[str, Any]) -> None: body = f"data: {json.dumps(payload, ensure_ascii=False)}\\n\\n".encode("utf-8") await resp.write(body) await resp.drain()