"""transactions extension client.""" import logging import re from typing import List, Optional, Union import attr from buildpg import render from fastapi import HTTPException, Request from stac_fastapi.extensions.core.transaction import AsyncBaseTransactionsClient from stac_fastapi.extensions.core.transaction.request import ( PartialCollection, PartialItem, PatchOperation, ) from stac_fastapi.extensions.third_party.bulk_transactions import ( AsyncBaseBulkTransactionsClient, BulkTransactionMethod, Items, ) from stac_fastapi.types import stac as stac_types from stac_pydantic import Collection, Item, ItemCollection from starlette.responses import JSONResponse, Response from stac_fastapi.pgstac.config import Settings from stac_fastapi.pgstac.db import dbfunc from stac_fastapi.pgstac.models.links import CollectionLinks, ItemLinks logger = logging.getLogger("uvicorn") logger.setLevel(logging.INFO) class ClientValidateMixIn: def _validate_id(self, id: str, settings: Settings): invalid_chars = settings.invalid_id_chars id_regex = "[" + "".join(re.escape(char) for char in invalid_chars) + "]" if bool(re.search(id_regex, id)): raise HTTPException( status_code=400, detail=f"ID ({id}) cannot contain the following characters: {' '.join(invalid_chars)}", ) def _validate_collection(self, request: Request, collection: stac_types.Collection): self._validate_id(collection["id"], request.app.state.settings) def _validate_item( self, request: Request, item: stac_types.Item, collection_id: str, expected_item_id: Optional[str] = None, ) -> None: """Validate item.""" body_collection_id = item.get("collection") body_item_id = item.get("id") self._validate_id(body_item_id, request.app.state.settings) if item.get("geometry", None) is None: raise HTTPException( status_code=400, detail=f"Missing or null `geometry` for Item ({body_item_id}). Geometry is required in pgstac.", ) if body_collection_id is not None and collection_id != body_collection_id: raise HTTPException( status_code=400, detail=f"Collection ID from path parameter ({collection_id}) does not match Collection ID from Item ({body_collection_id})", ) if expected_item_id is not None and expected_item_id != body_item_id: raise HTTPException( status_code=400, detail=f"Item ID from path parameter ({expected_item_id}) does not match Item ID from Item ({body_item_id})", ) @attr.s class TransactionsClient(AsyncBaseTransactionsClient, ClientValidateMixIn): """Transactions extension specific CRUD operations.""" async def create_item( self, collection_id: str, item: Union[Item, ItemCollection], request: Request, **kwargs, ) -> Optional[Union[stac_types.Item, Response]]: """Create item.""" item = item.model_dump(mode="json") if item["type"] == "FeatureCollection": valid_items = [] for item in item["features"]: # noqa: B020 self._validate_item(request, item, collection_id) item["collection"] = collection_id valid_items.append(item) async with request.app.state.get_connection(request, "w") as conn: await dbfunc(conn, "create_items", valid_items) return Response(status_code=201) elif item["type"] == "Feature": self._validate_item(request, item, collection_id) item["collection"] = collection_id async with request.app.state.get_connection(request, "w") as conn: await dbfunc(conn, "create_item", item) item["links"] = await ItemLinks( collection_id=collection_id, item_id=item["id"], request=request, ).get_links(extra_links=item.get("links")) return stac_types.Item(**item) else: raise HTTPException( status_code=400, detail=f"Item body type must be 'Feature' or 'FeatureCollection', not {item['type']}", ) async def update_item( self, request: Request, collection_id: str, item_id: str, item: Item, **kwargs, ) -> Optional[Union[stac_types.Item, Response]]: """Update item.""" item = item.model_dump(mode="json") self._validate_item(request, item, collection_id, item_id) item["collection"] = collection_id async with request.app.state.get_connection(request, "w") as conn: await dbfunc(conn, "update_item", item) item["links"] = await ItemLinks( collection_id=collection_id, item_id=item["id"], request=request, ).get_links(extra_links=item.get("links")) return stac_types.Item(**item) async def create_collection( self, collection: Collection, request: Request, **kwargs, ) -> Optional[Union[stac_types.Collection, Response]]: """Create collection.""" collection = collection.model_dump(mode="json") self._validate_collection(request, collection) async with request.app.state.get_connection(request, "w") as conn: await dbfunc(conn, "create_collection", collection) collection["links"] = await CollectionLinks( collection_id=collection["id"], request=request ).get_links(extra_links=collection["links"]) return stac_types.Collection(**collection) async def update_collection( self, collection: Collection, request: Request, **kwargs, ) -> Optional[Union[stac_types.Collection, Response]]: """Update collection.""" col = collection.model_dump(mode="json") async with request.app.state.get_connection(request, "w") as conn: await dbfunc(conn, "update_collection", col) col["links"] = await CollectionLinks( collection_id=col["id"], request=request ).get_links(extra_links=col.get("links")) return stac_types.Collection(**col) async def delete_item( self, item_id: str, collection_id: str, request: Request, **kwargs, ) -> Optional[Union[stac_types.Item, Response]]: """Delete item.""" q, p = render( "SELECT * FROM delete_item(:item::text, :collection::text);", item=item_id, collection=collection_id, ) async with request.app.state.get_connection(request, "w") as conn: await conn.fetchval(q, *p) return JSONResponse({"deleted item": item_id}) async def delete_collection( self, collection_id: str, request: Request, **kwargs ) -> Optional[Union[stac_types.Collection, Response]]: """Delete collection.""" async with request.app.state.get_connection(request, "w") as conn: await dbfunc(conn, "delete_collection", collection_id) return JSONResponse({"deleted collection": collection_id}) async def patch_item( self, collection_id: str, item_id: str, patch: Union[PartialItem, List[PatchOperation]], **kwargs, ) -> Optional[Union[stac_types.Item, Response]]: """Patch Item.""" raise NotImplementedError async def patch_collection( self, collection_id: str, patch: Union[PartialCollection, List[PatchOperation]], **kwargs, ) -> Optional[Union[stac_types.Collection, Response]]: """Patch Collection.""" raise NotImplementedError @attr.s class BulkTransactionsClient(AsyncBaseBulkTransactionsClient, ClientValidateMixIn): """Postgres bulk transactions.""" async def bulk_item_insert(self, items: Items, request: Request, **kwargs) -> str: """Bulk item insertion using pgstac.""" collection_id = request.path_params["collection_id"] for item_id, item in items.items.items(): self._validate_item(request, item, collection_id, item_id) item["collection"] = collection_id items_to_insert = list(items.items.values()) async with request.app.state.get_connection(request, "w") as conn: if items.method == BulkTransactionMethod.INSERT: method_verb = "added" await dbfunc(conn, "create_items", items_to_insert) elif items.method == BulkTransactionMethod.UPSERT: method_verb = "upserted" await dbfunc(conn, "upsert_items", items_to_insert) return_msg = f"Successfully {method_verb} {len(items_to_insert)} items." return return_msg