dz1-spatial-query/stac-fastapi-pgstac/stac_fastapi/pgstac/transactions.py

261 lines
8.8 KiB
Python
Raw Normal View History

2025-07-03 20:29:02 +08:00
"""transactions extension client."""
import logging
import re
from typing import List, Optional, Union
import attr
from buildpg import render
from fastapi import HTTPException, Request
from stac_fastapi.extensions.core.transaction import AsyncBaseTransactionsClient
from stac_fastapi.extensions.core.transaction.request import (
PartialCollection,
PartialItem,
PatchOperation,
)
from stac_fastapi.extensions.third_party.bulk_transactions import (
AsyncBaseBulkTransactionsClient,
BulkTransactionMethod,
Items,
)
from stac_fastapi.types import stac as stac_types
from stac_pydantic import Collection, Item, ItemCollection
from starlette.responses import JSONResponse, Response
from stac_fastapi.pgstac.config import Settings
from stac_fastapi.pgstac.db import dbfunc
from stac_fastapi.pgstac.models.links import CollectionLinks, ItemLinks
logger = logging.getLogger("uvicorn")
logger.setLevel(logging.INFO)
class ClientValidateMixIn:
def _validate_id(self, id: str, settings: Settings):
invalid_chars = settings.invalid_id_chars
id_regex = "[" + "".join(re.escape(char) for char in invalid_chars) + "]"
if bool(re.search(id_regex, id)):
raise HTTPException(
status_code=400,
detail=f"ID ({id}) cannot contain the following characters: {' '.join(invalid_chars)}",
)
def _validate_collection(self, request: Request, collection: stac_types.Collection):
self._validate_id(collection["id"], request.app.state.settings)
def _validate_item(
self,
request: Request,
item: stac_types.Item,
collection_id: str,
expected_item_id: Optional[str] = None,
) -> None:
"""Validate item."""
body_collection_id = item.get("collection")
body_item_id = item.get("id")
self._validate_id(body_item_id, request.app.state.settings)
if item.get("geometry", None) is None:
raise HTTPException(
status_code=400,
detail=f"Missing or null `geometry` for Item ({body_item_id}). Geometry is required in pgstac.",
)
if body_collection_id is not None and collection_id != body_collection_id:
raise HTTPException(
status_code=400,
detail=f"Collection ID from path parameter ({collection_id}) does not match Collection ID from Item ({body_collection_id})",
)
if expected_item_id is not None and expected_item_id != body_item_id:
raise HTTPException(
status_code=400,
detail=f"Item ID from path parameter ({expected_item_id}) does not match Item ID from Item ({body_item_id})",
)
@attr.s
class TransactionsClient(AsyncBaseTransactionsClient, ClientValidateMixIn):
"""Transactions extension specific CRUD operations."""
async def create_item(
self,
collection_id: str,
item: Union[Item, ItemCollection],
request: Request,
**kwargs,
) -> Optional[Union[stac_types.Item, Response]]:
"""Create item."""
item = item.model_dump(mode="json")
if item["type"] == "FeatureCollection":
valid_items = []
for item in item["features"]: # noqa: B020
self._validate_item(request, item, collection_id)
item["collection"] = collection_id
valid_items.append(item)
async with request.app.state.get_connection(request, "w") as conn:
await dbfunc(conn, "create_items", valid_items)
return Response(status_code=201)
elif item["type"] == "Feature":
self._validate_item(request, item, collection_id)
item["collection"] = collection_id
async with request.app.state.get_connection(request, "w") as conn:
await dbfunc(conn, "create_item", item)
item["links"] = await ItemLinks(
collection_id=collection_id,
item_id=item["id"],
request=request,
).get_links(extra_links=item.get("links"))
return stac_types.Item(**item)
else:
raise HTTPException(
status_code=400,
detail=f"Item body type must be 'Feature' or 'FeatureCollection', not {item['type']}",
)
async def update_item(
self,
request: Request,
collection_id: str,
item_id: str,
item: Item,
**kwargs,
) -> Optional[Union[stac_types.Item, Response]]:
"""Update item."""
item = item.model_dump(mode="json")
self._validate_item(request, item, collection_id, item_id)
item["collection"] = collection_id
async with request.app.state.get_connection(request, "w") as conn:
await dbfunc(conn, "update_item", item)
item["links"] = await ItemLinks(
collection_id=collection_id,
item_id=item["id"],
request=request,
).get_links(extra_links=item.get("links"))
return stac_types.Item(**item)
async def create_collection(
self,
collection: Collection,
request: Request,
**kwargs,
) -> Optional[Union[stac_types.Collection, Response]]:
"""Create collection."""
collection = collection.model_dump(mode="json")
self._validate_collection(request, collection)
async with request.app.state.get_connection(request, "w") as conn:
await dbfunc(conn, "create_collection", collection)
collection["links"] = await CollectionLinks(
collection_id=collection["id"], request=request
).get_links(extra_links=collection["links"])
return stac_types.Collection(**collection)
async def update_collection(
self,
collection: Collection,
request: Request,
**kwargs,
) -> Optional[Union[stac_types.Collection, Response]]:
"""Update collection."""
col = collection.model_dump(mode="json")
async with request.app.state.get_connection(request, "w") as conn:
await dbfunc(conn, "update_collection", col)
col["links"] = await CollectionLinks(
collection_id=col["id"], request=request
).get_links(extra_links=col.get("links"))
return stac_types.Collection(**col)
async def delete_item(
self,
item_id: str,
collection_id: str,
request: Request,
**kwargs,
) -> Optional[Union[stac_types.Item, Response]]:
"""Delete item."""
q, p = render(
"SELECT * FROM delete_item(:item::text, :collection::text);",
item=item_id,
collection=collection_id,
)
async with request.app.state.get_connection(request, "w") as conn:
await conn.fetchval(q, *p)
return JSONResponse({"deleted item": item_id})
async def delete_collection(
self, collection_id: str, request: Request, **kwargs
) -> Optional[Union[stac_types.Collection, Response]]:
"""Delete collection."""
async with request.app.state.get_connection(request, "w") as conn:
await dbfunc(conn, "delete_collection", collection_id)
return JSONResponse({"deleted collection": collection_id})
async def patch_item(
self,
collection_id: str,
item_id: str,
patch: Union[PartialItem, List[PatchOperation]],
**kwargs,
) -> Optional[Union[stac_types.Item, Response]]:
"""Patch Item."""
raise NotImplementedError
async def patch_collection(
self,
collection_id: str,
patch: Union[PartialCollection, List[PatchOperation]],
**kwargs,
) -> Optional[Union[stac_types.Collection, Response]]:
"""Patch Collection."""
raise NotImplementedError
@attr.s
class BulkTransactionsClient(AsyncBaseBulkTransactionsClient, ClientValidateMixIn):
"""Postgres bulk transactions."""
async def bulk_item_insert(self, items: Items, request: Request, **kwargs) -> str:
"""Bulk item insertion using pgstac."""
collection_id = request.path_params["collection_id"]
for item_id, item in items.items.items():
self._validate_item(request, item, collection_id, item_id)
item["collection"] = collection_id
items_to_insert = list(items.items.values())
async with request.app.state.get_connection(request, "w") as conn:
if items.method == BulkTransactionMethod.INSERT:
method_verb = "added"
await dbfunc(conn, "create_items", items_to_insert)
elif items.method == BulkTransactionMethod.UPSERT:
method_verb = "upserted"
await dbfunc(conn, "upsert_items", items_to_insert)
return_msg = f"Successfully {method_verb} {len(items_to_insert)} items."
return return_msg