diff --git a/pygeoapi/api.py b/pygeoapi/api.py index edd7135..a394149 100644 --- a/pygeoapi/api.py +++ b/pygeoapi/api.py @@ -370,7 +370,7 @@ class APIRequest: # has been implemented, with_data() can become async too loop = asyncio.get_event_loop() api_req._data = asyncio.run_coroutine_threadsafe( - request.body(), loop) + request.body(), loop).result(1) return api_req @staticmethod diff --git a/pygeoapi/starlette_app.py b/pygeoapi/starlette_app.py index 4251747..29db43e 100644 --- a/pygeoapi/starlette_app.py +++ b/pygeoapi/starlette_app.py @@ -32,8 +32,9 @@ # ================================================================= """ Starlette module providing the route paths to the api""" +import asyncio import os -from typing import Union +from typing import Callable, Union from pathlib import Path import click @@ -79,16 +80,42 @@ API_RULES = get_api_rules(CONFIG) api_ = API(CONFIG, OPENAPI) -def get_response(result: tuple) -> Union[Response, JSONResponse, HTMLResponse]: +def call_api_threadsafe( + loop: asyncio.AbstractEventLoop, api_call: Callable, *args +) -> tuple: + """ + The api call needs a running loop. This method is meant to be called + from a thread that has no loop running. + + :param loop: The loop to use. + :param api_call: The API method to call. + :param args: Arguments to pass to the API method. + :returns: The api call result tuple. + """ + asyncio.set_event_loop(loop) + return api_call(*args) + + +async def get_response( + api_call, + *args, +) -> Union[Response, JSONResponse, HTMLResponse]: """ Creates a Starlette Response object and updates matching headers. + Runs the core api handler in a separate thread in order to avoid + blocking the main event loop. + :param result: The result of the API call. This should be a tuple of (headers, status, content). :returns: A Response instance. """ + loop = asyncio.get_running_loop() + result = await loop.run_in_executor( + None, call_api_threadsafe, loop, api_call, *args) + headers, status, content = result if headers['Content-Type'] == 'text/html': response = HTMLResponse(content=content, status_code=status) @@ -111,7 +138,7 @@ async def landing_page(request: Request): :returns: Starlette HTTP Response """ - return get_response(api_.landing_page(request)) + return await get_response(api_.landing_page, request) async def openapi(request: Request): @@ -122,7 +149,7 @@ async def openapi(request: Request): :returns: Starlette HTTP Response """ - return get_response(api_.openapi_(request)) + return await get_response(api_.openapi_, request) async def conformance(request: Request): @@ -133,7 +160,7 @@ async def conformance(request: Request): :returns: Starlette HTTP Response """ - return get_response(api_.conformance(request)) + return await get_response(api_.conformance, request) async def get_tilematrix_set(request: Request, tileMatrixSetId=None): @@ -146,7 +173,8 @@ async def get_tilematrix_set(request: Request, tileMatrixSetId=None): if 'tileMatrixSetId' in request.path_params: tileMatrixSetId = request.path_params['tileMatrixSetId'] - return get_response(api_.tilematrixset(request, tileMatrixSetId)) + return await get_response( + api_.tilematrixset, request, tileMatrixSetId) async def get_tilematrix_sets(request: Request): @@ -155,7 +183,7 @@ async def get_tilematrix_sets(request: Request): :returns: HTTP response """ - return get_response(api_.tilematrixsets(request)) + return await get_response(api_.tilematrixsets, request) async def collection_queryables(request: Request, collection_id=None): @@ -169,7 +197,8 @@ async def collection_queryables(request: Request, collection_id=None): """ if 'collection_id' in request.path_params: collection_id = request.path_params['collection_id'] - return get_response(api_.get_collection_queryables(request, collection_id)) + return await get_response( + api_.get_collection_queryables, request, collection_id) async def get_collection_tiles(request: Request, collection_id=None): @@ -183,8 +212,8 @@ async def get_collection_tiles(request: Request, collection_id=None): """ if 'collection_id' in request.path_params: collection_id = request.path_params['collection_id'] - return get_response(api_.get_collection_tiles( - request, collection_id)) + return await get_response( + api_.get_collection_tiles, request, collection_id) async def get_collection_tiles_metadata(request: Request, collection_id=None, @@ -201,8 +230,10 @@ async def get_collection_tiles_metadata(request: Request, collection_id=None, collection_id = request.path_params['collection_id'] if 'tileMatrixSetId' in request.path_params: tileMatrixSetId = request.path_params['tileMatrixSetId'] - return get_response(api_.get_collection_tiles_metadata( - request, collection_id, tileMatrixSetId)) + return await get_response( + api_.get_collection_tiles_metadata, request, + collection_id, tileMatrixSetId + ) async def get_collection_items_tiles(request: Request, collection_id=None, @@ -230,9 +261,10 @@ async def get_collection_items_tiles(request: Request, collection_id=None, tileRow = request.path_params['tileRow'] if 'tileCol' in request.path_params: tileCol = request.path_params['tileCol'] - return get_response(api_.get_collection_tiles_data( - request, collection_id, tileMatrixSetId, - tile_matrix, tileRow, tileCol)) + return await get_response( + api_.get_collection_tiles_data, request, collection_id, + tileMatrixSetId, tile_matrix, tileRow, tileCol + ) async def collection_items(request: Request, collection_id=None, item_id=None): @@ -252,38 +284,45 @@ async def collection_items(request: Request, collection_id=None, item_id=None): item_id = request.path_params['item_id'] if item_id is None: if request.method == 'GET': # list items - return get_response( - api_.get_collection_items( - request, collection_id)) + return await get_response( + api_.get_collection_items, request, collection_id) elif request.method == 'POST': # filter or manage items content_type = request.headers.get('content-type') if content_type is not None: if content_type == 'application/geo+json': - return get_response( - api_.manage_collection_item(request, 'create', - collection_id)) + return await get_response( + api_.manage_collection_item, request, + 'create', collection_id) else: - return get_response( - api_.post_collection_items(request, collection_id)) + return await get_response( + api_.post_collection_items, + request, + collection_id + ) elif request.method == 'OPTIONS': - return get_response( - api_.manage_collection_item(request, 'options', collection_id)) + return await get_response( + api_.manage_collection_item, request, + 'options', collection_id + ) elif request.method == 'DELETE': - return get_response( - api_.manage_collection_item(request, 'delete', - collection_id, item_id)) + return await get_response( + api_.manage_collection_item, request, 'delete', + collection_id, item_id + ) elif request.method == 'PUT': - return get_response( - api_.manage_collection_item(request, 'update', - collection_id, item_id)) + return await get_response( + api_.manage_collection_item, request, 'update', + collection_id, item_id + ) elif request.method == 'OPTIONS': - return get_response( - api_.manage_collection_item(request, 'options', - collection_id, item_id)) + return await get_response( + api_.manage_collection_item, request, 'options', + collection_id, item_id + ) else: - return get_response(api_.get_collection_item( - request, collection_id, item_id)) + return await get_response( + api_.get_collection_item, request, collection_id, item_id) async def collection_coverage(request: Request, collection_id=None): @@ -298,7 +337,8 @@ async def collection_coverage(request: Request, collection_id=None): if 'collection_id' in request.path_params: collection_id = request.path_params['collection_id'] - return get_response(api_.get_collection_coverage(request, collection_id)) + return await get_response( + api_.get_collection_coverage, request, collection_id) async def collection_coverage_domainset(request: Request, collection_id=None): @@ -313,8 +353,8 @@ async def collection_coverage_domainset(request: Request, collection_id=None): if 'collection_id' in request.path_params: collection_id = request.path_params['collection_id'] - return get_response(api_.get_collection_coverage_domainset( - request, collection_id)) + return await get_response( + api_.get_collection_coverage_domainset, request, collection_id) async def collection_coverage_rangetype(request: Request, collection_id=None): @@ -330,8 +370,8 @@ async def collection_coverage_rangetype(request: Request, collection_id=None): if 'collection_id' in request.path_params: collection_id = request.path_params['collection_id'] - return get_response(api_.get_collection_coverage_rangetype( - request, collection_id)) + return await get_response( + api_.get_collection_coverage_rangetype, request, collection_id) async def collection_map(request: Request, collection_id, style_id=None): @@ -349,8 +389,8 @@ async def collection_map(request: Request, collection_id, style_id=None): if 'style_id' in request.path_params: style_id = request.path_params['style_id'] - return get_response(api_.get_collection_map( - request, collection_id, style_id)) + return await get_response( + api_.get_collection_map, request, collection_id, style_id) async def get_processes(request: Request, process_id=None): @@ -365,7 +405,7 @@ async def get_processes(request: Request, process_id=None): if 'process_id' in request.path_params: process_id = request.path_params['process_id'] - return get_response(api_.describe_processes(request, process_id)) + return await get_response(api_.describe_processes, request, process_id) async def get_jobs(request: Request, job_id=None): @@ -382,12 +422,12 @@ async def get_jobs(request: Request, job_id=None): job_id = request.path_params['job_id'] if job_id is None: # list of submit job - return get_response(api_.get_jobs(request)) + return await get_response(api_.get_jobs, request) else: # get or delete job if request.method == 'DELETE': - return get_response(api_.delete_job(job_id)) + return await get_response(api_.delete_job, job_id) else: # Return status of a specific job - return get_response(api_.get_jobs(request, job_id)) + return await get_response(api_.get_jobs, request, job_id) async def execute_process_jobs(request: Request, process_id=None): @@ -403,7 +443,7 @@ async def execute_process_jobs(request: Request, process_id=None): if 'process_id' in request.path_params: process_id = request.path_params['process_id'] - return get_response(api_.execute_process(request, process_id)) + return await get_response(api_.execute_process, request, process_id) async def get_job_result(request: Request, job_id=None): @@ -419,7 +459,7 @@ async def get_job_result(request: Request, job_id=None): if 'job_id' in request.path_params: job_id = request.path_params['job_id'] - return get_response(api_.get_job_result(request, job_id)) + return await get_response(api_.get_job_result, request, job_id) async def get_job_result_resource(request: Request, @@ -439,8 +479,8 @@ async def get_job_result_resource(request: Request, if 'resource' in request.path_params: resource = request.path_params['resource'] - return get_response(api_.get_job_result_resource( - request, job_id, resource)) + return await get_response( + api_.get_job_result_resource, request, job_id, resource) async def get_collection_edr_query(request: Request, collection_id=None, instance_id=None): # noqa @@ -460,8 +500,10 @@ async def get_collection_edr_query(request: Request, collection_id=None, instanc instance_id = request.path_params['instance_id'] query_type = request["path"].split('/')[-1] # noqa - return get_response(api_.get_collection_edr_query(request, collection_id, - instance_id, query_type)) + return await get_response( + api_.get_collection_edr_query, request, collection_id, + instance_id, query_type + ) async def collections(request: Request, collection_id=None): @@ -475,7 +517,8 @@ async def collections(request: Request, collection_id=None): """ if 'collection_id' in request.path_params: collection_id = request.path_params['collection_id'] - return get_response(api_.describe_collections(request, collection_id)) + return await get_response( + api_.describe_collections, request, collection_id) async def stac_catalog_root(request: Request): @@ -486,7 +529,7 @@ async def stac_catalog_root(request: Request): :returns: Starlette HTTP response """ - return get_response(api_.get_stac_root(request)) + return await get_response(api_.get_stac_root, request) async def stac_catalog_path(request: Request): @@ -498,7 +541,7 @@ async def stac_catalog_path(request: Request): :returns: Starlette HTTP response """ path = request.path_params["path"] - return get_response(api_.get_stac_path(request, path)) + return await get_response(api_.get_stac_path, request, path) async def admin_config(request: Request): @@ -509,11 +552,11 @@ async def admin_config(request: Request): """ if request.method == 'GET': - return get_response(ADMIN.get_config(request)) + return await get_response(ADMIN.get_config, request) elif request.method == 'PUT': - return get_response(ADMIN.put_config(request)) + return await get_response(ADMIN.put_config, request) elif request.method == 'PATCH': - return get_response(ADMIN.patch_config(request)) + return await get_response(ADMIN.patch_config, request) async def admin_config_resources(request: Request): @@ -524,9 +567,9 @@ async def admin_config_resources(request: Request): """ if request.method == 'GET': - return get_response(ADMIN.get_resources(request)) + return await get_response(ADMIN.get_resources, request) elif request.method == 'POST': - return get_response(ADMIN.put_resource(request)) + return await get_response(ADMIN.put_resource, request) async def admin_config_resource(request: Request, resource_id: str): @@ -542,13 +585,17 @@ async def admin_config_resource(request: Request, resource_id: str): resource_id = request.path_params['resource_id'] if request.method == 'GET': - return get_response(ADMIN.get_resource(request, resource_id)) + return await get_response( + ADMIN.get_resource, request, resource_id) elif request.method == 'PUT': - return get_response(ADMIN.put_resource(request, resource_id)) + return await get_response( + ADMIN.put_resource, request, resource_id) elif request.method == 'PATCH': - return get_response(ADMIN.patch_resource(request, resource_id)) + return await get_response( + ADMIN.patch_resource, request, resource_id) elif request.method == 'DELETE': - return get_response(ADMIN.delete_resource(request, resource_id)) + return await get_response( + ADMIN.delete_resource, request, resource_id) class ApiRulesMiddleware: