first commit
This commit is contained in:
3
nanobot-channel-web/nanobot_channel_web/__init__.py
Normal file
3
nanobot-channel-web/nanobot_channel_web/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .channel import WebChannel
|
||||
|
||||
__all__ = ["WebChannel"]
|
||||
327
nanobot-channel-web/nanobot_channel_web/channel.py
Normal file
327
nanobot-channel-web/nanobot_channel_web/channel.py
Normal file
@@ -0,0 +1,327 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections import defaultdict, deque
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from aiohttp import ClientError, ClientSession, ClientTimeout, web
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic.alias_generators import to_camel
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
|
||||
|
||||
class WebChannelConfig(BaseModel):
|
||||
model_config = ConfigDict(alias_generator=to_camel, populate_by_name=True)
|
||||
|
||||
enabled: bool = False
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 9000
|
||||
allow_from: list[str] = Field(default_factory=lambda: ["*"])
|
||||
cors_origin: str = "*"
|
||||
history_size: int = 200
|
||||
ping_interval_s: int = 15
|
||||
api_token: str = ""
|
||||
auth_enabled: bool = False
|
||||
auth_service_url: str = ""
|
||||
auth_service_timeout_s: int = 8
|
||||
|
||||
|
||||
class WebChannel(BaseChannel):
|
||||
name = "web"
|
||||
display_name = "Web"
|
||||
|
||||
@classmethod
|
||||
def default_config(cls) -> dict[str, Any]:
|
||||
return WebChannelConfig().model_dump(by_alias=True)
|
||||
|
||||
def __init__(self, config: Any, bus: MessageBus):
|
||||
if isinstance(config, dict):
|
||||
config = WebChannelConfig.model_validate(config)
|
||||
elif not isinstance(config, WebChannelConfig):
|
||||
# Extra channel sections in nanobot config can arrive as generic objects.
|
||||
config = WebChannelConfig.model_validate(getattr(config, "model_dump", lambda: {})())
|
||||
super().__init__(config, bus)
|
||||
self.config: WebChannelConfig = config
|
||||
self._app: web.Application | None = None
|
||||
self._runner: web.AppRunner | None = None
|
||||
self._site: web.TCPSite | None = None
|
||||
self._http: ClientSession | None = None
|
||||
self._listeners: dict[str, set[asyncio.Queue[dict[str, Any]]]] = defaultdict(set)
|
||||
self._history: dict[str, deque[dict[str, Any]]] = defaultdict(
|
||||
lambda: deque(maxlen=self.config.history_size)
|
||||
)
|
||||
|
||||
async def start(self) -> None:
|
||||
self._running = True
|
||||
|
||||
if self.config.auth_enabled:
|
||||
if not self.config.auth_service_url.strip():
|
||||
raise RuntimeError("authEnabled=true requires channels.web.authServiceUrl")
|
||||
self._http = ClientSession(
|
||||
timeout=ClientTimeout(total=max(1, self.config.auth_service_timeout_s)),
|
||||
)
|
||||
logger.info("web channel external auth enabled -> {}", self.config.auth_service_url)
|
||||
|
||||
self._app = web.Application(middlewares=[self._cors_middleware])
|
||||
self._app.router.add_get("/health", self._health)
|
||||
self._app.router.add_post("/message", self._on_message)
|
||||
self._app.router.add_get("/events/{chat_id}", self._events)
|
||||
self._app.router.add_get("/history/{chat_id}", self._history_api)
|
||||
|
||||
self._runner = web.AppRunner(self._app)
|
||||
await self._runner.setup()
|
||||
self._site = web.TCPSite(self._runner, self.config.host, self.config.port)
|
||||
await self._site.start()
|
||||
|
||||
logger.info("web channel listening on http://{}:{}", self.config.host, self.config.port)
|
||||
|
||||
while self._running:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._running = False
|
||||
if self._http:
|
||||
await self._http.close()
|
||||
self._http = None
|
||||
if self._runner:
|
||||
await self._runner.cleanup()
|
||||
self._runner = None
|
||||
self._site = None
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
payload = {
|
||||
"type": "progress" if msg.metadata.get("_progress") else "message",
|
||||
"role": "assistant",
|
||||
"chat_id": msg.chat_id,
|
||||
"content": msg.content,
|
||||
"at": datetime.now().strftime("%H:%M:%S"),
|
||||
"metadata": msg.metadata,
|
||||
}
|
||||
self._append_history(msg.chat_id, payload)
|
||||
await self._fanout(msg.chat_id, payload)
|
||||
|
||||
def _allowed_cors_origins(self) -> list[str]:
|
||||
configured = [o.strip() for o in self.config.cors_origin.split(",") if o.strip()]
|
||||
if not configured:
|
||||
configured = [str(o).strip() for o in self.config.allow_from if str(o).strip()]
|
||||
return configured or ["*"]
|
||||
|
||||
def _resolve_cors_origin(self, request: web.Request) -> str:
|
||||
allowed = self._allowed_cors_origins()
|
||||
if "*" in allowed:
|
||||
return "*"
|
||||
req_origin = request.headers.get("Origin", "").strip()
|
||||
if req_origin and req_origin in allowed:
|
||||
return req_origin
|
||||
return allowed[0]
|
||||
|
||||
def _apply_cors_headers(self, request: web.Request, resp: web.StreamResponse) -> web.StreamResponse:
|
||||
origin = self._resolve_cors_origin(request)
|
||||
resp.headers["Access-Control-Allow-Origin"] = origin
|
||||
if origin != "*":
|
||||
resp.headers["Vary"] = "Origin"
|
||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
resp.headers["Access-Control-Allow-Methods"] = "GET,POST,OPTIONS"
|
||||
req_headers = request.headers.get("Access-Control-Request-Headers", "").strip()
|
||||
resp.headers["Access-Control-Allow-Headers"] = req_headers or "Content-Type, Authorization"
|
||||
resp.headers["Access-Control-Max-Age"] = "86400"
|
||||
return resp
|
||||
|
||||
@web.middleware
|
||||
async def _cors_middleware(self, request: web.Request, handler):
|
||||
try:
|
||||
if request.method == "OPTIONS":
|
||||
resp: web.StreamResponse = web.Response(status=204)
|
||||
else:
|
||||
resp = await handler(request)
|
||||
except web.HTTPException as e:
|
||||
resp = web.json_response({"ok": False, "error": e.reason or "http error"}, status=e.status)
|
||||
except Exception as e:
|
||||
logger.exception("web request failed: {}", e)
|
||||
resp = web.json_response({"ok": False, "error": "internal server error"}, status=500)
|
||||
|
||||
return self._apply_cors_headers(request, resp)
|
||||
|
||||
def _extract_token(self, request: web.Request) -> str:
|
||||
auth = request.headers.get("Authorization", "")
|
||||
if auth.startswith("Bearer "):
|
||||
return auth[7:].strip()
|
||||
# Native EventSource cannot set Authorization headers, so allow query token.
|
||||
return request.query.get("token", "").strip()
|
||||
|
||||
async def _verify_external_user(self, token: str) -> dict[str, Any] | None:
|
||||
if not self._http:
|
||||
return None
|
||||
url = f"{self.config.auth_service_url.rstrip('/')}/auth/me"
|
||||
try:
|
||||
async with self._http.get(url, headers={"Authorization": f"Bearer {token}"}) as resp:
|
||||
if resp.status != 200:
|
||||
return None
|
||||
payload = await resp.json()
|
||||
if not payload.get("ok"):
|
||||
return None
|
||||
user = payload.get("user")
|
||||
if not isinstance(user, dict):
|
||||
return None
|
||||
phone = str(user.get("phone", ""))
|
||||
uid = int(user.get("id", 0))
|
||||
if not phone or uid <= 0:
|
||||
return None
|
||||
return {"id": uid, "phone": phone}
|
||||
except (ClientError, asyncio.TimeoutError, ValueError) as e:
|
||||
logger.warning("web auth service request failed: {}", e)
|
||||
return None
|
||||
|
||||
async def _require_auth(self, request: web.Request) -> tuple[dict[str, Any] | None, web.Response | None]:
|
||||
if self.config.auth_enabled:
|
||||
token = self._extract_token(request)
|
||||
if not token:
|
||||
return None, self._unauthorized()
|
||||
user = await self._verify_external_user(token)
|
||||
if not user:
|
||||
return None, self._unauthorized()
|
||||
return user, None
|
||||
|
||||
expected = (self.config.api_token or "").strip()
|
||||
if expected and self._extract_token(request) != expected:
|
||||
return None, self._unauthorized()
|
||||
return None, None
|
||||
|
||||
def _unauthorized(self) -> web.Response:
|
||||
return web.json_response({"ok": False, "error": "unauthorized"}, status=401)
|
||||
|
||||
def _is_chat_allowed(self, user_phone: str, chat_id: str) -> bool:
|
||||
return chat_id == user_phone or chat_id.startswith(f"{user_phone}:")
|
||||
|
||||
async def _health(self, _request: web.Request) -> web.Response:
|
||||
return web.json_response({"ok": True, "channel": self.name})
|
||||
|
||||
async def _history_api(self, request: web.Request) -> web.Response:
|
||||
user, err = await self._require_auth(request)
|
||||
if err:
|
||||
return err
|
||||
chat_id = request.match_info["chat_id"]
|
||||
if user and not self._is_chat_allowed(user["phone"], chat_id):
|
||||
return web.json_response({"ok": False, "error": "forbidden"}, status=403)
|
||||
return web.json_response({"messages": list(self._history.get(chat_id, []))})
|
||||
|
||||
async def _on_message(self, request: web.Request) -> web.Response:
|
||||
user, err = await self._require_auth(request)
|
||||
if err:
|
||||
return err
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
return web.json_response({"ok": False, "error": "invalid json"}, status=400)
|
||||
|
||||
sender = str(body.get("sender", "web-user"))
|
||||
if user:
|
||||
sender = user["phone"]
|
||||
chat_id = str(body.get("chat_id", sender))
|
||||
if user and not self._is_chat_allowed(user["phone"], chat_id):
|
||||
return web.json_response({"ok": False, "error": "forbidden"}, status=403)
|
||||
text = str(body.get("text", "")).strip()
|
||||
media = body.get("media", [])
|
||||
metadata = body.get("metadata", {})
|
||||
|
||||
if not text:
|
||||
return web.json_response({"ok": False, "error": "text is required"}, status=400)
|
||||
|
||||
self._append_history(
|
||||
chat_id,
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"chat_id": chat_id,
|
||||
"content": text,
|
||||
"at": datetime.now().strftime("%H:%M:%S"),
|
||||
},
|
||||
)
|
||||
|
||||
await self._handle_message(
|
||||
sender_id=sender,
|
||||
chat_id=chat_id,
|
||||
content=text,
|
||||
media=media if isinstance(media, list) else [],
|
||||
metadata=metadata if isinstance(metadata, dict) else {},
|
||||
)
|
||||
|
||||
return web.json_response({"ok": True})
|
||||
|
||||
async def _events(self, request: web.Request) -> web.StreamResponse:
|
||||
user, err = await self._require_auth(request)
|
||||
if err:
|
||||
return err
|
||||
chat_id = request.match_info["chat_id"]
|
||||
if user and not self._is_chat_allowed(user["phone"], chat_id):
|
||||
return web.json_response({"ok": False, "error": "forbidden"}, status=403)
|
||||
|
||||
resp = web.StreamResponse(
|
||||
status=200,
|
||||
headers={
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
await resp.prepare(request)
|
||||
|
||||
queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
|
||||
self._listeners[chat_id].add(queue)
|
||||
|
||||
try:
|
||||
await self._write_sse(
|
||||
resp,
|
||||
{
|
||||
"type": "system",
|
||||
"role": "system",
|
||||
"chat_id": chat_id,
|
||||
"content": "stream connected",
|
||||
"at": datetime.now().strftime("%H:%M:%S"),
|
||||
},
|
||||
)
|
||||
|
||||
for item in self._history.get(chat_id, []):
|
||||
await self._write_sse(resp, item)
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
item = await asyncio.wait_for(
|
||||
queue.get(), timeout=max(5, self.config.ping_interval_s)
|
||||
)
|
||||
await self._write_sse(resp, item)
|
||||
except asyncio.TimeoutError:
|
||||
await resp.write(b": ping\\n\\n")
|
||||
await resp.drain()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except ConnectionResetError:
|
||||
pass
|
||||
finally:
|
||||
self._listeners[chat_id].discard(queue)
|
||||
if not self._listeners[chat_id]:
|
||||
del self._listeners[chat_id]
|
||||
|
||||
return resp
|
||||
|
||||
async def _fanout(self, chat_id: str, payload: dict[str, Any]) -> None:
|
||||
listeners = self._listeners.get(chat_id)
|
||||
if not listeners:
|
||||
return
|
||||
for q in listeners:
|
||||
q.put_nowait(payload)
|
||||
|
||||
def _append_history(self, chat_id: str, payload: dict[str, Any]) -> None:
|
||||
self._history[chat_id].append(payload)
|
||||
|
||||
async def _write_sse(self, resp: web.StreamResponse, payload: dict[str, Any]) -> None:
|
||||
body = f"data: {json.dumps(payload, ensure_ascii=False)}\\n\\n".encode("utf-8")
|
||||
await resp.write(body)
|
||||
await resp.drain()
|
||||
Reference in New Issue
Block a user