run api in another thread when using starlette (#1533)

This commit is contained in:
Ricardo Garcia Silva
2024-02-02 19:34:09 +00:00
committed by GitHub
parent ffd33fafc1
commit 88ae474627
2 changed files with 113 additions and 66 deletions
+1 -1
View File
@@ -370,7 +370,7 @@ class APIRequest:
# has been implemented, with_data() can become async too
loop = asyncio.get_event_loop()
api_req._data = asyncio.run_coroutine_threadsafe(
request.body(), loop)
request.body(), loop).result(1)
return api_req
@staticmethod
+112 -65
View File
@@ -32,8 +32,9 @@
# =================================================================
""" Starlette module providing the route paths to the api"""
import asyncio
import os
from typing import Union
from typing import Callable, Union
from pathlib import Path
import click
@@ -79,16 +80,42 @@ API_RULES = get_api_rules(CONFIG)
api_ = API(CONFIG, OPENAPI)
def get_response(result: tuple) -> Union[Response, JSONResponse, HTMLResponse]:
def call_api_threadsafe(
loop: asyncio.AbstractEventLoop, api_call: Callable, *args
) -> tuple:
"""
The api call needs a running loop. This method is meant to be called
from a thread that has no loop running.
:param loop: The loop to use.
:param api_call: The API method to call.
:param args: Arguments to pass to the API method.
:returns: The api call result tuple.
"""
asyncio.set_event_loop(loop)
return api_call(*args)
async def get_response(
api_call,
*args,
) -> Union[Response, JSONResponse, HTMLResponse]:
"""
Creates a Starlette Response object and updates matching headers.
Runs the core api handler in a separate thread in order to avoid
blocking the main event loop.
:param result: The result of the API call.
This should be a tuple of (headers, status, content).
:returns: A Response instance.
"""
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(
None, call_api_threadsafe, loop, api_call, *args)
headers, status, content = result
if headers['Content-Type'] == 'text/html':
response = HTMLResponse(content=content, status_code=status)
@@ -111,7 +138,7 @@ async def landing_page(request: Request):
:returns: Starlette HTTP Response
"""
return get_response(api_.landing_page(request))
return await get_response(api_.landing_page, request)
async def openapi(request: Request):
@@ -122,7 +149,7 @@ async def openapi(request: Request):
:returns: Starlette HTTP Response
"""
return get_response(api_.openapi_(request))
return await get_response(api_.openapi_, request)
async def conformance(request: Request):
@@ -133,7 +160,7 @@ async def conformance(request: Request):
:returns: Starlette HTTP Response
"""
return get_response(api_.conformance(request))
return await get_response(api_.conformance, request)
async def get_tilematrix_set(request: Request, tileMatrixSetId=None):
@@ -146,7 +173,8 @@ async def get_tilematrix_set(request: Request, tileMatrixSetId=None):
if 'tileMatrixSetId' in request.path_params:
tileMatrixSetId = request.path_params['tileMatrixSetId']
return get_response(api_.tilematrixset(request, tileMatrixSetId))
return await get_response(
api_.tilematrixset, request, tileMatrixSetId)
async def get_tilematrix_sets(request: Request):
@@ -155,7 +183,7 @@ async def get_tilematrix_sets(request: Request):
:returns: HTTP response
"""
return get_response(api_.tilematrixsets(request))
return await get_response(api_.tilematrixsets, request)
async def collection_queryables(request: Request, collection_id=None):
@@ -169,7 +197,8 @@ async def collection_queryables(request: Request, collection_id=None):
"""
if 'collection_id' in request.path_params:
collection_id = request.path_params['collection_id']
return get_response(api_.get_collection_queryables(request, collection_id))
return await get_response(
api_.get_collection_queryables, request, collection_id)
async def get_collection_tiles(request: Request, collection_id=None):
@@ -183,8 +212,8 @@ async def get_collection_tiles(request: Request, collection_id=None):
"""
if 'collection_id' in request.path_params:
collection_id = request.path_params['collection_id']
return get_response(api_.get_collection_tiles(
request, collection_id))
return await get_response(
api_.get_collection_tiles, request, collection_id)
async def get_collection_tiles_metadata(request: Request, collection_id=None,
@@ -201,8 +230,10 @@ async def get_collection_tiles_metadata(request: Request, collection_id=None,
collection_id = request.path_params['collection_id']
if 'tileMatrixSetId' in request.path_params:
tileMatrixSetId = request.path_params['tileMatrixSetId']
return get_response(api_.get_collection_tiles_metadata(
request, collection_id, tileMatrixSetId))
return await get_response(
api_.get_collection_tiles_metadata, request,
collection_id, tileMatrixSetId
)
async def get_collection_items_tiles(request: Request, collection_id=None,
@@ -230,9 +261,10 @@ async def get_collection_items_tiles(request: Request, collection_id=None,
tileRow = request.path_params['tileRow']
if 'tileCol' in request.path_params:
tileCol = request.path_params['tileCol']
return get_response(api_.get_collection_tiles_data(
request, collection_id, tileMatrixSetId,
tile_matrix, tileRow, tileCol))
return await get_response(
api_.get_collection_tiles_data, request, collection_id,
tileMatrixSetId, tile_matrix, tileRow, tileCol
)
async def collection_items(request: Request, collection_id=None, item_id=None):
@@ -252,38 +284,45 @@ async def collection_items(request: Request, collection_id=None, item_id=None):
item_id = request.path_params['item_id']
if item_id is None:
if request.method == 'GET': # list items
return get_response(
api_.get_collection_items(
request, collection_id))
return await get_response(
api_.get_collection_items, request, collection_id)
elif request.method == 'POST': # filter or manage items
content_type = request.headers.get('content-type')
if content_type is not None:
if content_type == 'application/geo+json':
return get_response(
api_.manage_collection_item(request, 'create',
collection_id))
return await get_response(
api_.manage_collection_item, request,
'create', collection_id)
else:
return get_response(
api_.post_collection_items(request, collection_id))
return await get_response(
api_.post_collection_items,
request,
collection_id
)
elif request.method == 'OPTIONS':
return get_response(
api_.manage_collection_item(request, 'options', collection_id))
return await get_response(
api_.manage_collection_item, request,
'options', collection_id
)
elif request.method == 'DELETE':
return get_response(
api_.manage_collection_item(request, 'delete',
collection_id, item_id))
return await get_response(
api_.manage_collection_item, request, 'delete',
collection_id, item_id
)
elif request.method == 'PUT':
return get_response(
api_.manage_collection_item(request, 'update',
collection_id, item_id))
return await get_response(
api_.manage_collection_item, request, 'update',
collection_id, item_id
)
elif request.method == 'OPTIONS':
return get_response(
api_.manage_collection_item(request, 'options',
collection_id, item_id))
return await get_response(
api_.manage_collection_item, request, 'options',
collection_id, item_id
)
else:
return get_response(api_.get_collection_item(
request, collection_id, item_id))
return await get_response(
api_.get_collection_item, request, collection_id, item_id)
async def collection_coverage(request: Request, collection_id=None):
@@ -298,7 +337,8 @@ async def collection_coverage(request: Request, collection_id=None):
if 'collection_id' in request.path_params:
collection_id = request.path_params['collection_id']
return get_response(api_.get_collection_coverage(request, collection_id))
return await get_response(
api_.get_collection_coverage, request, collection_id)
async def collection_coverage_domainset(request: Request, collection_id=None):
@@ -313,8 +353,8 @@ async def collection_coverage_domainset(request: Request, collection_id=None):
if 'collection_id' in request.path_params:
collection_id = request.path_params['collection_id']
return get_response(api_.get_collection_coverage_domainset(
request, collection_id))
return await get_response(
api_.get_collection_coverage_domainset, request, collection_id)
async def collection_coverage_rangetype(request: Request, collection_id=None):
@@ -330,8 +370,8 @@ async def collection_coverage_rangetype(request: Request, collection_id=None):
if 'collection_id' in request.path_params:
collection_id = request.path_params['collection_id']
return get_response(api_.get_collection_coverage_rangetype(
request, collection_id))
return await get_response(
api_.get_collection_coverage_rangetype, request, collection_id)
async def collection_map(request: Request, collection_id, style_id=None):
@@ -349,8 +389,8 @@ async def collection_map(request: Request, collection_id, style_id=None):
if 'style_id' in request.path_params:
style_id = request.path_params['style_id']
return get_response(api_.get_collection_map(
request, collection_id, style_id))
return await get_response(
api_.get_collection_map, request, collection_id, style_id)
async def get_processes(request: Request, process_id=None):
@@ -365,7 +405,7 @@ async def get_processes(request: Request, process_id=None):
if 'process_id' in request.path_params:
process_id = request.path_params['process_id']
return get_response(api_.describe_processes(request, process_id))
return await get_response(api_.describe_processes, request, process_id)
async def get_jobs(request: Request, job_id=None):
@@ -382,12 +422,12 @@ async def get_jobs(request: Request, job_id=None):
job_id = request.path_params['job_id']
if job_id is None: # list of submit job
return get_response(api_.get_jobs(request))
return await get_response(api_.get_jobs, request)
else: # get or delete job
if request.method == 'DELETE':
return get_response(api_.delete_job(job_id))
return await get_response(api_.delete_job, job_id)
else: # Return status of a specific job
return get_response(api_.get_jobs(request, job_id))
return await get_response(api_.get_jobs, request, job_id)
async def execute_process_jobs(request: Request, process_id=None):
@@ -403,7 +443,7 @@ async def execute_process_jobs(request: Request, process_id=None):
if 'process_id' in request.path_params:
process_id = request.path_params['process_id']
return get_response(api_.execute_process(request, process_id))
return await get_response(api_.execute_process, request, process_id)
async def get_job_result(request: Request, job_id=None):
@@ -419,7 +459,7 @@ async def get_job_result(request: Request, job_id=None):
if 'job_id' in request.path_params:
job_id = request.path_params['job_id']
return get_response(api_.get_job_result(request, job_id))
return await get_response(api_.get_job_result, request, job_id)
async def get_job_result_resource(request: Request,
@@ -439,8 +479,8 @@ async def get_job_result_resource(request: Request,
if 'resource' in request.path_params:
resource = request.path_params['resource']
return get_response(api_.get_job_result_resource(
request, job_id, resource))
return await get_response(
api_.get_job_result_resource, request, job_id, resource)
async def get_collection_edr_query(request: Request, collection_id=None, instance_id=None): # noqa
@@ -460,8 +500,10 @@ async def get_collection_edr_query(request: Request, collection_id=None, instanc
instance_id = request.path_params['instance_id']
query_type = request["path"].split('/')[-1] # noqa
return get_response(api_.get_collection_edr_query(request, collection_id,
instance_id, query_type))
return await get_response(
api_.get_collection_edr_query, request, collection_id,
instance_id, query_type
)
async def collections(request: Request, collection_id=None):
@@ -475,7 +517,8 @@ async def collections(request: Request, collection_id=None):
"""
if 'collection_id' in request.path_params:
collection_id = request.path_params['collection_id']
return get_response(api_.describe_collections(request, collection_id))
return await get_response(
api_.describe_collections, request, collection_id)
async def stac_catalog_root(request: Request):
@@ -486,7 +529,7 @@ async def stac_catalog_root(request: Request):
:returns: Starlette HTTP response
"""
return get_response(api_.get_stac_root(request))
return await get_response(api_.get_stac_root, request)
async def stac_catalog_path(request: Request):
@@ -498,7 +541,7 @@ async def stac_catalog_path(request: Request):
:returns: Starlette HTTP response
"""
path = request.path_params["path"]
return get_response(api_.get_stac_path(request, path))
return await get_response(api_.get_stac_path, request, path)
async def admin_config(request: Request):
@@ -509,11 +552,11 @@ async def admin_config(request: Request):
"""
if request.method == 'GET':
return get_response(ADMIN.get_config(request))
return await get_response(ADMIN.get_config, request)
elif request.method == 'PUT':
return get_response(ADMIN.put_config(request))
return await get_response(ADMIN.put_config, request)
elif request.method == 'PATCH':
return get_response(ADMIN.patch_config(request))
return await get_response(ADMIN.patch_config, request)
async def admin_config_resources(request: Request):
@@ -524,9 +567,9 @@ async def admin_config_resources(request: Request):
"""
if request.method == 'GET':
return get_response(ADMIN.get_resources(request))
return await get_response(ADMIN.get_resources, request)
elif request.method == 'POST':
return get_response(ADMIN.put_resource(request))
return await get_response(ADMIN.put_resource, request)
async def admin_config_resource(request: Request, resource_id: str):
@@ -542,13 +585,17 @@ async def admin_config_resource(request: Request, resource_id: str):
resource_id = request.path_params['resource_id']
if request.method == 'GET':
return get_response(ADMIN.get_resource(request, resource_id))
return await get_response(
ADMIN.get_resource, request, resource_id)
elif request.method == 'PUT':
return get_response(ADMIN.put_resource(request, resource_id))
return await get_response(
ADMIN.put_resource, request, resource_id)
elif request.method == 'PATCH':
return get_response(ADMIN.patch_resource(request, resource_id))
return await get_response(
ADMIN.patch_resource, request, resource_id)
elif request.method == 'DELETE':
return get_response(ADMIN.delete_resource(request, resource_id))
return await get_response(
ADMIN.delete_resource, request, resource_id)
class ApiRulesMiddleware: