first commit
This commit is contained in:
421
app/main.py
Normal file
421
app/main.py
Normal file
@@ -0,0 +1,421 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
import re
|
||||
import secrets
|
||||
import sqlite3
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import jwt
|
||||
from fastapi import Depends, FastAPI, Header, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
|
||||
PHONE_PATTERN = re.compile(r"^\d{6,20}$")
|
||||
|
||||
AUTH_DB_PATH = os.getenv("AUTH_DB_PATH", "auth_service.sqlite3")
|
||||
AUTH_JWT_SECRET = os.getenv("AUTH_JWT_SECRET", "change-this-secret")
|
||||
AUTH_TOKEN_TTL_HOURS = int(os.getenv("AUTH_TOKEN_TTL_HOURS", "24"))
|
||||
AUTH_CORS_ORIGINS = os.getenv("AUTH_CORS_ORIGINS", "*")
|
||||
AUTH_INVITE_CODES = {
|
||||
item.strip()
|
||||
for item in os.getenv("AUTH_INVITE_CODES", "").split(",")
|
||||
if item.strip()
|
||||
}
|
||||
AUTH_ADMIN_KEY = os.getenv("AUTH_ADMIN_KEY", "")
|
||||
|
||||
|
||||
def _ensure_db(path_str: str) -> Path:
|
||||
path = Path(path_str).expanduser()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with sqlite3.connect(path) as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
phone TEXT NOT NULL UNIQUE,
|
||||
password_hash TEXT NOT NULL,
|
||||
salt TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
last_login_at TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS registration_requests (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
phone TEXT NOT NULL UNIQUE,
|
||||
password_hash TEXT NOT NULL,
|
||||
salt TEXT NOT NULL,
|
||||
invite_code TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
reviewed_at TEXT,
|
||||
review_note TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
return path
|
||||
|
||||
|
||||
DB_PATH = _ensure_db(AUTH_DB_PATH)
|
||||
|
||||
|
||||
class AuthPayload(BaseModel):
|
||||
phone: str
|
||||
password: str
|
||||
|
||||
|
||||
class RegisterPayload(AuthPayload):
|
||||
invite_code: str
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
ok: bool
|
||||
access_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int
|
||||
user: dict
|
||||
|
||||
|
||||
class RegisterPendingResponse(BaseModel):
|
||||
ok: bool
|
||||
status: str
|
||||
request_id: int
|
||||
message: str
|
||||
|
||||
|
||||
def _connect() -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
|
||||
def _normalize_phone(phone: str) -> str:
|
||||
p = str(phone).strip()
|
||||
# Accept common user input formats in test environments, then normalize.
|
||||
p = p.replace(" ", "").replace("-", "")
|
||||
if p.startswith("+86"):
|
||||
p = p[3:]
|
||||
if not PHONE_PATTERN.match(p):
|
||||
raise HTTPException(status_code=400, detail="phone must be 6-20 digits")
|
||||
return p
|
||||
|
||||
|
||||
def _validate_password(password: str) -> None:
|
||||
if len(password) < 6:
|
||||
raise HTTPException(status_code=400, detail="password must be at least 6 characters")
|
||||
|
||||
|
||||
def _hash_password(password: str, salt: bytes) -> str:
|
||||
digest = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, 200_000)
|
||||
return base64.b64encode(digest).decode("ascii")
|
||||
|
||||
|
||||
def _create_access_token(user_id: int, phone: str) -> str:
|
||||
now = datetime.now(timezone.utc)
|
||||
payload = {
|
||||
"uid": user_id,
|
||||
"sub": phone,
|
||||
"iat": int(now.timestamp()),
|
||||
"exp": int((now + timedelta(hours=max(1, AUTH_TOKEN_TTL_HOURS))).timestamp()),
|
||||
}
|
||||
return jwt.encode(payload, AUTH_JWT_SECRET, algorithm="HS256")
|
||||
|
||||
|
||||
def _token_response(user_id: int, phone: str) -> TokenResponse:
|
||||
token = _create_access_token(user_id, phone)
|
||||
return TokenResponse(
|
||||
ok=True,
|
||||
access_token=token,
|
||||
expires_in=max(1, AUTH_TOKEN_TTL_HOURS) * 3600,
|
||||
user={"id": user_id, "phone": phone},
|
||||
)
|
||||
|
||||
|
||||
def _extract_bearer(authorization: str | None) -> str:
|
||||
if not authorization:
|
||||
raise HTTPException(status_code=401, detail="missing authorization")
|
||||
prefix = "Bearer "
|
||||
if not authorization.startswith(prefix):
|
||||
raise HTTPException(status_code=401, detail="invalid authorization")
|
||||
token = authorization[len(prefix):].strip()
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="missing token")
|
||||
return token
|
||||
|
||||
|
||||
def _current_user(authorization: str | None = Header(default=None)) -> dict:
|
||||
token = _extract_bearer(authorization)
|
||||
try:
|
||||
payload = jwt.decode(token, AUTH_JWT_SECRET, algorithms=["HS256"])
|
||||
except jwt.PyJWTError as e:
|
||||
raise HTTPException(status_code=401, detail="invalid token") from e
|
||||
|
||||
uid = int(payload.get("uid", 0))
|
||||
phone = str(payload.get("sub", ""))
|
||||
if uid <= 0 or not phone:
|
||||
raise HTTPException(status_code=401, detail="invalid token payload")
|
||||
return {"id": uid, "phone": phone}
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _require_invite_code(invite_code: str) -> str:
|
||||
code = str(invite_code).strip()
|
||||
if not code:
|
||||
raise HTTPException(status_code=400, detail="invite code is required")
|
||||
if AUTH_INVITE_CODES and code not in AUTH_INVITE_CODES:
|
||||
raise HTTPException(status_code=400, detail="invalid invite code")
|
||||
return code
|
||||
|
||||
|
||||
def _admin_ok(x_admin_key: str | None = Header(default=None)) -> None:
|
||||
if not AUTH_ADMIN_KEY:
|
||||
raise HTTPException(status_code=500, detail="AUTH_ADMIN_KEY is not configured")
|
||||
if x_admin_key != AUTH_ADMIN_KEY:
|
||||
raise HTTPException(status_code=401, detail="invalid admin key")
|
||||
|
||||
|
||||
def _approve_request_with_conn(conn: sqlite3.Connection, request_id: int, note: str = "approved") -> None:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT id, phone, password_hash, salt, status
|
||||
FROM registration_requests
|
||||
WHERE id = ?
|
||||
""",
|
||||
(request_id,),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(status_code=404, detail="request not found")
|
||||
if row["status"] != "pending":
|
||||
raise HTTPException(status_code=400, detail=f"request already {row['status']}")
|
||||
|
||||
try:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO users(phone, password_hash, salt, created_at)
|
||||
VALUES(?, ?, ?, ?)
|
||||
""",
|
||||
(row["phone"], row["password_hash"], row["salt"], _now_iso()),
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE registration_requests
|
||||
SET status = 'approved', reviewed_at = ?, review_note = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(_now_iso(), note, request_id),
|
||||
)
|
||||
conn.commit()
|
||||
except sqlite3.IntegrityError as e:
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE registration_requests
|
||||
SET status = 'rejected', reviewed_at = ?, review_note = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(_now_iso(), "phone already exists", request_id),
|
||||
)
|
||||
conn.commit()
|
||||
raise HTTPException(status_code=400, detail="phone already exists") from e
|
||||
|
||||
|
||||
def _reject_request_with_conn(conn: sqlite3.Connection, request_id: int, note: str) -> None:
|
||||
row = conn.execute(
|
||||
"SELECT id, status FROM registration_requests WHERE id = ?",
|
||||
(request_id,),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(status_code=404, detail="request not found")
|
||||
if row["status"] != "pending":
|
||||
raise HTTPException(status_code=400, detail=f"request already {row['status']}")
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE registration_requests
|
||||
SET status = 'rejected', reviewed_at = ?, review_note = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(_now_iso(), note or "rejected", request_id),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
app = FastAPI(title="nanobot-auth-service", version="0.1.0")
|
||||
|
||||
origins = ["*"] if AUTH_CORS_ORIGINS.strip() == "*" else [o.strip() for o in AUTH_CORS_ORIGINS.split(",") if o.strip()]
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health() -> dict:
|
||||
return {
|
||||
"ok": True,
|
||||
"service": "auth",
|
||||
"db": str(DB_PATH),
|
||||
"invite_code_check": bool(AUTH_INVITE_CODES),
|
||||
}
|
||||
|
||||
|
||||
@app.post("/auth/register", response_model=RegisterPendingResponse)
|
||||
def register(payload: RegisterPayload):
|
||||
phone = _normalize_phone(payload.phone)
|
||||
_validate_password(payload.password)
|
||||
invite_code = _require_invite_code(payload.invite_code)
|
||||
|
||||
salt = secrets.token_bytes(16)
|
||||
password_hash = _hash_password(payload.password, salt)
|
||||
now = _now_iso()
|
||||
|
||||
try:
|
||||
with _connect() as conn:
|
||||
existing = conn.execute(
|
||||
"SELECT id FROM users WHERE phone = ?",
|
||||
(phone,),
|
||||
).fetchone()
|
||||
if existing is not None:
|
||||
raise HTTPException(status_code=400, detail="phone already exists")
|
||||
|
||||
pending = conn.execute(
|
||||
"SELECT id FROM registration_requests WHERE phone = ? AND status = 'pending'",
|
||||
(phone,),
|
||||
).fetchone()
|
||||
if pending is not None:
|
||||
return RegisterPendingResponse(
|
||||
ok=True,
|
||||
status="pending",
|
||||
request_id=int(pending["id"]),
|
||||
message="pending review",
|
||||
)
|
||||
|
||||
cur = conn.execute(
|
||||
"""
|
||||
INSERT INTO registration_requests(phone, password_hash, salt, invite_code, status, created_at)
|
||||
VALUES(?, ?, ?, ?, 'pending', ?)
|
||||
""",
|
||||
(phone, password_hash, base64.b64encode(salt).decode("ascii"), invite_code, now),
|
||||
)
|
||||
request_id = int(cur.lastrowid)
|
||||
conn.commit()
|
||||
except HTTPException:
|
||||
raise
|
||||
except sqlite3.IntegrityError as e:
|
||||
raise HTTPException(status_code=400, detail="phone already exists") from e
|
||||
|
||||
return RegisterPendingResponse(
|
||||
ok=True,
|
||||
status="pending",
|
||||
request_id=request_id,
|
||||
message="pending review",
|
||||
)
|
||||
|
||||
|
||||
@app.get("/auth/register/status/{request_id}")
|
||||
def register_status(request_id: int):
|
||||
with _connect() as conn:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT id, phone, status, review_note
|
||||
FROM registration_requests
|
||||
WHERE id = ?
|
||||
""",
|
||||
(request_id,),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(status_code=404, detail="request not found")
|
||||
return {
|
||||
"ok": True,
|
||||
"request_id": int(row["id"]),
|
||||
"phone": row["phone"],
|
||||
"status": row["status"],
|
||||
"review_note": row["review_note"] or "",
|
||||
}
|
||||
|
||||
|
||||
@app.post("/auth/login", response_model=TokenResponse)
|
||||
def login(payload: AuthPayload):
|
||||
phone = _normalize_phone(payload.phone)
|
||||
|
||||
with _connect() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT id, phone, password_hash, salt FROM users WHERE phone = ?",
|
||||
(phone,),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(status_code=401, detail="invalid credentials")
|
||||
|
||||
salt = base64.b64decode(row["salt"].encode("ascii"))
|
||||
expected = row["password_hash"]
|
||||
actual = _hash_password(payload.password, salt)
|
||||
if not hmac.compare_digest(expected, actual):
|
||||
raise HTTPException(status_code=401, detail="invalid credentials")
|
||||
|
||||
conn.execute(
|
||||
"UPDATE users SET last_login_at = ? WHERE id = ?",
|
||||
(datetime.now(timezone.utc).isoformat(), int(row["id"])),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
return _token_response(int(row["id"]), str(row["phone"]))
|
||||
|
||||
|
||||
@app.get("/auth/me")
|
||||
def me(user: dict = Depends(_current_user)):
|
||||
return {"ok": True, "user": user}
|
||||
|
||||
|
||||
@app.get("/admin/requests")
|
||||
def admin_requests(status_filter: str = "pending", _admin: None = Depends(_admin_ok)):
|
||||
with _connect() as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT id, phone, invite_code, status, created_at, reviewed_at, review_note
|
||||
FROM registration_requests
|
||||
WHERE status = ?
|
||||
ORDER BY id ASC
|
||||
""",
|
||||
(status_filter,),
|
||||
).fetchall()
|
||||
items: list[dict[str, Any]] = []
|
||||
for r in rows:
|
||||
items.append(
|
||||
{
|
||||
"id": int(r["id"]),
|
||||
"phone": r["phone"],
|
||||
"invite_code": r["invite_code"],
|
||||
"status": r["status"],
|
||||
"created_at": r["created_at"],
|
||||
"reviewed_at": r["reviewed_at"],
|
||||
"review_note": r["review_note"],
|
||||
}
|
||||
)
|
||||
return {"ok": True, "items": items}
|
||||
|
||||
|
||||
@app.post("/admin/requests/{request_id}/approve")
|
||||
def admin_approve(request_id: int, _admin: None = Depends(_admin_ok)):
|
||||
with _connect() as conn:
|
||||
_approve_request_with_conn(conn, request_id, note="approved by admin")
|
||||
return {"ok": True, "status": "approved", "request_id": request_id}
|
||||
|
||||
|
||||
@app.post("/admin/requests/{request_id}/reject")
|
||||
def admin_reject(request_id: int, note: str = "rejected by admin", _admin: None = Depends(_admin_ok)):
|
||||
with _connect() as conn:
|
||||
_reject_request_with_conn(conn, request_id, note=note)
|
||||
return {"ok": True, "status": "rejected", "request_id": request_id}
|
||||
Reference in New Issue
Block a user