654 lines
22 KiB
Python
654 lines
22 KiB
Python
"""Item crud client."""
|
|
|
|
import json
|
|
import re
|
|
from typing import Any, Dict, List, Optional, Set, Type, Union
|
|
from urllib.parse import unquote_plus, urljoin
|
|
|
|
import attr
|
|
import orjson
|
|
from asyncpg.exceptions import InvalidDatetimeFormatError
|
|
from buildpg import render
|
|
from cql2 import Expr
|
|
from fastapi import HTTPException, Request
|
|
from pydantic import ValidationError
|
|
from pypgstac.hydration import hydrate
|
|
from stac_fastapi.api.models import JSONResponse
|
|
from stac_fastapi.types.core import AsyncBaseCoreClient, Relations
|
|
from stac_fastapi.types.errors import InvalidQueryParameter, NotFoundError
|
|
from stac_fastapi.types.requests import get_base_url
|
|
from stac_fastapi.types.stac import Collection, Collections, Item, ItemCollection
|
|
from stac_pydantic.shared import BBox, MimeTypes
|
|
|
|
from stac_fastapi.pgstac.config import Settings
|
|
from stac_fastapi.pgstac.models.links import (
|
|
CollectionLinks,
|
|
CollectionSearchPagingLinks,
|
|
ItemCollectionLinks,
|
|
ItemLinks,
|
|
PagingLinks,
|
|
SearchLinks,
|
|
)
|
|
from stac_fastapi.pgstac.types.search import PgstacSearch
|
|
from stac_fastapi.pgstac.utils import filter_fields
|
|
|
|
NumType = Union[float, int]
|
|
|
|
|
|
@attr.s
|
|
class CoreCrudClient(AsyncBaseCoreClient):
|
|
"""Client for core endpoints defined by stac."""
|
|
|
|
pgstac_search_model: Type[PgstacSearch] = attr.ib(default=PgstacSearch)
|
|
|
|
async def all_collections( # noqa: C901
|
|
self,
|
|
request: Request,
|
|
# Extensions
|
|
bbox: Optional[BBox] = None,
|
|
datetime: Optional[str] = None,
|
|
limit: Optional[int] = None,
|
|
offset: Optional[int] = None,
|
|
query: Optional[str] = None,
|
|
fields: Optional[List[str]] = None,
|
|
sortby: Optional[str] = None,
|
|
filter_expr: Optional[str] = None,
|
|
filter_lang: Optional[str] = None,
|
|
q: Optional[List[str]] = None,
|
|
**kwargs,
|
|
) -> Collections:
|
|
"""Cross catalog search (GET).
|
|
|
|
Called with `GET /collections`.
|
|
|
|
Returns:
|
|
Collections which match the search criteria, returns all
|
|
collections by default.
|
|
"""
|
|
base_url = get_base_url(request)
|
|
|
|
next_link: Optional[Dict[str, Any]] = None
|
|
prev_link: Optional[Dict[str, Any]] = None
|
|
collections_result: Collections
|
|
|
|
if self.extension_is_enabled("CollectionSearchExtension"):
|
|
base_args = {
|
|
"bbox": bbox,
|
|
"limit": limit,
|
|
"offset": offset,
|
|
"query": orjson.loads(unquote_plus(query)) if query else query,
|
|
}
|
|
|
|
clean_args = self._clean_search_args(
|
|
base_args=base_args,
|
|
datetime=datetime,
|
|
fields=fields,
|
|
sortby=sortby,
|
|
filter_query=filter_expr,
|
|
filter_lang=filter_lang,
|
|
q=q,
|
|
)
|
|
|
|
async with request.app.state.get_connection(request, "r") as conn:
|
|
q, p = render(
|
|
"""
|
|
SELECT * FROM collection_search(:req::text::jsonb);
|
|
""",
|
|
req=json.dumps(clean_args),
|
|
)
|
|
collections_result = await conn.fetchval(q, *p)
|
|
|
|
if links := collections_result.get("links"):
|
|
for link in links:
|
|
if link["rel"] == "next":
|
|
next_link = link
|
|
elif link["rel"] == "prev":
|
|
prev_link = link
|
|
|
|
else:
|
|
async with request.app.state.get_connection(request, "r") as conn:
|
|
cols = await conn.fetchval(
|
|
"""
|
|
SELECT * FROM all_collections();
|
|
"""
|
|
)
|
|
collections_result = {"collections": cols, "links": []}
|
|
|
|
linked_collections: List[Collection] = []
|
|
collections = collections_result["collections"]
|
|
if collections is not None and len(collections) > 0:
|
|
for c in collections:
|
|
coll = Collection(**c)
|
|
coll["links"] = await CollectionLinks(
|
|
collection_id=coll["id"], request=request
|
|
).get_links(extra_links=coll.get("links"))
|
|
|
|
if self.extension_is_enabled(
|
|
"FilterExtension"
|
|
) or self.extension_is_enabled("ItemCollectionFilterExtension"):
|
|
coll["links"].append(
|
|
{
|
|
"rel": Relations.queryables.value,
|
|
"type": MimeTypes.jsonschema.value,
|
|
"title": "Queryables",
|
|
"href": urljoin(
|
|
base_url, f"collections/{coll['id']}/queryables"
|
|
),
|
|
}
|
|
)
|
|
|
|
linked_collections.append(coll)
|
|
|
|
links = await CollectionSearchPagingLinks(
|
|
request=request,
|
|
next=next_link,
|
|
prev=prev_link,
|
|
).get_links()
|
|
|
|
return Collections(
|
|
collections=linked_collections or [],
|
|
links=links,
|
|
numberMatched=collections_result.get(
|
|
"numberMatched", len(linked_collections)
|
|
),
|
|
numberReturned=collections_result.get(
|
|
"numberReturned", len(linked_collections)
|
|
),
|
|
)
|
|
|
|
async def get_collection(
|
|
self, collection_id: str, request: Request, **kwargs
|
|
) -> Collection:
|
|
"""Get collection by id.
|
|
|
|
Called with `GET /collections/{collection_id}`.
|
|
|
|
Args:
|
|
collection_id: ID of the collection.
|
|
|
|
Returns:
|
|
Collection.
|
|
"""
|
|
collection: Optional[Dict[str, Any]]
|
|
|
|
async with request.app.state.get_connection(request, "r") as conn:
|
|
q, p = render(
|
|
"""
|
|
SELECT * FROM get_collection(:id::text);
|
|
""",
|
|
id=collection_id,
|
|
)
|
|
collection = await conn.fetchval(q, *p)
|
|
if collection is None:
|
|
raise NotFoundError(f"Collection {collection_id} does not exist.")
|
|
|
|
collection["links"] = await CollectionLinks(
|
|
collection_id=collection_id, request=request
|
|
).get_links(extra_links=collection.get("links"))
|
|
|
|
if self.extension_is_enabled("FilterExtension") or self.extension_is_enabled(
|
|
"ItemCollectionFilterExtension"
|
|
):
|
|
base_url = get_base_url(request)
|
|
collection["links"].append(
|
|
{
|
|
"rel": Relations.queryables.value,
|
|
"type": MimeTypes.jsonschema.value,
|
|
"title": "Queryables",
|
|
"href": urljoin(base_url, f"collections/{collection_id}/queryables"),
|
|
}
|
|
)
|
|
|
|
return Collection(**collection)
|
|
|
|
async def _get_base_item(
|
|
self, collection_id: str, request: Request
|
|
) -> Dict[str, Any]:
|
|
"""Get the base item of a collection for use in rehydrating full item collection properties.
|
|
|
|
Args:
|
|
collection_id: ID of the collection.
|
|
|
|
Returns:
|
|
Item.
|
|
"""
|
|
item: Optional[Dict[str, Any]]
|
|
|
|
async with request.app.state.get_connection(request, "r") as conn:
|
|
q, p = render(
|
|
"""
|
|
SELECT * FROM collection_base_item(:collection_id::text);
|
|
""",
|
|
collection_id=collection_id,
|
|
)
|
|
item = await conn.fetchval(q, *p)
|
|
|
|
if item is None:
|
|
raise NotFoundError(f"A base item for {collection_id} does not exist.")
|
|
|
|
return item
|
|
|
|
async def _search_base( # noqa: C901
|
|
self,
|
|
search_request: PgstacSearch,
|
|
request: Request,
|
|
) -> ItemCollection:
|
|
"""Cross catalog search (POST).
|
|
|
|
Called with `POST /search`.
|
|
|
|
Args:
|
|
search_request: search request parameters.
|
|
|
|
Returns:
|
|
ItemCollection containing items which match the search criteria.
|
|
"""
|
|
items: Dict[str, Any]
|
|
|
|
settings: Settings = request.app.state.settings
|
|
|
|
search_request.conf = search_request.conf or {}
|
|
search_request.conf["nohydrate"] = settings.use_api_hydrate
|
|
|
|
search_request_json = search_request.model_dump_json(
|
|
exclude_none=True, by_alias=True
|
|
)
|
|
|
|
try:
|
|
async with request.app.state.get_connection(request, "r") as conn:
|
|
q, p = render(
|
|
"""
|
|
SELECT * FROM search(:req::text::jsonb);
|
|
""",
|
|
req=search_request_json,
|
|
)
|
|
items = await conn.fetchval(q, *p)
|
|
except InvalidDatetimeFormatError as e:
|
|
raise InvalidQueryParameter(
|
|
f"Datetime parameter {search_request.datetime} is invalid."
|
|
) from e
|
|
|
|
# Starting in pgstac 0.9.0, the `next` and `prev` tokens are returned in spec-compliant links with method GET
|
|
next_from_link: Optional[str] = None
|
|
prev_from_link: Optional[str] = None
|
|
for link in items.get("links", []):
|
|
if link.get("rel") == "next":
|
|
next_from_link = link.get("href").split("token=next:")[1]
|
|
if link.get("rel") == "prev":
|
|
prev_from_link = link.get("href").split("token=prev:")[1]
|
|
|
|
next: Optional[str] = items.pop("next", next_from_link)
|
|
prev: Optional[str] = items.pop("prev", prev_from_link)
|
|
collection = ItemCollection(**items)
|
|
|
|
fields = getattr(search_request, "fields", None)
|
|
include: Set[str] = fields.include if fields and fields.include else set()
|
|
exclude: Set[str] = fields.exclude if fields and fields.exclude else set()
|
|
|
|
async def _add_item_links(
|
|
feature: Item,
|
|
collection_id: Optional[str] = None,
|
|
item_id: Optional[str] = None,
|
|
) -> None:
|
|
"""Add ItemLinks to the Item.
|
|
|
|
If the fields extension is excluding links, then don't add them.
|
|
Also skip links if the item doesn't provide collection and item ids.
|
|
"""
|
|
collection_id = feature.get("collection") or collection_id
|
|
item_id = feature.get("id") or item_id
|
|
|
|
if not exclude or "links" not in exclude and all([collection_id, item_id]):
|
|
feature["links"] = await ItemLinks(
|
|
collection_id=collection_id, # type: ignore
|
|
item_id=item_id, # type: ignore
|
|
request=request,
|
|
).get_links(extra_links=feature.get("links"))
|
|
|
|
cleaned_features: List[Item] = []
|
|
|
|
if settings.use_api_hydrate:
|
|
|
|
async def _get_base_item(collection_id: str) -> Dict[str, Any]:
|
|
return await self._get_base_item(collection_id, request=request)
|
|
|
|
base_item_cache = settings.base_item_cache(
|
|
fetch_base_item=_get_base_item, request=request
|
|
)
|
|
|
|
for feature in collection.get("features") or []:
|
|
base_item = await base_item_cache.get(feature.get("collection"))
|
|
# Exclude None values
|
|
base_item = {k: v for k, v in base_item.items() if v is not None}
|
|
|
|
feature = hydrate(base_item, feature)
|
|
|
|
# Grab ids needed for links that may be removed by the fields extension.
|
|
collection_id = feature.get("collection")
|
|
item_id = feature.get("id")
|
|
|
|
feature = filter_fields(feature, include, exclude)
|
|
await _add_item_links(feature, collection_id, item_id)
|
|
|
|
cleaned_features.append(feature)
|
|
else:
|
|
for feature in collection.get("features") or []:
|
|
await _add_item_links(feature)
|
|
cleaned_features.append(feature)
|
|
|
|
collection["features"] = cleaned_features
|
|
collection["links"] = await PagingLinks(
|
|
request=request,
|
|
next=next,
|
|
prev=prev,
|
|
).get_links()
|
|
|
|
return collection
|
|
|
|
async def item_collection(
|
|
self,
|
|
collection_id: str,
|
|
request: Request,
|
|
bbox: Optional[BBox] = None,
|
|
datetime: Optional[str] = None,
|
|
limit: Optional[int] = None,
|
|
# Extensions
|
|
query: Optional[str] = None,
|
|
fields: Optional[List[str]] = None,
|
|
sortby: Optional[str] = None,
|
|
filter_expr: Optional[str] = None,
|
|
filter_lang: Optional[str] = None,
|
|
token: Optional[str] = None,
|
|
**kwargs,
|
|
) -> ItemCollection:
|
|
"""Get all items from a specific collection.
|
|
|
|
Called with `GET /collections/{collection_id}/items`
|
|
|
|
Args:
|
|
collection_id: id of the collection.
|
|
limit: number of items to return.
|
|
token: pagination token.
|
|
|
|
Returns:
|
|
An ItemCollection.
|
|
"""
|
|
# If collection does not exist, NotFoundError wil be raised
|
|
await self.get_collection(collection_id, request=request)
|
|
|
|
base_args = {
|
|
"collections": [collection_id],
|
|
"bbox": bbox,
|
|
"datetime": datetime,
|
|
"limit": limit,
|
|
"token": token,
|
|
"query": orjson.loads(unquote_plus(query)) if query else query,
|
|
}
|
|
|
|
clean = self._clean_search_args(
|
|
base_args=base_args,
|
|
filter_query=filter_expr,
|
|
filter_lang=filter_lang,
|
|
fields=fields,
|
|
sortby=sortby,
|
|
)
|
|
|
|
try:
|
|
search_request = self.pgstac_search_model(**clean)
|
|
except ValidationError as e:
|
|
raise HTTPException(
|
|
status_code=400, detail=f"Invalid parameters provided {e}"
|
|
) from e
|
|
|
|
item_collection = await self._search_base(search_request, request=request)
|
|
|
|
links = await ItemCollectionLinks(
|
|
collection_id=collection_id, request=request
|
|
).get_links(extra_links=item_collection["links"])
|
|
item_collection["links"] = links
|
|
|
|
# If we have the `fields` extension enabled
|
|
# we need to avoid Pydantic validation because the
|
|
# Items might not be a valid STAC Item objects
|
|
if fields := getattr(search_request, "fields", None):
|
|
if fields.include or fields.exclude:
|
|
return JSONResponse(item_collection) # type: ignore
|
|
|
|
return ItemCollection(**item_collection)
|
|
|
|
async def get_item(
|
|
self, item_id: str, collection_id: str, request: Request, **kwargs
|
|
) -> Item:
|
|
"""Get item by id.
|
|
|
|
Called with `GET /collections/{collection_id}/items/{item_id}`.
|
|
|
|
Args:
|
|
item_id: ID of the item.
|
|
collection_id: ID of the collection the item is in.
|
|
|
|
Returns:
|
|
Item.
|
|
"""
|
|
# If collection does not exist, NotFoundError wil be raised
|
|
await self.get_collection(collection_id, request=request)
|
|
|
|
search_request = self.pgstac_search_model(
|
|
ids=[item_id], collections=[collection_id], limit=1
|
|
)
|
|
item_collection = await self._search_base(search_request, request=request)
|
|
if not item_collection["features"]:
|
|
raise NotFoundError(
|
|
f"Item {item_id} in Collection {collection_id} does not exist."
|
|
)
|
|
|
|
return Item(**item_collection["features"][0])
|
|
|
|
async def post_search(
|
|
self, search_request: PgstacSearch, request: Request, **kwargs
|
|
) -> ItemCollection:
|
|
"""Cross catalog search (POST).
|
|
|
|
Called with `POST /search`.
|
|
|
|
Args:
|
|
search_request: search request parameters.
|
|
|
|
Returns:
|
|
ItemCollection containing items which match the search criteria.
|
|
"""
|
|
item_collection = await self._search_base(search_request, request=request)
|
|
|
|
# If we have the `fields` extension enabled
|
|
# we need to avoid Pydantic validation because the
|
|
# Items might not be a valid STAC Item objects
|
|
if fields := getattr(search_request, "fields", None):
|
|
if fields.include or fields.exclude:
|
|
return JSONResponse(item_collection) # type: ignore
|
|
|
|
links = await SearchLinks(request=request).get_links(
|
|
extra_links=item_collection["links"]
|
|
)
|
|
item_collection["links"] = links
|
|
|
|
return ItemCollection(**item_collection)
|
|
|
|
async def get_search(
|
|
self,
|
|
request: Request,
|
|
collections: Optional[List[str]] = None,
|
|
ids: Optional[List[str]] = None,
|
|
bbox: Optional[BBox] = None,
|
|
intersects: Optional[str] = None,
|
|
datetime: Optional[str] = None,
|
|
limit: Optional[int] = None,
|
|
# Extensions
|
|
query: Optional[str] = None,
|
|
fields: Optional[List[str]] = None,
|
|
sortby: Optional[str] = None,
|
|
filter_expr: Optional[str] = None,
|
|
filter_lang: Optional[str] = None,
|
|
token: Optional[str] = None,
|
|
**kwargs,
|
|
) -> ItemCollection:
|
|
"""Cross catalog search (GET).
|
|
|
|
Called with `GET /search`.
|
|
|
|
Returns:
|
|
ItemCollection containing items which match the search criteria.
|
|
"""
|
|
# Parse request parameters
|
|
base_args = {
|
|
"collections": collections,
|
|
"ids": ids,
|
|
"bbox": bbox,
|
|
"limit": limit,
|
|
"token": token,
|
|
"query": orjson.loads(unquote_plus(query)) if query else query,
|
|
}
|
|
|
|
clean = self._clean_search_args(
|
|
base_args=base_args,
|
|
intersects=intersects,
|
|
datetime=datetime,
|
|
fields=fields,
|
|
sortby=sortby,
|
|
filter_query=filter_expr,
|
|
filter_lang=filter_lang,
|
|
)
|
|
|
|
try:
|
|
search_request = self.pgstac_search_model(**clean)
|
|
except ValidationError as e:
|
|
raise HTTPException(
|
|
status_code=400, detail=f"Invalid parameters provided {e}"
|
|
) from e
|
|
|
|
item_collection = await self._search_base(search_request, request=request)
|
|
|
|
links = await SearchLinks(request=request).get_links(
|
|
extra_links=item_collection["links"]
|
|
)
|
|
item_collection["links"] = links
|
|
|
|
# If we have the `fields` extension enabled
|
|
# we need to avoid Pydantic validation because the
|
|
# Items might not be a valid STAC Item objects
|
|
if fields := getattr(search_request, "fields", None):
|
|
if fields.include or fields.exclude:
|
|
return JSONResponse(item_collection) # type: ignore
|
|
|
|
return ItemCollection(**item_collection)
|
|
|
|
def _clean_search_args( # noqa: C901
|
|
self,
|
|
base_args: Dict[str, Any],
|
|
intersects: Optional[str] = None,
|
|
datetime: Optional[str] = None,
|
|
fields: Optional[List[str]] = None,
|
|
sortby: Optional[str] = None,
|
|
filter_query: Optional[str] = None,
|
|
filter_lang: Optional[str] = None,
|
|
q: Optional[List[str]] = None,
|
|
) -> Dict[str, Any]:
|
|
"""Clean up search arguments to match format expected by pgstac"""
|
|
if filter_query:
|
|
if filter_lang == "cql2-text":
|
|
e = Expr(filter_query)
|
|
base_args["filter"] = e.to_json()
|
|
base_args["filter_lang"] = "cql2-json"
|
|
else:
|
|
base_args["filter"] = orjson.loads(filter_query)
|
|
base_args["filter_lang"] = filter_lang
|
|
|
|
if datetime:
|
|
base_args["datetime"] = datetime
|
|
|
|
if intersects:
|
|
base_args["intersects"] = orjson.loads(unquote_plus(intersects))
|
|
|
|
if sortby:
|
|
# https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form
|
|
sort_param = []
|
|
for sort in sortby:
|
|
sortparts = re.match(r"^([+-]?)(.*)$", sort)
|
|
if sortparts:
|
|
sort_param.append(
|
|
{
|
|
"field": sortparts.group(2).strip(),
|
|
"direction": "desc" if sortparts.group(1) == "-" else "asc",
|
|
}
|
|
)
|
|
base_args["sortby"] = sort_param
|
|
|
|
if fields:
|
|
includes = set()
|
|
excludes = set()
|
|
for field in fields:
|
|
if field[0] == "-":
|
|
excludes.add(field[1:])
|
|
elif field[0] == "+":
|
|
includes.add(field[1:])
|
|
else:
|
|
includes.add(field)
|
|
|
|
base_args["fields"] = {"include": includes, "exclude": excludes}
|
|
|
|
if q:
|
|
base_args["q"] = " OR ".join(q)
|
|
|
|
# Remove None values from dict
|
|
clean = {}
|
|
for k, v in base_args.items():
|
|
if v is not None and v != []:
|
|
clean[k] = v
|
|
|
|
return clean
|
|
|
|
|
|
async def health_check(request: Request) -> Union[Dict, JSONResponse]:
|
|
"""PgSTAC HealthCheck."""
|
|
resp = {
|
|
"status": "UP",
|
|
"lifespan": {
|
|
"status": "UP",
|
|
},
|
|
}
|
|
if not hasattr(request.app.state, "get_connection"):
|
|
return JSONResponse(
|
|
status_code=503,
|
|
content={
|
|
"status": "DOWN",
|
|
"lifespan": {
|
|
"status": "DOWN",
|
|
"message": "application lifespan wasn't run",
|
|
},
|
|
"pgstac": {
|
|
"status": "DOWN",
|
|
"message": "Could not connect to database",
|
|
},
|
|
},
|
|
)
|
|
|
|
try:
|
|
async with request.app.state.get_connection(request, "r") as conn:
|
|
q, p = render(
|
|
"""SELECT pgstac.get_version();""",
|
|
)
|
|
version = await conn.fetchval(q, *p)
|
|
except Exception as e:
|
|
resp["status"] = "DOWN"
|
|
resp["pgstac"] = {
|
|
"status": "DOWN",
|
|
"message": str(e),
|
|
}
|
|
return JSONResponse(status_code=503, content=resp)
|
|
|
|
resp["pgstac"] = {
|
|
"status": "UP",
|
|
"pgstac_version": version,
|
|
}
|
|
|
|
return resp
|