154 lines
4.4 KiB
Python
154 lines
4.4 KiB
Python
"""Database connection handling."""
|
|
|
|
import json
|
|
from contextlib import asynccontextmanager, contextmanager
|
|
from typing import (
|
|
AsyncIterator,
|
|
Callable,
|
|
Dict,
|
|
Generator,
|
|
List,
|
|
Literal,
|
|
Optional,
|
|
Union,
|
|
)
|
|
|
|
import orjson
|
|
from asyncpg import Connection, Pool, exceptions
|
|
from buildpg import V, asyncpg, render
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
from stac_fastapi.types.errors import (
|
|
ConflictError,
|
|
DatabaseError,
|
|
ForeignKeyError,
|
|
NotFoundError,
|
|
)
|
|
|
|
from stac_fastapi.pgstac.config import PostgresSettings
|
|
|
|
|
|
async def con_init(conn):
|
|
"""Use orjson for json returns."""
|
|
await conn.set_type_codec(
|
|
"json",
|
|
encoder=orjson.dumps,
|
|
decoder=orjson.loads,
|
|
schema="pg_catalog",
|
|
)
|
|
await conn.set_type_codec(
|
|
"jsonb",
|
|
encoder=orjson.dumps,
|
|
decoder=orjson.loads,
|
|
schema="pg_catalog",
|
|
)
|
|
|
|
|
|
ConnectionGetter = Callable[[Request, Literal["r", "w"]], AsyncIterator[Connection]]
|
|
|
|
|
|
async def _create_pool(settings: PostgresSettings) -> Pool:
|
|
"""Create a connection pool."""
|
|
return await asyncpg.create_pool(
|
|
settings.connection_string,
|
|
min_size=settings.db_min_conn_size,
|
|
max_size=settings.db_max_conn_size,
|
|
max_queries=settings.db_max_queries,
|
|
max_inactive_connection_lifetime=settings.db_max_inactive_conn_lifetime,
|
|
init=con_init,
|
|
server_settings=settings.server_settings.model_dump(),
|
|
)
|
|
|
|
|
|
async def connect_to_db(
|
|
app: FastAPI,
|
|
get_conn: Optional[ConnectionGetter] = None,
|
|
postgres_settings: Optional[PostgresSettings] = None,
|
|
add_write_connection_pool: bool = False,
|
|
write_postgres_settings: Optional[PostgresSettings] = None,
|
|
) -> None:
|
|
"""Create connection pools & connection retriever on application."""
|
|
if not postgres_settings:
|
|
postgres_settings = PostgresSettings()
|
|
|
|
app.state.readpool = await _create_pool(postgres_settings)
|
|
|
|
if add_write_connection_pool:
|
|
if not write_postgres_settings:
|
|
write_postgres_settings = postgres_settings
|
|
|
|
app.state.writepool = await _create_pool(write_postgres_settings)
|
|
|
|
app.state.get_connection = get_conn if get_conn else get_connection
|
|
|
|
|
|
async def close_db_connection(app: FastAPI) -> None:
|
|
"""Close connection."""
|
|
await app.state.readpool.close()
|
|
if pool := getattr(app.state, "writepool", None):
|
|
await pool.close()
|
|
|
|
|
|
@asynccontextmanager
|
|
async def get_connection(
|
|
request: Request,
|
|
readwrite: Literal["r", "w"] = "r",
|
|
) -> AsyncIterator[Connection]:
|
|
"""Retrieve connection from database conection pool."""
|
|
pool = request.app.state.readpool
|
|
if readwrite == "w":
|
|
pool = getattr(request.app.state, "writepool", None)
|
|
if not pool:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail="Could not find connection pool for write operations",
|
|
)
|
|
|
|
with translate_pgstac_errors():
|
|
async with pool.acquire() as conn:
|
|
yield conn
|
|
|
|
|
|
async def dbfunc(conn: Connection, func: str, arg: Union[str, Dict, List]):
|
|
"""Wrap PLPGSQL Functions.
|
|
|
|
Keyword arguments:
|
|
pool -- the asyncpg pool to use to connect to the database
|
|
func -- the name of the PostgreSQL function to call
|
|
arg -- the argument to the PostgreSQL function as either a string
|
|
or a dict that will be converted into jsonb
|
|
"""
|
|
with translate_pgstac_errors():
|
|
if isinstance(arg, str):
|
|
q, p = render(
|
|
"""
|
|
SELECT * FROM :func(:item::text);
|
|
""",
|
|
func=V(func),
|
|
item=arg,
|
|
)
|
|
return await conn.fetchval(q, *p)
|
|
else:
|
|
q, p = render(
|
|
"""
|
|
SELECT * FROM :func(:item::text::jsonb);
|
|
""",
|
|
func=V(func),
|
|
item=json.dumps(arg),
|
|
)
|
|
return await conn.fetchval(q, *p)
|
|
|
|
|
|
@contextmanager
|
|
def translate_pgstac_errors() -> Generator[None, None, None]:
|
|
"""Context manager that translates pgstac errors into FastAPI errors."""
|
|
try:
|
|
yield
|
|
except exceptions.UniqueViolationError as e:
|
|
raise ConflictError from e
|
|
except exceptions.NoDataFoundError as e:
|
|
raise NotFoundError from e
|
|
except exceptions.NotNullViolationError as e:
|
|
raise DatabaseError from e
|
|
except exceptions.ForeignKeyViolationError as e:
|
|
raise ForeignKeyError from e
|