port PostgreSQL process manager to SQLAlchemy (#1745)
This commit is contained in:
@@ -45,15 +45,16 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Tuple
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
from sqlalchemy import insert, update, delete
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from pygeoapi.process.manager.base import BaseManager
|
||||
from pygeoapi.process.base import (
|
||||
JobNotFoundError,
|
||||
JobResultNotFoundError,
|
||||
ProcessorGenericError,
|
||||
ProcessorGenericError
|
||||
)
|
||||
from pygeoapi.process.manager.base import BaseManager
|
||||
from pygeoapi.provider.postgresql import get_engine, get_table_model
|
||||
from pygeoapi.util import JobStatus
|
||||
|
||||
|
||||
@@ -61,7 +62,7 @@ LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PostgreSQLManager(BaseManager):
|
||||
"""PostgreSql Manager"""
|
||||
"""PostgreSQL Manager"""
|
||||
|
||||
def __init__(self, manager_def: dict):
|
||||
"""
|
||||
@@ -74,30 +75,39 @@ class PostgreSQLManager(BaseManager):
|
||||
|
||||
super().__init__(manager_def)
|
||||
self.is_async = True
|
||||
self.id_field = 'identifier'
|
||||
self.supports_subscribing = True
|
||||
self.connection = manager_def['connection']
|
||||
|
||||
self.__database_connection_parameters = manager_def['connection']
|
||||
try:
|
||||
# Test connection parameters:
|
||||
test_query = """SELECT version()"""
|
||||
with self.get_db_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(test_query)
|
||||
cur.fetchone()
|
||||
self.db_search_path = tuple(self.connection.get('search_path',
|
||||
['public']))
|
||||
except Exception:
|
||||
self.db_search_path = 'public'
|
||||
|
||||
try:
|
||||
LOGGER.debug('Connecting to database')
|
||||
if isinstance(self.connection, str):
|
||||
self._engine = get_engine(self.connection)
|
||||
else:
|
||||
self._engine = get_engine(**self.connection)
|
||||
except Exception as err:
|
||||
LOGGER.error(f'Test connecting to DB failed: {err}')
|
||||
raise ProcessorGenericError('Test connecting to DB failed.')
|
||||
msg = 'Test connecting to DB failed'
|
||||
LOGGER.error(f'{msg}: {err}')
|
||||
raise ProcessorGenericError(msg)
|
||||
|
||||
def get_db_connection(self):
|
||||
"""
|
||||
Get and return a new connection to the DB.
|
||||
"""
|
||||
if isinstance(self.__database_connection_parameters, str):
|
||||
conn = psycopg2.connect(self.__database_connection_parameters)
|
||||
else:
|
||||
conn = psycopg2.connect(**self.__database_connection_parameters)
|
||||
|
||||
return conn
|
||||
try:
|
||||
LOGGER.debug('Getting table model')
|
||||
self.table_model = get_table_model(
|
||||
'jobs',
|
||||
self.id_field,
|
||||
self.db_search_path,
|
||||
self._engine
|
||||
)
|
||||
except Exception as err:
|
||||
msg = 'Table model fetch failed'
|
||||
LOGGER.error(f'{msg}: {err}')
|
||||
raise ProcessorGenericError(msg)
|
||||
|
||||
def get_jobs(self, status: JobStatus = None) -> list:
|
||||
"""
|
||||
@@ -111,16 +121,14 @@ class PostgreSQLManager(BaseManager):
|
||||
mimetype, message, progress)
|
||||
"""
|
||||
|
||||
with self.get_db_connection() as conn:
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
query_select = """SELECT * FROM jobs """
|
||||
if status is not None:
|
||||
query_select = query_select + "WHERE status = %s"
|
||||
query_params = [status.value]
|
||||
else:
|
||||
query_params = []
|
||||
cur.execute(query_select, query_params)
|
||||
return cur.fetchall()
|
||||
LOGGER.debug('Querying for jobs')
|
||||
with Session(self._engine) as session:
|
||||
results = session.query(self.table_model)
|
||||
if status is not None:
|
||||
column = getattr(self.table_model, 'status')
|
||||
results = results.filter(column == status.value)
|
||||
|
||||
return [r.__dict__ for r in results.all()]
|
||||
|
||||
def add_job(self, job_metadata: dict) -> str:
|
||||
"""
|
||||
@@ -131,16 +139,18 @@ class PostgreSQLManager(BaseManager):
|
||||
:returns: identifier of added job
|
||||
"""
|
||||
|
||||
query_insert = """INSERT INTO jobs(
|
||||
type, process_id, identifier, status, message,
|
||||
progress, job_start_datetime, job_end_datetime
|
||||
) VALUES(%(type)s, %(process_id)s, %(identifier)s, %(status)s,
|
||||
%(message)s, %(progress)s, %(job_start_datetime)s,
|
||||
%(job_end_datetime)s);"""
|
||||
with self.get_db_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(query_insert, job_metadata)
|
||||
conn.commit()
|
||||
LOGGER.debug('Adding job')
|
||||
with Session(self._engine) as session:
|
||||
try:
|
||||
session.execute(insert(self.table_model)
|
||||
.values(**job_metadata))
|
||||
session.commit()
|
||||
except Exception as err:
|
||||
session.rollback()
|
||||
msg = 'Insert failed'
|
||||
LOGGER.error(f'{msg}: {err}')
|
||||
raise ProcessorGenericError(msg)
|
||||
|
||||
return job_metadata['identifier']
|
||||
|
||||
def update_job(self, job_id: str, update_dict: dict) -> bool:
|
||||
@@ -153,30 +163,25 @@ class PostgreSQLManager(BaseManager):
|
||||
:returns: `bool` of status result
|
||||
"""
|
||||
|
||||
query_update = "UPDATE jobs SET ("
|
||||
keys_to_update = 0
|
||||
for key in update_dict.keys():
|
||||
if keys_to_update:
|
||||
query_update = query_update + (", ")
|
||||
query_update = query_update + key
|
||||
keys_to_update = keys_to_update + 1
|
||||
rowcount = 0
|
||||
|
||||
query_update = query_update + ") = ("
|
||||
keys_to_update = 0
|
||||
for key in update_dict.keys():
|
||||
if keys_to_update:
|
||||
query_update = query_update + (", ")
|
||||
query_update = query_update + "%(" + key + ")s"
|
||||
keys_to_update = keys_to_update + 1
|
||||
query_update = query_update + (") WHERE identifier = %(identifier)s")
|
||||
|
||||
update_dict['identifier'] = job_id
|
||||
|
||||
with self.get_db_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(query_update, update_dict)
|
||||
rowcount = cur.rowcount
|
||||
conn.commit()
|
||||
LOGGER.debug('Updating job')
|
||||
with Session(self._engine) as session:
|
||||
try:
|
||||
column = getattr(self.table_model, self.id_field)
|
||||
stmt = (
|
||||
update(self.table_model)
|
||||
.where(column == job_id)
|
||||
.values(**update_dict)
|
||||
)
|
||||
result = session.execute(stmt)
|
||||
session.commit()
|
||||
rowcount = result.rowcount
|
||||
except Exception as err:
|
||||
session.rollback()
|
||||
msg = 'Update failed'
|
||||
LOGGER.error(f'{msg}: {err}')
|
||||
raise ProcessorGenericError(msg)
|
||||
|
||||
return rowcount == 1
|
||||
|
||||
@@ -191,18 +196,18 @@ class PostgreSQLManager(BaseManager):
|
||||
:returns: `dict` # `pygeoapi.process.manager.Job`
|
||||
"""
|
||||
|
||||
with self.get_db_connection() as conn:
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
query_select = \
|
||||
"""SELECT * FROM jobs WHERE identifier = %s"""
|
||||
query_params = [job_id]
|
||||
cur.execute(query_select, query_params)
|
||||
found = cur.fetchone()
|
||||
LOGGER.debug('Querying for job')
|
||||
with Session(self._engine) as session:
|
||||
results = session.query(self.table_model)
|
||||
column = getattr(self.table_model, self.id_field)
|
||||
results = session.query(self.table_model).filter(column == job_id)
|
||||
|
||||
if found is not None:
|
||||
return found
|
||||
else:
|
||||
raise JobNotFoundError()
|
||||
first = results.first()
|
||||
|
||||
if first is not None:
|
||||
return first.__dict__
|
||||
else:
|
||||
raise JobNotFoundError()
|
||||
|
||||
def delete_job(self, job_id: str) -> bool:
|
||||
"""
|
||||
@@ -214,22 +219,37 @@ class PostgreSQLManager(BaseManager):
|
||||
known job
|
||||
:return `bool` of status result
|
||||
"""
|
||||
# delete result file if present
|
||||
|
||||
rowcount = 0
|
||||
|
||||
# get result file if present for deletion
|
||||
job_result = self.get_job(job_id)
|
||||
location = job_result.get('location')
|
||||
if location and self.output_dir is not None:
|
||||
|
||||
LOGGER.debug('Deleting job')
|
||||
with Session(self._engine) as session:
|
||||
try:
|
||||
column = getattr(self.table_model, self.id_field)
|
||||
stmt = (
|
||||
delete(self.table_model)
|
||||
.where(column == job_id)
|
||||
)
|
||||
result = session.execute(stmt)
|
||||
session.commit()
|
||||
rowcount = result.rowcount
|
||||
except Exception as err:
|
||||
session.rollback()
|
||||
msg = 'Delete failed'
|
||||
LOGGER.error(f'{msg}: {err}')
|
||||
raise ProcessorGenericError(msg)
|
||||
|
||||
# delete result file if present
|
||||
if None not in [location, self.output_dir]:
|
||||
try:
|
||||
Path(location).unlink()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
query_delete = "DELETE FROM jobs WHERE identifier = %s"
|
||||
with self.get_db_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(query_delete, [job_id])
|
||||
rowcount = cur.rowcount
|
||||
conn.commit()
|
||||
|
||||
return rowcount == 1
|
||||
|
||||
def get_job_result(self, job_id: str) -> Tuple[str, Any]:
|
||||
|
||||
+1
-1
@@ -55,7 +55,7 @@ server:
|
||||
user: postgres
|
||||
password: ${POSTGRESQL_PASSWORD:-postgres}
|
||||
# Alternative accepted connection definition:
|
||||
# connection: postgresql://postgres:postgres@localhost:5432/test
|
||||
# connection: postgresql://postgres:${POSTGRESQL_PASSWORD:-postgres}@localhost:5432/test
|
||||
output_dir: /tmp
|
||||
|
||||
logging:
|
||||
@@ -45,7 +45,7 @@ from pygeoapi.util import yaml_load
|
||||
@pytest.fixture()
|
||||
def config():
|
||||
with open(get_test_file_path(
|
||||
'pygeoapi-test-config-postgres-manager.yml')
|
||||
'pygeoapi-test-config-postgresql-manager.yml')
|
||||
) as fh:
|
||||
return yaml_load(fh)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user