"""FastAPI application using PGStac. Enables the extensions specified as a comma-delimited list in the ENABLED_EXTENSIONS environment variable (e.g. `transactions,sort,query`). If the variable is not set, enables all extensions. """ import os from contextlib import asynccontextmanager from brotli_asgi import BrotliMiddleware from fastapi import FastAPI from stac_fastapi.api.app import StacApi from stac_fastapi.api.middleware import CORSMiddleware, ProxyHeaderMiddleware from stac_fastapi.api.models import ( EmptyRequest, ItemCollectionUri, JSONResponse, create_get_request_model, create_post_request_model, create_request_model, ) from stac_fastapi.extensions.core import ( CollectionSearchExtension, CollectionSearchFilterExtension, FieldsExtension, FreeTextExtension, ItemCollectionFilterExtension, OffsetPaginationExtension, SearchFilterExtension, SortExtension, TokenPaginationExtension, TransactionExtension, ) from stac_fastapi.extensions.core.fields import FieldsConformanceClasses from stac_fastapi.extensions.core.free_text import FreeTextConformanceClasses from stac_fastapi.extensions.core.query import QueryConformanceClasses from stac_fastapi.extensions.core.sort import SortConformanceClasses from stac_fastapi.extensions.third_party import BulkTransactionExtension from starlette.middleware import Middleware from stac_fastapi.pgstac.config import Settings from stac_fastapi.pgstac.core import CoreCrudClient, health_check from stac_fastapi.pgstac.db import close_db_connection, connect_to_db from stac_fastapi.pgstac.extensions import QueryExtension from stac_fastapi.pgstac.extensions.filter import FiltersClient from stac_fastapi.pgstac.transactions import BulkTransactionsClient, TransactionsClient from stac_fastapi.pgstac.types.search import PgstacSearch settings = Settings() # search extensions search_extensions_map = { "query": QueryExtension(), "sort": SortExtension(), "fields": FieldsExtension(), "filter": SearchFilterExtension(client=FiltersClient()), "pagination": TokenPaginationExtension(), } # collection_search extensions cs_extensions_map = { "query": QueryExtension(conformance_classes=[QueryConformanceClasses.COLLECTIONS]), "sort": SortExtension(conformance_classes=[SortConformanceClasses.COLLECTIONS]), "fields": FieldsExtension(conformance_classes=[FieldsConformanceClasses.COLLECTIONS]), "filter": CollectionSearchFilterExtension(client=FiltersClient()), "free_text": FreeTextExtension( conformance_classes=[FreeTextConformanceClasses.COLLECTIONS], ), "pagination": OffsetPaginationExtension(), } # item_collection extensions itm_col_extensions_map = { "query": QueryExtension( conformance_classes=[QueryConformanceClasses.ITEMS], ), "sort": SortExtension( conformance_classes=[SortConformanceClasses.ITEMS], ), "fields": FieldsExtension(conformance_classes=[FieldsConformanceClasses.ITEMS]), "filter": ItemCollectionFilterExtension(client=FiltersClient()), "pagination": TokenPaginationExtension(), } enabled_extensions = { *search_extensions_map.keys(), *cs_extensions_map.keys(), *itm_col_extensions_map.keys(), "collection_search", } if ext := os.environ.get("ENABLED_EXTENSIONS"): enabled_extensions = set(ext.split(",")) application_extensions = [] with_transactions = os.environ.get("ENABLE_TRANSACTIONS_EXTENSIONS", "").lower() in [ "yes", "true", "1", ] if with_transactions: application_extensions.append( TransactionExtension( client=TransactionsClient(), settings=settings, response_class=JSONResponse, ), ) application_extensions.append( BulkTransactionExtension(client=BulkTransactionsClient()), ) # /search models search_extensions = [ extension for key, extension in search_extensions_map.items() if key in enabled_extensions ] post_request_model = create_post_request_model(search_extensions, base_model=PgstacSearch) get_request_model = create_get_request_model(search_extensions) application_extensions.extend(search_extensions) # /collections/{collectionId}/items model items_get_request_model = ItemCollectionUri itm_col_extensions = [ extension for key, extension in itm_col_extensions_map.items() if key in enabled_extensions ] if itm_col_extensions: items_get_request_model = create_request_model( model_name="ItemCollectionUri", base_model=ItemCollectionUri, extensions=itm_col_extensions, request_type="GET", ) application_extensions.extend(itm_col_extensions) # /collections model collections_get_request_model = EmptyRequest if "collection_search" in enabled_extensions: cs_extensions = [ extension for key, extension in cs_extensions_map.items() if key in enabled_extensions ] collection_search_extension = CollectionSearchExtension.from_extensions(cs_extensions) collections_get_request_model = collection_search_extension.GET application_extensions.append(collection_search_extension) @asynccontextmanager async def lifespan(app: FastAPI): """FastAPI Lifespan.""" await connect_to_db(app, add_write_connection_pool=with_transactions) yield await close_db_connection(app) api = StacApi( app=FastAPI( openapi_url=settings.openapi_url, docs_url=settings.docs_url, redoc_url=None, root_path=settings.root_path, title=settings.stac_fastapi_title, version=settings.stac_fastapi_version, description=settings.stac_fastapi_description, lifespan=lifespan, ), settings=settings, extensions=application_extensions, client=CoreCrudClient(pgstac_search_model=post_request_model), response_class=JSONResponse, items_get_request_model=items_get_request_model, search_get_request_model=get_request_model, search_post_request_model=post_request_model, collections_get_request_model=collections_get_request_model, middlewares=[ Middleware(BrotliMiddleware), Middleware(ProxyHeaderMiddleware), Middleware( CORSMiddleware, allow_origins=settings.cors_origins, allow_methods=settings.cors_methods, ), ], health_check=health_check, ) app = api.app def run(): """Run app from command line using uvicorn if available.""" try: import uvicorn uvicorn.run( "stac_fastapi.pgstac.app:app", host=settings.app_host, port=settings.app_port, log_level="info", reload=settings.reload, root_path=os.getenv("UVICORN_ROOT_PATH", ""), ) except ImportError as e: raise RuntimeError("Uvicorn must be installed in order to use command") from e if __name__ == "__main__": run() def create_handler(app): """Create a handler to use with AWS Lambda if mangum available.""" try: from mangum import Mangum return Mangum(app) except ImportError: return None handler = create_handler(app)