first commit

This commit is contained in:
龙澳
2026-03-23 14:32:15 +08:00
parent 7045866ef3
commit 9ce8858562
3 changed files with 26 additions and 13 deletions

View File

@@ -27,6 +27,11 @@ AUTH_INVITE_CODES = {
for item in os.getenv("AUTH_INVITE_CODES", "").split(",")
if item.strip()
}
AUTH_VERIFICATION_CODES = {
item.strip()
for item in os.getenv("AUTH_VERIFICATION_CODES", "").split(",")
if item.strip()
}
AUTH_ADMIN_KEY = os.getenv("AUTH_ADMIN_KEY", "")
@@ -74,7 +79,8 @@ class AuthPayload(BaseModel):
class RegisterPayload(AuthPayload):
invite_code: str
invite_code: str = ""
verification_code: str = ""
class TokenResponse(BaseModel):
@@ -170,12 +176,17 @@ def _now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
def _require_invite_code(invite_code: str) -> str:
code = str(invite_code).strip()
def _allowed_verification_codes() -> set[str]:
return AUTH_VERIFICATION_CODES or AUTH_INVITE_CODES
def _require_verification_code(code_value: str) -> str:
code = str(code_value).strip()
if not code:
raise HTTPException(status_code=400, detail="invite code is required")
if AUTH_INVITE_CODES and code not in AUTH_INVITE_CODES:
raise HTTPException(status_code=400, detail="invalid invite code")
raise HTTPException(status_code=400, detail="verification code is required")
allowed = _allowed_verification_codes()
if allowed and code not in allowed:
raise HTTPException(status_code=400, detail="invalid verification code")
return code
@@ -268,7 +279,7 @@ def health() -> dict:
"ok": True,
"service": "auth",
"db": str(DB_PATH),
"invite_code_check": bool(AUTH_INVITE_CODES),
"verification_code_check": bool(_allowed_verification_codes()),
}
@@ -276,7 +287,9 @@ def health() -> dict:
def register(payload: RegisterPayload):
phone = _normalize_phone(payload.phone)
_validate_password(payload.password)
invite_code = _require_invite_code(payload.invite_code)
verification_code = _require_verification_code(
payload.verification_code or payload.invite_code
)
salt = secrets.token_bytes(16)
password_hash = _hash_password(payload.password, salt)
@@ -308,7 +321,7 @@ def register(payload: RegisterPayload):
INSERT INTO registration_requests(phone, password_hash, salt, invite_code, status, created_at)
VALUES(?, ?, ?, ?, 'pending', ?)
""",
(phone, password_hash, base64.b64encode(salt).decode("ascii"), invite_code, now),
(phone, password_hash, base64.b64encode(salt).decode("ascii"), verification_code, now),
)
request_id = int(cur.lastrowid)
conn.commit()