dz1-spatial-query/stac-fastapi-pgstac/stac_fastapi/pgstac/transactions.py
weixin_46229132 5bc6302955 first commit
2025-07-03 20:29:02 +08:00

261 lines
8.8 KiB
Python

"""transactions extension client."""
import logging
import re
from typing import List, Optional, Union
import attr
from buildpg import render
from fastapi import HTTPException, Request
from stac_fastapi.extensions.core.transaction import AsyncBaseTransactionsClient
from stac_fastapi.extensions.core.transaction.request import (
PartialCollection,
PartialItem,
PatchOperation,
)
from stac_fastapi.extensions.third_party.bulk_transactions import (
AsyncBaseBulkTransactionsClient,
BulkTransactionMethod,
Items,
)
from stac_fastapi.types import stac as stac_types
from stac_pydantic import Collection, Item, ItemCollection
from starlette.responses import JSONResponse, Response
from stac_fastapi.pgstac.config import Settings
from stac_fastapi.pgstac.db import dbfunc
from stac_fastapi.pgstac.models.links import CollectionLinks, ItemLinks
logger = logging.getLogger("uvicorn")
logger.setLevel(logging.INFO)
class ClientValidateMixIn:
def _validate_id(self, id: str, settings: Settings):
invalid_chars = settings.invalid_id_chars
id_regex = "[" + "".join(re.escape(char) for char in invalid_chars) + "]"
if bool(re.search(id_regex, id)):
raise HTTPException(
status_code=400,
detail=f"ID ({id}) cannot contain the following characters: {' '.join(invalid_chars)}",
)
def _validate_collection(self, request: Request, collection: stac_types.Collection):
self._validate_id(collection["id"], request.app.state.settings)
def _validate_item(
self,
request: Request,
item: stac_types.Item,
collection_id: str,
expected_item_id: Optional[str] = None,
) -> None:
"""Validate item."""
body_collection_id = item.get("collection")
body_item_id = item.get("id")
self._validate_id(body_item_id, request.app.state.settings)
if item.get("geometry", None) is None:
raise HTTPException(
status_code=400,
detail=f"Missing or null `geometry` for Item ({body_item_id}). Geometry is required in pgstac.",
)
if body_collection_id is not None and collection_id != body_collection_id:
raise HTTPException(
status_code=400,
detail=f"Collection ID from path parameter ({collection_id}) does not match Collection ID from Item ({body_collection_id})",
)
if expected_item_id is not None and expected_item_id != body_item_id:
raise HTTPException(
status_code=400,
detail=f"Item ID from path parameter ({expected_item_id}) does not match Item ID from Item ({body_item_id})",
)
@attr.s
class TransactionsClient(AsyncBaseTransactionsClient, ClientValidateMixIn):
"""Transactions extension specific CRUD operations."""
async def create_item(
self,
collection_id: str,
item: Union[Item, ItemCollection],
request: Request,
**kwargs,
) -> Optional[Union[stac_types.Item, Response]]:
"""Create item."""
item = item.model_dump(mode="json")
if item["type"] == "FeatureCollection":
valid_items = []
for item in item["features"]: # noqa: B020
self._validate_item(request, item, collection_id)
item["collection"] = collection_id
valid_items.append(item)
async with request.app.state.get_connection(request, "w") as conn:
await dbfunc(conn, "create_items", valid_items)
return Response(status_code=201)
elif item["type"] == "Feature":
self._validate_item(request, item, collection_id)
item["collection"] = collection_id
async with request.app.state.get_connection(request, "w") as conn:
await dbfunc(conn, "create_item", item)
item["links"] = await ItemLinks(
collection_id=collection_id,
item_id=item["id"],
request=request,
).get_links(extra_links=item.get("links"))
return stac_types.Item(**item)
else:
raise HTTPException(
status_code=400,
detail=f"Item body type must be 'Feature' or 'FeatureCollection', not {item['type']}",
)
async def update_item(
self,
request: Request,
collection_id: str,
item_id: str,
item: Item,
**kwargs,
) -> Optional[Union[stac_types.Item, Response]]:
"""Update item."""
item = item.model_dump(mode="json")
self._validate_item(request, item, collection_id, item_id)
item["collection"] = collection_id
async with request.app.state.get_connection(request, "w") as conn:
await dbfunc(conn, "update_item", item)
item["links"] = await ItemLinks(
collection_id=collection_id,
item_id=item["id"],
request=request,
).get_links(extra_links=item.get("links"))
return stac_types.Item(**item)
async def create_collection(
self,
collection: Collection,
request: Request,
**kwargs,
) -> Optional[Union[stac_types.Collection, Response]]:
"""Create collection."""
collection = collection.model_dump(mode="json")
self._validate_collection(request, collection)
async with request.app.state.get_connection(request, "w") as conn:
await dbfunc(conn, "create_collection", collection)
collection["links"] = await CollectionLinks(
collection_id=collection["id"], request=request
).get_links(extra_links=collection["links"])
return stac_types.Collection(**collection)
async def update_collection(
self,
collection: Collection,
request: Request,
**kwargs,
) -> Optional[Union[stac_types.Collection, Response]]:
"""Update collection."""
col = collection.model_dump(mode="json")
async with request.app.state.get_connection(request, "w") as conn:
await dbfunc(conn, "update_collection", col)
col["links"] = await CollectionLinks(
collection_id=col["id"], request=request
).get_links(extra_links=col.get("links"))
return stac_types.Collection(**col)
async def delete_item(
self,
item_id: str,
collection_id: str,
request: Request,
**kwargs,
) -> Optional[Union[stac_types.Item, Response]]:
"""Delete item."""
q, p = render(
"SELECT * FROM delete_item(:item::text, :collection::text);",
item=item_id,
collection=collection_id,
)
async with request.app.state.get_connection(request, "w") as conn:
await conn.fetchval(q, *p)
return JSONResponse({"deleted item": item_id})
async def delete_collection(
self, collection_id: str, request: Request, **kwargs
) -> Optional[Union[stac_types.Collection, Response]]:
"""Delete collection."""
async with request.app.state.get_connection(request, "w") as conn:
await dbfunc(conn, "delete_collection", collection_id)
return JSONResponse({"deleted collection": collection_id})
async def patch_item(
self,
collection_id: str,
item_id: str,
patch: Union[PartialItem, List[PatchOperation]],
**kwargs,
) -> Optional[Union[stac_types.Item, Response]]:
"""Patch Item."""
raise NotImplementedError
async def patch_collection(
self,
collection_id: str,
patch: Union[PartialCollection, List[PatchOperation]],
**kwargs,
) -> Optional[Union[stac_types.Collection, Response]]:
"""Patch Collection."""
raise NotImplementedError
@attr.s
class BulkTransactionsClient(AsyncBaseBulkTransactionsClient, ClientValidateMixIn):
"""Postgres bulk transactions."""
async def bulk_item_insert(self, items: Items, request: Request, **kwargs) -> str:
"""Bulk item insertion using pgstac."""
collection_id = request.path_params["collection_id"]
for item_id, item in items.items.items():
self._validate_item(request, item, collection_id, item_id)
item["collection"] = collection_id
items_to_insert = list(items.items.values())
async with request.app.state.get_connection(request, "w") as conn:
if items.method == BulkTransactionMethod.INSERT:
method_verb = "added"
await dbfunc(conn, "create_items", items_to_insert)
elif items.method == BulkTransactionMethod.UPSERT:
method_verb = "upserted"
await dbfunc(conn, "upsert_items", items_to_insert)
return_msg = f"Successfully {method_verb} {len(items_to_insert)} items."
return return_msg