dz1-spatial-query/stac-fastapi-pgstac/stac_fastapi/pgstac/core.py

654 lines
22 KiB
Python
Raw Normal View History

2025-07-03 20:29:02 +08:00
"""Item crud client."""
import json
import re
from typing import Any, Dict, List, Optional, Set, Type, Union
from urllib.parse import unquote_plus, urljoin
import attr
import orjson
from asyncpg.exceptions import InvalidDatetimeFormatError
from buildpg import render
from cql2 import Expr
from fastapi import HTTPException, Request
from pydantic import ValidationError
from pypgstac.hydration import hydrate
from stac_fastapi.api.models import JSONResponse
from stac_fastapi.types.core import AsyncBaseCoreClient, Relations
from stac_fastapi.types.errors import InvalidQueryParameter, NotFoundError
from stac_fastapi.types.requests import get_base_url
from stac_fastapi.types.stac import Collection, Collections, Item, ItemCollection
from stac_pydantic.shared import BBox, MimeTypes
from stac_fastapi.pgstac.config import Settings
from stac_fastapi.pgstac.models.links import (
CollectionLinks,
CollectionSearchPagingLinks,
ItemCollectionLinks,
ItemLinks,
PagingLinks,
SearchLinks,
)
from stac_fastapi.pgstac.types.search import PgstacSearch
from stac_fastapi.pgstac.utils import filter_fields
NumType = Union[float, int]
@attr.s
class CoreCrudClient(AsyncBaseCoreClient):
"""Client for core endpoints defined by stac."""
pgstac_search_model: Type[PgstacSearch] = attr.ib(default=PgstacSearch)
async def all_collections( # noqa: C901
self,
request: Request,
# Extensions
bbox: Optional[BBox] = None,
datetime: Optional[str] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
query: Optional[str] = None,
fields: Optional[List[str]] = None,
sortby: Optional[str] = None,
filter_expr: Optional[str] = None,
filter_lang: Optional[str] = None,
q: Optional[List[str]] = None,
**kwargs,
) -> Collections:
"""Cross catalog search (GET).
Called with `GET /collections`.
Returns:
Collections which match the search criteria, returns all
collections by default.
"""
base_url = get_base_url(request)
next_link: Optional[Dict[str, Any]] = None
prev_link: Optional[Dict[str, Any]] = None
collections_result: Collections
if self.extension_is_enabled("CollectionSearchExtension"):
base_args = {
"bbox": bbox,
"limit": limit,
"offset": offset,
"query": orjson.loads(unquote_plus(query)) if query else query,
}
clean_args = self._clean_search_args(
base_args=base_args,
datetime=datetime,
fields=fields,
sortby=sortby,
filter_query=filter_expr,
filter_lang=filter_lang,
q=q,
)
async with request.app.state.get_connection(request, "r") as conn:
q, p = render(
"""
SELECT * FROM collection_search(:req::text::jsonb);
""",
req=json.dumps(clean_args),
)
collections_result = await conn.fetchval(q, *p)
if links := collections_result.get("links"):
for link in links:
if link["rel"] == "next":
next_link = link
elif link["rel"] == "prev":
prev_link = link
else:
async with request.app.state.get_connection(request, "r") as conn:
cols = await conn.fetchval(
"""
SELECT * FROM all_collections();
"""
)
collections_result = {"collections": cols, "links": []}
linked_collections: List[Collection] = []
collections = collections_result["collections"]
if collections is not None and len(collections) > 0:
for c in collections:
coll = Collection(**c)
coll["links"] = await CollectionLinks(
collection_id=coll["id"], request=request
).get_links(extra_links=coll.get("links"))
if self.extension_is_enabled(
"FilterExtension"
) or self.extension_is_enabled("ItemCollectionFilterExtension"):
coll["links"].append(
{
"rel": Relations.queryables.value,
"type": MimeTypes.jsonschema.value,
"title": "Queryables",
"href": urljoin(
base_url, f"collections/{coll['id']}/queryables"
),
}
)
linked_collections.append(coll)
links = await CollectionSearchPagingLinks(
request=request,
next=next_link,
prev=prev_link,
).get_links()
return Collections(
collections=linked_collections or [],
links=links,
numberMatched=collections_result.get(
"numberMatched", len(linked_collections)
),
numberReturned=collections_result.get(
"numberReturned", len(linked_collections)
),
)
async def get_collection(
self, collection_id: str, request: Request, **kwargs
) -> Collection:
"""Get collection by id.
Called with `GET /collections/{collection_id}`.
Args:
collection_id: ID of the collection.
Returns:
Collection.
"""
collection: Optional[Dict[str, Any]]
async with request.app.state.get_connection(request, "r") as conn:
q, p = render(
"""
SELECT * FROM get_collection(:id::text);
""",
id=collection_id,
)
collection = await conn.fetchval(q, *p)
if collection is None:
raise NotFoundError(f"Collection {collection_id} does not exist.")
collection["links"] = await CollectionLinks(
collection_id=collection_id, request=request
).get_links(extra_links=collection.get("links"))
if self.extension_is_enabled("FilterExtension") or self.extension_is_enabled(
"ItemCollectionFilterExtension"
):
base_url = get_base_url(request)
collection["links"].append(
{
"rel": Relations.queryables.value,
"type": MimeTypes.jsonschema.value,
"title": "Queryables",
"href": urljoin(base_url, f"collections/{collection_id}/queryables"),
}
)
return Collection(**collection)
async def _get_base_item(
self, collection_id: str, request: Request
) -> Dict[str, Any]:
"""Get the base item of a collection for use in rehydrating full item collection properties.
Args:
collection_id: ID of the collection.
Returns:
Item.
"""
item: Optional[Dict[str, Any]]
async with request.app.state.get_connection(request, "r") as conn:
q, p = render(
"""
SELECT * FROM collection_base_item(:collection_id::text);
""",
collection_id=collection_id,
)
item = await conn.fetchval(q, *p)
if item is None:
raise NotFoundError(f"A base item for {collection_id} does not exist.")
return item
async def _search_base( # noqa: C901
self,
search_request: PgstacSearch,
request: Request,
) -> ItemCollection:
"""Cross catalog search (POST).
Called with `POST /search`.
Args:
search_request: search request parameters.
Returns:
ItemCollection containing items which match the search criteria.
"""
items: Dict[str, Any]
settings: Settings = request.app.state.settings
search_request.conf = search_request.conf or {}
search_request.conf["nohydrate"] = settings.use_api_hydrate
search_request_json = search_request.model_dump_json(
exclude_none=True, by_alias=True
)
try:
async with request.app.state.get_connection(request, "r") as conn:
q, p = render(
"""
SELECT * FROM search(:req::text::jsonb);
""",
req=search_request_json,
)
items = await conn.fetchval(q, *p)
except InvalidDatetimeFormatError as e:
raise InvalidQueryParameter(
f"Datetime parameter {search_request.datetime} is invalid."
) from e
# Starting in pgstac 0.9.0, the `next` and `prev` tokens are returned in spec-compliant links with method GET
next_from_link: Optional[str] = None
prev_from_link: Optional[str] = None
for link in items.get("links", []):
if link.get("rel") == "next":
next_from_link = link.get("href").split("token=next:")[1]
if link.get("rel") == "prev":
prev_from_link = link.get("href").split("token=prev:")[1]
next: Optional[str] = items.pop("next", next_from_link)
prev: Optional[str] = items.pop("prev", prev_from_link)
collection = ItemCollection(**items)
fields = getattr(search_request, "fields", None)
include: Set[str] = fields.include if fields and fields.include else set()
exclude: Set[str] = fields.exclude if fields and fields.exclude else set()
async def _add_item_links(
feature: Item,
collection_id: Optional[str] = None,
item_id: Optional[str] = None,
) -> None:
"""Add ItemLinks to the Item.
If the fields extension is excluding links, then don't add them.
Also skip links if the item doesn't provide collection and item ids.
"""
collection_id = feature.get("collection") or collection_id
item_id = feature.get("id") or item_id
if not exclude or "links" not in exclude and all([collection_id, item_id]):
feature["links"] = await ItemLinks(
collection_id=collection_id, # type: ignore
item_id=item_id, # type: ignore
request=request,
).get_links(extra_links=feature.get("links"))
cleaned_features: List[Item] = []
if settings.use_api_hydrate:
async def _get_base_item(collection_id: str) -> Dict[str, Any]:
return await self._get_base_item(collection_id, request=request)
base_item_cache = settings.base_item_cache(
fetch_base_item=_get_base_item, request=request
)
for feature in collection.get("features") or []:
base_item = await base_item_cache.get(feature.get("collection"))
# Exclude None values
base_item = {k: v for k, v in base_item.items() if v is not None}
feature = hydrate(base_item, feature)
# Grab ids needed for links that may be removed by the fields extension.
collection_id = feature.get("collection")
item_id = feature.get("id")
feature = filter_fields(feature, include, exclude)
await _add_item_links(feature, collection_id, item_id)
cleaned_features.append(feature)
else:
for feature in collection.get("features") or []:
await _add_item_links(feature)
cleaned_features.append(feature)
collection["features"] = cleaned_features
collection["links"] = await PagingLinks(
request=request,
next=next,
prev=prev,
).get_links()
return collection
async def item_collection(
self,
collection_id: str,
request: Request,
bbox: Optional[BBox] = None,
datetime: Optional[str] = None,
limit: Optional[int] = None,
# Extensions
query: Optional[str] = None,
fields: Optional[List[str]] = None,
sortby: Optional[str] = None,
filter_expr: Optional[str] = None,
filter_lang: Optional[str] = None,
token: Optional[str] = None,
**kwargs,
) -> ItemCollection:
"""Get all items from a specific collection.
Called with `GET /collections/{collection_id}/items`
Args:
collection_id: id of the collection.
limit: number of items to return.
token: pagination token.
Returns:
An ItemCollection.
"""
# If collection does not exist, NotFoundError wil be raised
await self.get_collection(collection_id, request=request)
base_args = {
"collections": [collection_id],
"bbox": bbox,
"datetime": datetime,
"limit": limit,
"token": token,
"query": orjson.loads(unquote_plus(query)) if query else query,
}
clean = self._clean_search_args(
base_args=base_args,
filter_query=filter_expr,
filter_lang=filter_lang,
fields=fields,
sortby=sortby,
)
try:
search_request = self.pgstac_search_model(**clean)
except ValidationError as e:
raise HTTPException(
status_code=400, detail=f"Invalid parameters provided {e}"
) from e
item_collection = await self._search_base(search_request, request=request)
links = await ItemCollectionLinks(
collection_id=collection_id, request=request
).get_links(extra_links=item_collection["links"])
item_collection["links"] = links
# If we have the `fields` extension enabled
# we need to avoid Pydantic validation because the
# Items might not be a valid STAC Item objects
if fields := getattr(search_request, "fields", None):
if fields.include or fields.exclude:
return JSONResponse(item_collection) # type: ignore
return ItemCollection(**item_collection)
async def get_item(
self, item_id: str, collection_id: str, request: Request, **kwargs
) -> Item:
"""Get item by id.
Called with `GET /collections/{collection_id}/items/{item_id}`.
Args:
item_id: ID of the item.
collection_id: ID of the collection the item is in.
Returns:
Item.
"""
# If collection does not exist, NotFoundError wil be raised
await self.get_collection(collection_id, request=request)
search_request = self.pgstac_search_model(
ids=[item_id], collections=[collection_id], limit=1
)
item_collection = await self._search_base(search_request, request=request)
if not item_collection["features"]:
raise NotFoundError(
f"Item {item_id} in Collection {collection_id} does not exist."
)
return Item(**item_collection["features"][0])
async def post_search(
self, search_request: PgstacSearch, request: Request, **kwargs
) -> ItemCollection:
"""Cross catalog search (POST).
Called with `POST /search`.
Args:
search_request: search request parameters.
Returns:
ItemCollection containing items which match the search criteria.
"""
item_collection = await self._search_base(search_request, request=request)
# If we have the `fields` extension enabled
# we need to avoid Pydantic validation because the
# Items might not be a valid STAC Item objects
if fields := getattr(search_request, "fields", None):
if fields.include or fields.exclude:
return JSONResponse(item_collection) # type: ignore
links = await SearchLinks(request=request).get_links(
extra_links=item_collection["links"]
)
item_collection["links"] = links
return ItemCollection(**item_collection)
async def get_search(
self,
request: Request,
collections: Optional[List[str]] = None,
ids: Optional[List[str]] = None,
bbox: Optional[BBox] = None,
intersects: Optional[str] = None,
datetime: Optional[str] = None,
limit: Optional[int] = None,
# Extensions
query: Optional[str] = None,
fields: Optional[List[str]] = None,
sortby: Optional[str] = None,
filter_expr: Optional[str] = None,
filter_lang: Optional[str] = None,
token: Optional[str] = None,
**kwargs,
) -> ItemCollection:
"""Cross catalog search (GET).
Called with `GET /search`.
Returns:
ItemCollection containing items which match the search criteria.
"""
# Parse request parameters
base_args = {
"collections": collections,
"ids": ids,
"bbox": bbox,
"limit": limit,
"token": token,
"query": orjson.loads(unquote_plus(query)) if query else query,
}
clean = self._clean_search_args(
base_args=base_args,
intersects=intersects,
datetime=datetime,
fields=fields,
sortby=sortby,
filter_query=filter_expr,
filter_lang=filter_lang,
)
try:
search_request = self.pgstac_search_model(**clean)
except ValidationError as e:
raise HTTPException(
status_code=400, detail=f"Invalid parameters provided {e}"
) from e
item_collection = await self._search_base(search_request, request=request)
links = await SearchLinks(request=request).get_links(
extra_links=item_collection["links"]
)
item_collection["links"] = links
# If we have the `fields` extension enabled
# we need to avoid Pydantic validation because the
# Items might not be a valid STAC Item objects
if fields := getattr(search_request, "fields", None):
if fields.include or fields.exclude:
return JSONResponse(item_collection) # type: ignore
return ItemCollection(**item_collection)
def _clean_search_args( # noqa: C901
self,
base_args: Dict[str, Any],
intersects: Optional[str] = None,
datetime: Optional[str] = None,
fields: Optional[List[str]] = None,
sortby: Optional[str] = None,
filter_query: Optional[str] = None,
filter_lang: Optional[str] = None,
q: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""Clean up search arguments to match format expected by pgstac"""
if filter_query:
if filter_lang == "cql2-text":
e = Expr(filter_query)
base_args["filter"] = e.to_json()
base_args["filter_lang"] = "cql2-json"
else:
base_args["filter"] = orjson.loads(filter_query)
base_args["filter_lang"] = filter_lang
if datetime:
base_args["datetime"] = datetime
if intersects:
base_args["intersects"] = orjson.loads(unquote_plus(intersects))
if sortby:
# https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form
sort_param = []
for sort in sortby:
sortparts = re.match(r"^([+-]?)(.*)$", sort)
if sortparts:
sort_param.append(
{
"field": sortparts.group(2).strip(),
"direction": "desc" if sortparts.group(1) == "-" else "asc",
}
)
base_args["sortby"] = sort_param
if fields:
includes = set()
excludes = set()
for field in fields:
if field[0] == "-":
excludes.add(field[1:])
elif field[0] == "+":
includes.add(field[1:])
else:
includes.add(field)
base_args["fields"] = {"include": includes, "exclude": excludes}
if q:
base_args["q"] = " OR ".join(q)
# Remove None values from dict
clean = {}
for k, v in base_args.items():
if v is not None and v != []:
clean[k] = v
return clean
async def health_check(request: Request) -> Union[Dict, JSONResponse]:
"""PgSTAC HealthCheck."""
resp = {
"status": "UP",
"lifespan": {
"status": "UP",
},
}
if not hasattr(request.app.state, "get_connection"):
return JSONResponse(
status_code=503,
content={
"status": "DOWN",
"lifespan": {
"status": "DOWN",
"message": "application lifespan wasn't run",
},
"pgstac": {
"status": "DOWN",
"message": "Could not connect to database",
},
},
)
try:
async with request.app.state.get_connection(request, "r") as conn:
q, p = render(
"""SELECT pgstac.get_version();""",
)
version = await conn.fetchval(q, *p)
except Exception as e:
resp["status"] = "DOWN"
resp["pgstac"] = {
"status": "DOWN",
"message": str(e),
}
return JSONResponse(status_code=503, content=resp)
resp["pgstac"] = {
"status": "UP",
"pgstac_version": version,
}
return resp