diff --git a/pygeoapi/process/manager/postgresql.py b/pygeoapi/process/manager/postgresql.py index 017de7f..8215b12 100644 --- a/pygeoapi/process/manager/postgresql.py +++ b/pygeoapi/process/manager/postgresql.py @@ -45,15 +45,16 @@ import logging from pathlib import Path from typing import Any, Tuple -import psycopg2 -import psycopg2.extras +from sqlalchemy import insert, update, delete +from sqlalchemy.orm import Session -from pygeoapi.process.manager.base import BaseManager from pygeoapi.process.base import ( JobNotFoundError, JobResultNotFoundError, - ProcessorGenericError, + ProcessorGenericError ) +from pygeoapi.process.manager.base import BaseManager +from pygeoapi.provider.postgresql import get_engine, get_table_model from pygeoapi.util import JobStatus @@ -61,7 +62,7 @@ LOGGER = logging.getLogger(__name__) class PostgreSQLManager(BaseManager): - """PostgreSql Manager""" + """PostgreSQL Manager""" def __init__(self, manager_def: dict): """ @@ -74,30 +75,39 @@ class PostgreSQLManager(BaseManager): super().__init__(manager_def) self.is_async = True + self.id_field = 'identifier' self.supports_subscribing = True + self.connection = manager_def['connection'] - self.__database_connection_parameters = manager_def['connection'] try: - # Test connection parameters: - test_query = """SELECT version()""" - with self.get_db_connection() as conn: - with conn.cursor() as cur: - cur.execute(test_query) - cur.fetchone() + self.db_search_path = tuple(self.connection.get('search_path', + ['public'])) + except Exception: + self.db_search_path = 'public' + + try: + LOGGER.debug('Connecting to database') + if isinstance(self.connection, str): + self._engine = get_engine(self.connection) + else: + self._engine = get_engine(**self.connection) except Exception as err: - LOGGER.error(f'Test connecting to DB failed: {err}') - raise ProcessorGenericError('Test connecting to DB failed.') + msg = 'Test connecting to DB failed' + LOGGER.error(f'{msg}: {err}') + raise ProcessorGenericError(msg) - def get_db_connection(self): - """ - Get and return a new connection to the DB. - """ - if isinstance(self.__database_connection_parameters, str): - conn = psycopg2.connect(self.__database_connection_parameters) - else: - conn = psycopg2.connect(**self.__database_connection_parameters) - - return conn + try: + LOGGER.debug('Getting table model') + self.table_model = get_table_model( + 'jobs', + self.id_field, + self.db_search_path, + self._engine + ) + except Exception as err: + msg = 'Table model fetch failed' + LOGGER.error(f'{msg}: {err}') + raise ProcessorGenericError(msg) def get_jobs(self, status: JobStatus = None) -> list: """ @@ -111,16 +121,14 @@ class PostgreSQLManager(BaseManager): mimetype, message, progress) """ - with self.get_db_connection() as conn: - with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: - query_select = """SELECT * FROM jobs """ - if status is not None: - query_select = query_select + "WHERE status = %s" - query_params = [status.value] - else: - query_params = [] - cur.execute(query_select, query_params) - return cur.fetchall() + LOGGER.debug('Querying for jobs') + with Session(self._engine) as session: + results = session.query(self.table_model) + if status is not None: + column = getattr(self.table_model, 'status') + results = results.filter(column == status.value) + + return [r.__dict__ for r in results.all()] def add_job(self, job_metadata: dict) -> str: """ @@ -131,16 +139,18 @@ class PostgreSQLManager(BaseManager): :returns: identifier of added job """ - query_insert = """INSERT INTO jobs( - type, process_id, identifier, status, message, - progress, job_start_datetime, job_end_datetime - ) VALUES(%(type)s, %(process_id)s, %(identifier)s, %(status)s, - %(message)s, %(progress)s, %(job_start_datetime)s, - %(job_end_datetime)s);""" - with self.get_db_connection() as conn: - with conn.cursor() as cur: - cur.execute(query_insert, job_metadata) - conn.commit() + LOGGER.debug('Adding job') + with Session(self._engine) as session: + try: + session.execute(insert(self.table_model) + .values(**job_metadata)) + session.commit() + except Exception as err: + session.rollback() + msg = 'Insert failed' + LOGGER.error(f'{msg}: {err}') + raise ProcessorGenericError(msg) + return job_metadata['identifier'] def update_job(self, job_id: str, update_dict: dict) -> bool: @@ -153,30 +163,25 @@ class PostgreSQLManager(BaseManager): :returns: `bool` of status result """ - query_update = "UPDATE jobs SET (" - keys_to_update = 0 - for key in update_dict.keys(): - if keys_to_update: - query_update = query_update + (", ") - query_update = query_update + key - keys_to_update = keys_to_update + 1 + rowcount = 0 - query_update = query_update + ") = (" - keys_to_update = 0 - for key in update_dict.keys(): - if keys_to_update: - query_update = query_update + (", ") - query_update = query_update + "%(" + key + ")s" - keys_to_update = keys_to_update + 1 - query_update = query_update + (") WHERE identifier = %(identifier)s") - - update_dict['identifier'] = job_id - - with self.get_db_connection() as conn: - with conn.cursor() as cur: - cur.execute(query_update, update_dict) - rowcount = cur.rowcount - conn.commit() + LOGGER.debug('Updating job') + with Session(self._engine) as session: + try: + column = getattr(self.table_model, self.id_field) + stmt = ( + update(self.table_model) + .where(column == job_id) + .values(**update_dict) + ) + result = session.execute(stmt) + session.commit() + rowcount = result.rowcount + except Exception as err: + session.rollback() + msg = 'Update failed' + LOGGER.error(f'{msg}: {err}') + raise ProcessorGenericError(msg) return rowcount == 1 @@ -191,18 +196,18 @@ class PostgreSQLManager(BaseManager): :returns: `dict` # `pygeoapi.process.manager.Job` """ - with self.get_db_connection() as conn: - with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: - query_select = \ - """SELECT * FROM jobs WHERE identifier = %s""" - query_params = [job_id] - cur.execute(query_select, query_params) - found = cur.fetchone() + LOGGER.debug('Querying for job') + with Session(self._engine) as session: + results = session.query(self.table_model) + column = getattr(self.table_model, self.id_field) + results = session.query(self.table_model).filter(column == job_id) - if found is not None: - return found - else: - raise JobNotFoundError() + first = results.first() + + if first is not None: + return first.__dict__ + else: + raise JobNotFoundError() def delete_job(self, job_id: str) -> bool: """ @@ -214,22 +219,37 @@ class PostgreSQLManager(BaseManager): known job :return `bool` of status result """ - # delete result file if present + + rowcount = 0 + + # get result file if present for deletion job_result = self.get_job(job_id) location = job_result.get('location') - if location and self.output_dir is not None: + + LOGGER.debug('Deleting job') + with Session(self._engine) as session: + try: + column = getattr(self.table_model, self.id_field) + stmt = ( + delete(self.table_model) + .where(column == job_id) + ) + result = session.execute(stmt) + session.commit() + rowcount = result.rowcount + except Exception as err: + session.rollback() + msg = 'Delete failed' + LOGGER.error(f'{msg}: {err}') + raise ProcessorGenericError(msg) + + # delete result file if present + if None not in [location, self.output_dir]: try: Path(location).unlink() except FileNotFoundError: pass - query_delete = "DELETE FROM jobs WHERE identifier = %s" - with self.get_db_connection() as conn: - with conn.cursor() as cur: - cur.execute(query_delete, [job_id]) - rowcount = cur.rowcount - conn.commit() - return rowcount == 1 def get_job_result(self, job_id: str) -> Tuple[str, Any]: diff --git a/tests/pygeoapi-test-config-postgres-manager.yml b/tests/pygeoapi-test-config-postgresql-manager.yml similarity index 97% rename from tests/pygeoapi-test-config-postgres-manager.yml rename to tests/pygeoapi-test-config-postgresql-manager.yml index 496f558..fb3d635 100644 --- a/tests/pygeoapi-test-config-postgres-manager.yml +++ b/tests/pygeoapi-test-config-postgresql-manager.yml @@ -55,7 +55,7 @@ server: user: postgres password: ${POSTGRESQL_PASSWORD:-postgres} # Alternative accepted connection definition: - # connection: postgresql://postgres:postgres@localhost:5432/test + # connection: postgresql://postgres:${POSTGRESQL_PASSWORD:-postgres}@localhost:5432/test output_dir: /tmp logging: diff --git a/tests/test_postgresql_manager.py b/tests/test_postgresql_manager.py index cd3c86f..43464f6 100644 --- a/tests/test_postgresql_manager.py +++ b/tests/test_postgresql_manager.py @@ -45,7 +45,7 @@ from pygeoapi.util import yaml_load @pytest.fixture() def config(): with open(get_test_file_path( - 'pygeoapi-test-config-postgres-manager.yml') + 'pygeoapi-test-config-postgresql-manager.yml') ) as fh: return yaml_load(fh)