dz1-spatial-query/stac-fastapi-pgstac/stac_fastapi/pgstac/db.py
weixin_46229132 5bc6302955 first commit
2025-07-03 20:29:02 +08:00

154 lines
4.4 KiB
Python

"""Database connection handling."""
import json
from contextlib import asynccontextmanager, contextmanager
from typing import (
AsyncIterator,
Callable,
Dict,
Generator,
List,
Literal,
Optional,
Union,
)
import orjson
from asyncpg import Connection, Pool, exceptions
from buildpg import V, asyncpg, render
from fastapi import FastAPI, HTTPException, Request
from stac_fastapi.types.errors import (
ConflictError,
DatabaseError,
ForeignKeyError,
NotFoundError,
)
from stac_fastapi.pgstac.config import PostgresSettings
async def con_init(conn):
"""Use orjson for json returns."""
await conn.set_type_codec(
"json",
encoder=orjson.dumps,
decoder=orjson.loads,
schema="pg_catalog",
)
await conn.set_type_codec(
"jsonb",
encoder=orjson.dumps,
decoder=orjson.loads,
schema="pg_catalog",
)
ConnectionGetter = Callable[[Request, Literal["r", "w"]], AsyncIterator[Connection]]
async def _create_pool(settings: PostgresSettings) -> Pool:
"""Create a connection pool."""
return await asyncpg.create_pool(
settings.connection_string,
min_size=settings.db_min_conn_size,
max_size=settings.db_max_conn_size,
max_queries=settings.db_max_queries,
max_inactive_connection_lifetime=settings.db_max_inactive_conn_lifetime,
init=con_init,
server_settings=settings.server_settings.model_dump(),
)
async def connect_to_db(
app: FastAPI,
get_conn: Optional[ConnectionGetter] = None,
postgres_settings: Optional[PostgresSettings] = None,
add_write_connection_pool: bool = False,
write_postgres_settings: Optional[PostgresSettings] = None,
) -> None:
"""Create connection pools & connection retriever on application."""
if not postgres_settings:
postgres_settings = PostgresSettings()
app.state.readpool = await _create_pool(postgres_settings)
if add_write_connection_pool:
if not write_postgres_settings:
write_postgres_settings = postgres_settings
app.state.writepool = await _create_pool(write_postgres_settings)
app.state.get_connection = get_conn if get_conn else get_connection
async def close_db_connection(app: FastAPI) -> None:
"""Close connection."""
await app.state.readpool.close()
if pool := getattr(app.state, "writepool", None):
await pool.close()
@asynccontextmanager
async def get_connection(
request: Request,
readwrite: Literal["r", "w"] = "r",
) -> AsyncIterator[Connection]:
"""Retrieve connection from database conection pool."""
pool = request.app.state.readpool
if readwrite == "w":
pool = getattr(request.app.state, "writepool", None)
if not pool:
raise HTTPException(
status_code=500,
detail="Could not find connection pool for write operations",
)
with translate_pgstac_errors():
async with pool.acquire() as conn:
yield conn
async def dbfunc(conn: Connection, func: str, arg: Union[str, Dict, List]):
"""Wrap PLPGSQL Functions.
Keyword arguments:
pool -- the asyncpg pool to use to connect to the database
func -- the name of the PostgreSQL function to call
arg -- the argument to the PostgreSQL function as either a string
or a dict that will be converted into jsonb
"""
with translate_pgstac_errors():
if isinstance(arg, str):
q, p = render(
"""
SELECT * FROM :func(:item::text);
""",
func=V(func),
item=arg,
)
return await conn.fetchval(q, *p)
else:
q, p = render(
"""
SELECT * FROM :func(:item::text::jsonb);
""",
func=V(func),
item=json.dumps(arg),
)
return await conn.fetchval(q, *p)
@contextmanager
def translate_pgstac_errors() -> Generator[None, None, None]:
"""Context manager that translates pgstac errors into FastAPI errors."""
try:
yield
except exceptions.UniqueViolationError as e:
raise ConflictError from e
except exceptions.NoDataFoundError as e:
raise NotFoundError from e
except exceptions.NotNullViolationError as e:
raise DatabaseError from e
except exceptions.ForeignKeyViolationError as e:
raise ForeignKeyError from e