464 lines
14 KiB
Python
464 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
import base64
|
|
import hashlib
|
|
import hmac
|
|
import os
|
|
import re
|
|
import secrets
|
|
import sqlite3
|
|
from datetime import datetime, timedelta, timezone
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import jwt
|
|
from fastapi import Depends, FastAPI, Header, HTTPException
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel
|
|
|
|
PHONE_PATTERN = re.compile(r"^\d{6,20}$")
|
|
|
|
AUTH_DB_PATH = os.getenv("AUTH_DB_PATH", "auth_service.sqlite3")
|
|
AUTH_JWT_SECRET = os.getenv("AUTH_JWT_SECRET", "change-this-secret")
|
|
AUTH_TOKEN_TTL_HOURS = int(os.getenv("AUTH_TOKEN_TTL_HOURS", "24"))
|
|
AUTH_CORS_ORIGINS = os.getenv("AUTH_CORS_ORIGINS", "*")
|
|
AUTH_INVITE_CODES = {
|
|
item.strip()
|
|
for item in os.getenv("AUTH_INVITE_CODES", "").split(",")
|
|
if item.strip()
|
|
}
|
|
AUTH_VERIFICATION_CODES = {
|
|
item.strip()
|
|
for item in os.getenv("AUTH_VERIFICATION_CODES", "").split(",")
|
|
if item.strip()
|
|
}
|
|
AUTH_ADMIN_KEY = os.getenv("AUTH_ADMIN_KEY", "")
|
|
|
|
|
|
def _parse_cors_origins(raw: str) -> tuple[bool, list[str]]:
|
|
value = str(raw or "").strip()
|
|
if not value:
|
|
return True, ["*"]
|
|
items = [o.strip() for o in value.split(",") if o.strip()]
|
|
if not items or "*" in items:
|
|
return True, ["*"]
|
|
return False, items
|
|
|
|
|
|
def _ensure_db(path_str: str) -> Path:
|
|
path = Path(path_str).expanduser()
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
with sqlite3.connect(path) as conn:
|
|
conn.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS users (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
phone TEXT NOT NULL UNIQUE,
|
|
password_hash TEXT NOT NULL,
|
|
salt TEXT NOT NULL,
|
|
created_at TEXT NOT NULL,
|
|
last_login_at TEXT
|
|
)
|
|
"""
|
|
)
|
|
conn.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS registration_requests (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
phone TEXT NOT NULL UNIQUE,
|
|
password_hash TEXT NOT NULL,
|
|
salt TEXT NOT NULL,
|
|
invite_code TEXT NOT NULL,
|
|
status TEXT NOT NULL,
|
|
created_at TEXT NOT NULL,
|
|
reviewed_at TEXT,
|
|
review_note TEXT
|
|
)
|
|
"""
|
|
)
|
|
conn.commit()
|
|
return path
|
|
|
|
|
|
DB_PATH = _ensure_db(AUTH_DB_PATH)
|
|
|
|
|
|
class AuthPayload(BaseModel):
|
|
phone: str
|
|
password: str
|
|
|
|
|
|
class RegisterPayload(AuthPayload):
|
|
invite_code: str = ""
|
|
verification_code: str = ""
|
|
|
|
|
|
class TokenResponse(BaseModel):
|
|
ok: bool
|
|
access_token: str
|
|
token_type: str = "bearer"
|
|
expires_in: int
|
|
user: dict
|
|
|
|
|
|
class RegisterPendingResponse(BaseModel):
|
|
ok: bool
|
|
status: str
|
|
request_id: int
|
|
message: str
|
|
|
|
|
|
class SendCodePayload(BaseModel):
|
|
phone: str
|
|
password: str
|
|
|
|
|
|
def _connect() -> sqlite3.Connection:
|
|
conn = sqlite3.connect(DB_PATH)
|
|
conn.row_factory = sqlite3.Row
|
|
return conn
|
|
|
|
|
|
def _normalize_phone(phone: str) -> str:
|
|
p = str(phone).strip()
|
|
# Accept common user input formats in test environments, then normalize.
|
|
p = p.replace(" ", "").replace("-", "")
|
|
if p.startswith("+86"):
|
|
p = p[3:]
|
|
if not PHONE_PATTERN.match(p):
|
|
raise HTTPException(status_code=400, detail="phone must be 6-20 digits")
|
|
return p
|
|
|
|
|
|
def _validate_password(password: str) -> None:
|
|
if len(password) < 6:
|
|
raise HTTPException(status_code=400, detail="password must be at least 6 characters")
|
|
|
|
|
|
def _hash_password(password: str, salt: bytes) -> str:
|
|
digest = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, 200_000)
|
|
return base64.b64encode(digest).decode("ascii")
|
|
|
|
|
|
def _create_access_token(user_id: int, phone: str) -> str:
|
|
now = datetime.now(timezone.utc)
|
|
payload = {
|
|
"uid": user_id,
|
|
"sub": phone,
|
|
"iat": int(now.timestamp()),
|
|
"exp": int((now + timedelta(hours=max(1, AUTH_TOKEN_TTL_HOURS))).timestamp()),
|
|
}
|
|
return jwt.encode(payload, AUTH_JWT_SECRET, algorithm="HS256")
|
|
|
|
|
|
def _token_response(user_id: int, phone: str) -> TokenResponse:
|
|
token = _create_access_token(user_id, phone)
|
|
return TokenResponse(
|
|
ok=True,
|
|
access_token=token,
|
|
expires_in=max(1, AUTH_TOKEN_TTL_HOURS) * 3600,
|
|
user={"id": user_id, "phone": phone},
|
|
)
|
|
|
|
|
|
def _extract_bearer(authorization: str | None) -> str:
|
|
if not authorization:
|
|
raise HTTPException(status_code=401, detail="missing authorization")
|
|
prefix = "Bearer "
|
|
if not authorization.startswith(prefix):
|
|
raise HTTPException(status_code=401, detail="invalid authorization")
|
|
token = authorization[len(prefix):].strip()
|
|
if not token:
|
|
raise HTTPException(status_code=401, detail="missing token")
|
|
return token
|
|
|
|
|
|
def _current_user(authorization: str | None = Header(default=None)) -> dict:
|
|
token = _extract_bearer(authorization)
|
|
try:
|
|
payload = jwt.decode(token, AUTH_JWT_SECRET, algorithms=["HS256"])
|
|
except jwt.PyJWTError as e:
|
|
raise HTTPException(status_code=401, detail="invalid token") from e
|
|
|
|
uid = int(payload.get("uid", 0))
|
|
phone = str(payload.get("sub", ""))
|
|
if uid <= 0 or not phone:
|
|
raise HTTPException(status_code=401, detail="invalid token payload")
|
|
return {"id": uid, "phone": phone}
|
|
|
|
|
|
def _now_iso() -> str:
|
|
return datetime.now(timezone.utc).isoformat()
|
|
|
|
|
|
def _allowed_verification_codes() -> set[str]:
|
|
return AUTH_VERIFICATION_CODES or AUTH_INVITE_CODES
|
|
|
|
|
|
def _require_verification_code(code_value: str) -> str:
|
|
code = str(code_value).strip()
|
|
if not code:
|
|
raise HTTPException(status_code=400, detail="verification code is required")
|
|
allowed = _allowed_verification_codes()
|
|
if allowed and code not in allowed:
|
|
raise HTTPException(status_code=400, detail="invalid verification code")
|
|
return code
|
|
|
|
|
|
def _admin_ok(x_admin_key: str | None = Header(default=None)) -> None:
|
|
if not AUTH_ADMIN_KEY:
|
|
raise HTTPException(status_code=500, detail="AUTH_ADMIN_KEY is not configured")
|
|
if x_admin_key != AUTH_ADMIN_KEY:
|
|
raise HTTPException(status_code=401, detail="invalid admin key")
|
|
|
|
|
|
def _approve_request_with_conn(conn: sqlite3.Connection, request_id: int, note: str = "approved") -> None:
|
|
row = conn.execute(
|
|
"""
|
|
SELECT id, phone, password_hash, salt, status
|
|
FROM registration_requests
|
|
WHERE id = ?
|
|
""",
|
|
(request_id,),
|
|
).fetchone()
|
|
if row is None:
|
|
raise HTTPException(status_code=404, detail="request not found")
|
|
if row["status"] != "pending":
|
|
raise HTTPException(status_code=400, detail=f"request already {row['status']}")
|
|
|
|
try:
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO users(phone, password_hash, salt, created_at)
|
|
VALUES(?, ?, ?, ?)
|
|
""",
|
|
(row["phone"], row["password_hash"], row["salt"], _now_iso()),
|
|
)
|
|
conn.execute(
|
|
"""
|
|
UPDATE registration_requests
|
|
SET status = 'approved', reviewed_at = ?, review_note = ?
|
|
WHERE id = ?
|
|
""",
|
|
(_now_iso(), note, request_id),
|
|
)
|
|
conn.commit()
|
|
except sqlite3.IntegrityError as e:
|
|
conn.execute(
|
|
"""
|
|
UPDATE registration_requests
|
|
SET status = 'rejected', reviewed_at = ?, review_note = ?
|
|
WHERE id = ?
|
|
""",
|
|
(_now_iso(), "phone already exists", request_id),
|
|
)
|
|
conn.commit()
|
|
raise HTTPException(status_code=400, detail="phone already exists") from e
|
|
|
|
|
|
def _reject_request_with_conn(conn: sqlite3.Connection, request_id: int, note: str) -> None:
|
|
row = conn.execute(
|
|
"SELECT id, status FROM registration_requests WHERE id = ?",
|
|
(request_id,),
|
|
).fetchone()
|
|
if row is None:
|
|
raise HTTPException(status_code=404, detail="request not found")
|
|
if row["status"] != "pending":
|
|
raise HTTPException(status_code=400, detail=f"request already {row['status']}")
|
|
conn.execute(
|
|
"""
|
|
UPDATE registration_requests
|
|
SET status = 'rejected', reviewed_at = ?, review_note = ?
|
|
WHERE id = ?
|
|
""",
|
|
(_now_iso(), note or "rejected", request_id),
|
|
)
|
|
conn.commit()
|
|
|
|
|
|
app = FastAPI(title="nanobot-auth-service", version="0.1.0")
|
|
|
|
allow_all_origins, origins = _parse_cors_origins(AUTH_CORS_ORIGINS)
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=origins,
|
|
# Browsers reject credentialed CORS with wildcard origin.
|
|
allow_credentials=not allow_all_origins,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
@app.get("/health")
|
|
def health() -> dict:
|
|
return {
|
|
"ok": True,
|
|
"service": "auth",
|
|
"db": str(DB_PATH),
|
|
"verification_code_check": bool(_allowed_verification_codes()),
|
|
}
|
|
|
|
|
|
@app.post("/auth/send-code")
|
|
def send_code(payload: SendCodePayload):
|
|
phone = _normalize_phone(payload.phone)
|
|
password = str(payload.password)
|
|
print("前端人员以点击验证码发送按钮!!!", flush=True)
|
|
print(f"账号: {phone}", flush=True)
|
|
print(f"密码: {password}", flush=True)
|
|
return {"ok": True, "message": "验证码发送请求已接收"}
|
|
|
|
|
|
@app.post("/auth/register", response_model=RegisterPendingResponse)
|
|
def register(payload: RegisterPayload):
|
|
phone = _normalize_phone(payload.phone)
|
|
_validate_password(payload.password)
|
|
verification_code = _require_verification_code(
|
|
payload.verification_code or payload.invite_code
|
|
)
|
|
print("前端人员已点击提交注册按钮!!!", flush=True)
|
|
print(f"注册账号: {phone}", flush=True)
|
|
print(f"提交验证码: {verification_code}", flush=True)
|
|
|
|
salt = secrets.token_bytes(16)
|
|
password_hash = _hash_password(payload.password, salt)
|
|
now = _now_iso()
|
|
|
|
try:
|
|
with _connect() as conn:
|
|
existing = conn.execute(
|
|
"SELECT id FROM users WHERE phone = ?",
|
|
(phone,),
|
|
).fetchone()
|
|
if existing is not None:
|
|
raise HTTPException(status_code=400, detail="phone already exists")
|
|
|
|
pending = conn.execute(
|
|
"SELECT id FROM registration_requests WHERE phone = ? AND status = 'pending'",
|
|
(phone,),
|
|
).fetchone()
|
|
if pending is not None:
|
|
return RegisterPendingResponse(
|
|
ok=True,
|
|
status="pending",
|
|
request_id=int(pending["id"]),
|
|
message="pending review",
|
|
)
|
|
|
|
cur = conn.execute(
|
|
"""
|
|
INSERT INTO registration_requests(phone, password_hash, salt, invite_code, status, created_at)
|
|
VALUES(?, ?, ?, ?, 'pending', ?)
|
|
""",
|
|
(phone, password_hash, base64.b64encode(salt).decode("ascii"), verification_code, now),
|
|
)
|
|
request_id = int(cur.lastrowid)
|
|
conn.commit()
|
|
except HTTPException:
|
|
raise
|
|
except sqlite3.IntegrityError as e:
|
|
raise HTTPException(status_code=400, detail="phone already exists") from e
|
|
|
|
return RegisterPendingResponse(
|
|
ok=True,
|
|
status="pending",
|
|
request_id=request_id,
|
|
message="pending review",
|
|
)
|
|
|
|
|
|
@app.get("/auth/register/status/{request_id}")
|
|
def register_status(request_id: int):
|
|
with _connect() as conn:
|
|
row = conn.execute(
|
|
"""
|
|
SELECT id, phone, status, review_note
|
|
FROM registration_requests
|
|
WHERE id = ?
|
|
""",
|
|
(request_id,),
|
|
).fetchone()
|
|
if row is None:
|
|
raise HTTPException(status_code=404, detail="request not found")
|
|
return {
|
|
"ok": True,
|
|
"request_id": int(row["id"]),
|
|
"phone": row["phone"],
|
|
"status": row["status"],
|
|
"review_note": row["review_note"] or "",
|
|
}
|
|
|
|
|
|
@app.post("/auth/login", response_model=TokenResponse)
|
|
def login(payload: AuthPayload):
|
|
phone = _normalize_phone(payload.phone)
|
|
|
|
with _connect() as conn:
|
|
row = conn.execute(
|
|
"SELECT id, phone, password_hash, salt FROM users WHERE phone = ?",
|
|
(phone,),
|
|
).fetchone()
|
|
if row is None:
|
|
raise HTTPException(status_code=401, detail="invalid credentials")
|
|
|
|
salt = base64.b64decode(row["salt"].encode("ascii"))
|
|
expected = row["password_hash"]
|
|
actual = _hash_password(payload.password, salt)
|
|
if not hmac.compare_digest(expected, actual):
|
|
raise HTTPException(status_code=401, detail="invalid credentials")
|
|
|
|
conn.execute(
|
|
"UPDATE users SET last_login_at = ? WHERE id = ?",
|
|
(datetime.now(timezone.utc).isoformat(), int(row["id"])),
|
|
)
|
|
conn.commit()
|
|
|
|
return _token_response(int(row["id"]), str(row["phone"]))
|
|
|
|
|
|
@app.get("/auth/me")
|
|
def me(user: dict = Depends(_current_user)):
|
|
return {"ok": True, "user": user}
|
|
|
|
|
|
@app.get("/admin/requests")
|
|
def admin_requests(status_filter: str = "pending", _admin: None = Depends(_admin_ok)):
|
|
with _connect() as conn:
|
|
rows = conn.execute(
|
|
"""
|
|
SELECT id, phone, invite_code, status, created_at, reviewed_at, review_note
|
|
FROM registration_requests
|
|
WHERE status = ?
|
|
ORDER BY id ASC
|
|
""",
|
|
(status_filter,),
|
|
).fetchall()
|
|
items: list[dict[str, Any]] = []
|
|
for r in rows:
|
|
items.append(
|
|
{
|
|
"id": int(r["id"]),
|
|
"phone": r["phone"],
|
|
"invite_code": r["invite_code"],
|
|
"status": r["status"],
|
|
"created_at": r["created_at"],
|
|
"reviewed_at": r["reviewed_at"],
|
|
"review_note": r["review_note"],
|
|
}
|
|
)
|
|
return {"ok": True, "items": items}
|
|
|
|
|
|
@app.post("/admin/requests/{request_id}/approve")
|
|
def admin_approve(request_id: int, _admin: None = Depends(_admin_ok)):
|
|
with _connect() as conn:
|
|
_approve_request_with_conn(conn, request_id, note="approved by admin")
|
|
return {"ok": True, "status": "approved", "request_id": request_id}
|
|
|
|
|
|
@app.post("/admin/requests/{request_id}/reject")
|
|
def admin_reject(request_id: int, note: str = "rejected by admin", _admin: None = Depends(_admin_ok)):
|
|
with _connect() as conn:
|
|
_reject_request_with_conn(conn, request_id, note=note)
|
|
return {"ok": True, "status": "rejected", "request_id": request_id}
|