run api in another thread when using starlette (#1533)
This commit is contained in:
committed by
GitHub
parent
ffd33fafc1
commit
88ae474627
+1
-1
@@ -370,7 +370,7 @@ class APIRequest:
|
||||
# has been implemented, with_data() can become async too
|
||||
loop = asyncio.get_event_loop()
|
||||
api_req._data = asyncio.run_coroutine_threadsafe(
|
||||
request.body(), loop)
|
||||
request.body(), loop).result(1)
|
||||
return api_req
|
||||
|
||||
@staticmethod
|
||||
|
||||
+112
-65
@@ -32,8 +32,9 @@
|
||||
# =================================================================
|
||||
""" Starlette module providing the route paths to the api"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Union
|
||||
from typing import Callable, Union
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
@@ -79,16 +80,42 @@ API_RULES = get_api_rules(CONFIG)
|
||||
api_ = API(CONFIG, OPENAPI)
|
||||
|
||||
|
||||
def get_response(result: tuple) -> Union[Response, JSONResponse, HTMLResponse]:
|
||||
def call_api_threadsafe(
|
||||
loop: asyncio.AbstractEventLoop, api_call: Callable, *args
|
||||
) -> tuple:
|
||||
"""
|
||||
The api call needs a running loop. This method is meant to be called
|
||||
from a thread that has no loop running.
|
||||
|
||||
:param loop: The loop to use.
|
||||
:param api_call: The API method to call.
|
||||
:param args: Arguments to pass to the API method.
|
||||
:returns: The api call result tuple.
|
||||
"""
|
||||
asyncio.set_event_loop(loop)
|
||||
return api_call(*args)
|
||||
|
||||
|
||||
async def get_response(
|
||||
api_call,
|
||||
*args,
|
||||
) -> Union[Response, JSONResponse, HTMLResponse]:
|
||||
"""
|
||||
Creates a Starlette Response object and updates matching headers.
|
||||
|
||||
Runs the core api handler in a separate thread in order to avoid
|
||||
blocking the main event loop.
|
||||
|
||||
:param result: The result of the API call.
|
||||
This should be a tuple of (headers, status, content).
|
||||
|
||||
:returns: A Response instance.
|
||||
"""
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None, call_api_threadsafe, loop, api_call, *args)
|
||||
|
||||
headers, status, content = result
|
||||
if headers['Content-Type'] == 'text/html':
|
||||
response = HTMLResponse(content=content, status_code=status)
|
||||
@@ -111,7 +138,7 @@ async def landing_page(request: Request):
|
||||
|
||||
:returns: Starlette HTTP Response
|
||||
"""
|
||||
return get_response(api_.landing_page(request))
|
||||
return await get_response(api_.landing_page, request)
|
||||
|
||||
|
||||
async def openapi(request: Request):
|
||||
@@ -122,7 +149,7 @@ async def openapi(request: Request):
|
||||
|
||||
:returns: Starlette HTTP Response
|
||||
"""
|
||||
return get_response(api_.openapi_(request))
|
||||
return await get_response(api_.openapi_, request)
|
||||
|
||||
|
||||
async def conformance(request: Request):
|
||||
@@ -133,7 +160,7 @@ async def conformance(request: Request):
|
||||
|
||||
:returns: Starlette HTTP Response
|
||||
"""
|
||||
return get_response(api_.conformance(request))
|
||||
return await get_response(api_.conformance, request)
|
||||
|
||||
|
||||
async def get_tilematrix_set(request: Request, tileMatrixSetId=None):
|
||||
@@ -146,7 +173,8 @@ async def get_tilematrix_set(request: Request, tileMatrixSetId=None):
|
||||
if 'tileMatrixSetId' in request.path_params:
|
||||
tileMatrixSetId = request.path_params['tileMatrixSetId']
|
||||
|
||||
return get_response(api_.tilematrixset(request, tileMatrixSetId))
|
||||
return await get_response(
|
||||
api_.tilematrixset, request, tileMatrixSetId)
|
||||
|
||||
|
||||
async def get_tilematrix_sets(request: Request):
|
||||
@@ -155,7 +183,7 @@ async def get_tilematrix_sets(request: Request):
|
||||
|
||||
:returns: HTTP response
|
||||
"""
|
||||
return get_response(api_.tilematrixsets(request))
|
||||
return await get_response(api_.tilematrixsets, request)
|
||||
|
||||
|
||||
async def collection_queryables(request: Request, collection_id=None):
|
||||
@@ -169,7 +197,8 @@ async def collection_queryables(request: Request, collection_id=None):
|
||||
"""
|
||||
if 'collection_id' in request.path_params:
|
||||
collection_id = request.path_params['collection_id']
|
||||
return get_response(api_.get_collection_queryables(request, collection_id))
|
||||
return await get_response(
|
||||
api_.get_collection_queryables, request, collection_id)
|
||||
|
||||
|
||||
async def get_collection_tiles(request: Request, collection_id=None):
|
||||
@@ -183,8 +212,8 @@ async def get_collection_tiles(request: Request, collection_id=None):
|
||||
"""
|
||||
if 'collection_id' in request.path_params:
|
||||
collection_id = request.path_params['collection_id']
|
||||
return get_response(api_.get_collection_tiles(
|
||||
request, collection_id))
|
||||
return await get_response(
|
||||
api_.get_collection_tiles, request, collection_id)
|
||||
|
||||
|
||||
async def get_collection_tiles_metadata(request: Request, collection_id=None,
|
||||
@@ -201,8 +230,10 @@ async def get_collection_tiles_metadata(request: Request, collection_id=None,
|
||||
collection_id = request.path_params['collection_id']
|
||||
if 'tileMatrixSetId' in request.path_params:
|
||||
tileMatrixSetId = request.path_params['tileMatrixSetId']
|
||||
return get_response(api_.get_collection_tiles_metadata(
|
||||
request, collection_id, tileMatrixSetId))
|
||||
return await get_response(
|
||||
api_.get_collection_tiles_metadata, request,
|
||||
collection_id, tileMatrixSetId
|
||||
)
|
||||
|
||||
|
||||
async def get_collection_items_tiles(request: Request, collection_id=None,
|
||||
@@ -230,9 +261,10 @@ async def get_collection_items_tiles(request: Request, collection_id=None,
|
||||
tileRow = request.path_params['tileRow']
|
||||
if 'tileCol' in request.path_params:
|
||||
tileCol = request.path_params['tileCol']
|
||||
return get_response(api_.get_collection_tiles_data(
|
||||
request, collection_id, tileMatrixSetId,
|
||||
tile_matrix, tileRow, tileCol))
|
||||
return await get_response(
|
||||
api_.get_collection_tiles_data, request, collection_id,
|
||||
tileMatrixSetId, tile_matrix, tileRow, tileCol
|
||||
)
|
||||
|
||||
|
||||
async def collection_items(request: Request, collection_id=None, item_id=None):
|
||||
@@ -252,38 +284,45 @@ async def collection_items(request: Request, collection_id=None, item_id=None):
|
||||
item_id = request.path_params['item_id']
|
||||
if item_id is None:
|
||||
if request.method == 'GET': # list items
|
||||
return get_response(
|
||||
api_.get_collection_items(
|
||||
request, collection_id))
|
||||
return await get_response(
|
||||
api_.get_collection_items, request, collection_id)
|
||||
elif request.method == 'POST': # filter or manage items
|
||||
content_type = request.headers.get('content-type')
|
||||
if content_type is not None:
|
||||
if content_type == 'application/geo+json':
|
||||
return get_response(
|
||||
api_.manage_collection_item(request, 'create',
|
||||
collection_id))
|
||||
return await get_response(
|
||||
api_.manage_collection_item, request,
|
||||
'create', collection_id)
|
||||
else:
|
||||
return get_response(
|
||||
api_.post_collection_items(request, collection_id))
|
||||
return await get_response(
|
||||
api_.post_collection_items,
|
||||
request,
|
||||
collection_id
|
||||
)
|
||||
elif request.method == 'OPTIONS':
|
||||
return get_response(
|
||||
api_.manage_collection_item(request, 'options', collection_id))
|
||||
return await get_response(
|
||||
api_.manage_collection_item, request,
|
||||
'options', collection_id
|
||||
)
|
||||
|
||||
elif request.method == 'DELETE':
|
||||
return get_response(
|
||||
api_.manage_collection_item(request, 'delete',
|
||||
collection_id, item_id))
|
||||
return await get_response(
|
||||
api_.manage_collection_item, request, 'delete',
|
||||
collection_id, item_id
|
||||
)
|
||||
elif request.method == 'PUT':
|
||||
return get_response(
|
||||
api_.manage_collection_item(request, 'update',
|
||||
collection_id, item_id))
|
||||
return await get_response(
|
||||
api_.manage_collection_item, request, 'update',
|
||||
collection_id, item_id
|
||||
)
|
||||
elif request.method == 'OPTIONS':
|
||||
return get_response(
|
||||
api_.manage_collection_item(request, 'options',
|
||||
collection_id, item_id))
|
||||
return await get_response(
|
||||
api_.manage_collection_item, request, 'options',
|
||||
collection_id, item_id
|
||||
)
|
||||
else:
|
||||
return get_response(api_.get_collection_item(
|
||||
request, collection_id, item_id))
|
||||
return await get_response(
|
||||
api_.get_collection_item, request, collection_id, item_id)
|
||||
|
||||
|
||||
async def collection_coverage(request: Request, collection_id=None):
|
||||
@@ -298,7 +337,8 @@ async def collection_coverage(request: Request, collection_id=None):
|
||||
if 'collection_id' in request.path_params:
|
||||
collection_id = request.path_params['collection_id']
|
||||
|
||||
return get_response(api_.get_collection_coverage(request, collection_id))
|
||||
return await get_response(
|
||||
api_.get_collection_coverage, request, collection_id)
|
||||
|
||||
|
||||
async def collection_coverage_domainset(request: Request, collection_id=None):
|
||||
@@ -313,8 +353,8 @@ async def collection_coverage_domainset(request: Request, collection_id=None):
|
||||
if 'collection_id' in request.path_params:
|
||||
collection_id = request.path_params['collection_id']
|
||||
|
||||
return get_response(api_.get_collection_coverage_domainset(
|
||||
request, collection_id))
|
||||
return await get_response(
|
||||
api_.get_collection_coverage_domainset, request, collection_id)
|
||||
|
||||
|
||||
async def collection_coverage_rangetype(request: Request, collection_id=None):
|
||||
@@ -330,8 +370,8 @@ async def collection_coverage_rangetype(request: Request, collection_id=None):
|
||||
if 'collection_id' in request.path_params:
|
||||
collection_id = request.path_params['collection_id']
|
||||
|
||||
return get_response(api_.get_collection_coverage_rangetype(
|
||||
request, collection_id))
|
||||
return await get_response(
|
||||
api_.get_collection_coverage_rangetype, request, collection_id)
|
||||
|
||||
|
||||
async def collection_map(request: Request, collection_id, style_id=None):
|
||||
@@ -349,8 +389,8 @@ async def collection_map(request: Request, collection_id, style_id=None):
|
||||
if 'style_id' in request.path_params:
|
||||
style_id = request.path_params['style_id']
|
||||
|
||||
return get_response(api_.get_collection_map(
|
||||
request, collection_id, style_id))
|
||||
return await get_response(
|
||||
api_.get_collection_map, request, collection_id, style_id)
|
||||
|
||||
|
||||
async def get_processes(request: Request, process_id=None):
|
||||
@@ -365,7 +405,7 @@ async def get_processes(request: Request, process_id=None):
|
||||
if 'process_id' in request.path_params:
|
||||
process_id = request.path_params['process_id']
|
||||
|
||||
return get_response(api_.describe_processes(request, process_id))
|
||||
return await get_response(api_.describe_processes, request, process_id)
|
||||
|
||||
|
||||
async def get_jobs(request: Request, job_id=None):
|
||||
@@ -382,12 +422,12 @@ async def get_jobs(request: Request, job_id=None):
|
||||
job_id = request.path_params['job_id']
|
||||
|
||||
if job_id is None: # list of submit job
|
||||
return get_response(api_.get_jobs(request))
|
||||
return await get_response(api_.get_jobs, request)
|
||||
else: # get or delete job
|
||||
if request.method == 'DELETE':
|
||||
return get_response(api_.delete_job(job_id))
|
||||
return await get_response(api_.delete_job, job_id)
|
||||
else: # Return status of a specific job
|
||||
return get_response(api_.get_jobs(request, job_id))
|
||||
return await get_response(api_.get_jobs, request, job_id)
|
||||
|
||||
|
||||
async def execute_process_jobs(request: Request, process_id=None):
|
||||
@@ -403,7 +443,7 @@ async def execute_process_jobs(request: Request, process_id=None):
|
||||
if 'process_id' in request.path_params:
|
||||
process_id = request.path_params['process_id']
|
||||
|
||||
return get_response(api_.execute_process(request, process_id))
|
||||
return await get_response(api_.execute_process, request, process_id)
|
||||
|
||||
|
||||
async def get_job_result(request: Request, job_id=None):
|
||||
@@ -419,7 +459,7 @@ async def get_job_result(request: Request, job_id=None):
|
||||
if 'job_id' in request.path_params:
|
||||
job_id = request.path_params['job_id']
|
||||
|
||||
return get_response(api_.get_job_result(request, job_id))
|
||||
return await get_response(api_.get_job_result, request, job_id)
|
||||
|
||||
|
||||
async def get_job_result_resource(request: Request,
|
||||
@@ -439,8 +479,8 @@ async def get_job_result_resource(request: Request,
|
||||
if 'resource' in request.path_params:
|
||||
resource = request.path_params['resource']
|
||||
|
||||
return get_response(api_.get_job_result_resource(
|
||||
request, job_id, resource))
|
||||
return await get_response(
|
||||
api_.get_job_result_resource, request, job_id, resource)
|
||||
|
||||
|
||||
async def get_collection_edr_query(request: Request, collection_id=None, instance_id=None): # noqa
|
||||
@@ -460,8 +500,10 @@ async def get_collection_edr_query(request: Request, collection_id=None, instanc
|
||||
instance_id = request.path_params['instance_id']
|
||||
|
||||
query_type = request["path"].split('/')[-1] # noqa
|
||||
return get_response(api_.get_collection_edr_query(request, collection_id,
|
||||
instance_id, query_type))
|
||||
return await get_response(
|
||||
api_.get_collection_edr_query, request, collection_id,
|
||||
instance_id, query_type
|
||||
)
|
||||
|
||||
|
||||
async def collections(request: Request, collection_id=None):
|
||||
@@ -475,7 +517,8 @@ async def collections(request: Request, collection_id=None):
|
||||
"""
|
||||
if 'collection_id' in request.path_params:
|
||||
collection_id = request.path_params['collection_id']
|
||||
return get_response(api_.describe_collections(request, collection_id))
|
||||
return await get_response(
|
||||
api_.describe_collections, request, collection_id)
|
||||
|
||||
|
||||
async def stac_catalog_root(request: Request):
|
||||
@@ -486,7 +529,7 @@ async def stac_catalog_root(request: Request):
|
||||
|
||||
:returns: Starlette HTTP response
|
||||
"""
|
||||
return get_response(api_.get_stac_root(request))
|
||||
return await get_response(api_.get_stac_root, request)
|
||||
|
||||
|
||||
async def stac_catalog_path(request: Request):
|
||||
@@ -498,7 +541,7 @@ async def stac_catalog_path(request: Request):
|
||||
:returns: Starlette HTTP response
|
||||
"""
|
||||
path = request.path_params["path"]
|
||||
return get_response(api_.get_stac_path(request, path))
|
||||
return await get_response(api_.get_stac_path, request, path)
|
||||
|
||||
|
||||
async def admin_config(request: Request):
|
||||
@@ -509,11 +552,11 @@ async def admin_config(request: Request):
|
||||
"""
|
||||
|
||||
if request.method == 'GET':
|
||||
return get_response(ADMIN.get_config(request))
|
||||
return await get_response(ADMIN.get_config, request)
|
||||
elif request.method == 'PUT':
|
||||
return get_response(ADMIN.put_config(request))
|
||||
return await get_response(ADMIN.put_config, request)
|
||||
elif request.method == 'PATCH':
|
||||
return get_response(ADMIN.patch_config(request))
|
||||
return await get_response(ADMIN.patch_config, request)
|
||||
|
||||
|
||||
async def admin_config_resources(request: Request):
|
||||
@@ -524,9 +567,9 @@ async def admin_config_resources(request: Request):
|
||||
"""
|
||||
|
||||
if request.method == 'GET':
|
||||
return get_response(ADMIN.get_resources(request))
|
||||
return await get_response(ADMIN.get_resources, request)
|
||||
elif request.method == 'POST':
|
||||
return get_response(ADMIN.put_resource(request))
|
||||
return await get_response(ADMIN.put_resource, request)
|
||||
|
||||
|
||||
async def admin_config_resource(request: Request, resource_id: str):
|
||||
@@ -542,13 +585,17 @@ async def admin_config_resource(request: Request, resource_id: str):
|
||||
resource_id = request.path_params['resource_id']
|
||||
|
||||
if request.method == 'GET':
|
||||
return get_response(ADMIN.get_resource(request, resource_id))
|
||||
return await get_response(
|
||||
ADMIN.get_resource, request, resource_id)
|
||||
elif request.method == 'PUT':
|
||||
return get_response(ADMIN.put_resource(request, resource_id))
|
||||
return await get_response(
|
||||
ADMIN.put_resource, request, resource_id)
|
||||
elif request.method == 'PATCH':
|
||||
return get_response(ADMIN.patch_resource(request, resource_id))
|
||||
return await get_response(
|
||||
ADMIN.patch_resource, request, resource_id)
|
||||
elif request.method == 'DELETE':
|
||||
return get_response(ADMIN.delete_resource(request, resource_id))
|
||||
return await get_response(
|
||||
ADMIN.delete_resource, request, resource_id)
|
||||
|
||||
|
||||
class ApiRulesMiddleware:
|
||||
|
||||
Reference in New Issue
Block a user