diff --git a/pygeoapi/provider/postgresql.py b/pygeoapi/provider/postgresql.py index 5b3ef28..6cdf631 100644 --- a/pygeoapi/provider/postgresql.py +++ b/pygeoapi/provider/postgresql.py @@ -55,6 +55,7 @@ from geoalchemy2 import Geometry # noqa - this isn't used explicitly but is nee from geoalchemy2.functions import ST_MakeEnvelope from geoalchemy2.shape import to_shape from pygeofilter.backends.sqlalchemy.evaluate import to_filter +import pygeofilter.ast import pyproj import shapely from sqlalchemy import create_engine, MetaData, PrimaryKeyConstraint, asc, desc @@ -138,7 +139,8 @@ class PostgreSQLProvider(BaseProvider): LOGGER.debug('Preparing filters') property_filters = self._get_property_filters(properties) - cql_filters = self._get_cql_filters(filterq) + modified_filterq = self._modify_pygeofilter(filterq) + cql_filters = self._get_cql_filters(modified_filterq) bbox_filter = self._get_bbox_filter(bbox) order_by_clauses = self._get_order_by_clauses(sortby, self.table_model) selected_properties = self._select_properties_clause(select_properties, @@ -495,3 +497,40 @@ class PostgreSQLProvider(BaseProvider): else: crs_transform = None return crs_transform + + def _modify_pygeofilter( + self, + ast_tree: pygeofilter.ast.Node, + ) -> pygeofilter.ast.Node: + """ + Prepare the input pygeofilter for querying the database. + + Returns a new ``pygeofilter.ast.Node`` object that can be used for + querying the database. + """ + new_tree = deepcopy(ast_tree) + _inplace_replace_geometry_filter_name(new_tree, self.geom) + return new_tree + + +def _inplace_replace_geometry_filter_name( + node: pygeofilter.ast.Node, + geometry_column_name: str +): + """Recursively traverse node tree and rename nodes of type ``Attribute``. + + Nodes of type ``Attribute`` named ``geometry`` are renamed to the value of + the ``geometry_column_name`` parameter. + """ + try: + sub_nodes = node.get_sub_nodes() + except AttributeError: + pass + else: + for sub_node in sub_nodes: + is_attribute_node = isinstance(sub_node, pygeofilter.ast.Attribute) + if is_attribute_node and sub_node.name == "geometry": + sub_node.name = geometry_column_name + else: + _inplace_replace_geometry_filter_name( + sub_node, geometry_column_name) diff --git a/tests/test_postgresql_provider.py b/tests/test_postgresql_provider.py index 650a609..1d8512c 100644 --- a/tests/test_postgresql_provider.py +++ b/tests/test_postgresql_provider.py @@ -44,7 +44,9 @@ import pytest import pyproj from http import HTTPStatus +import pygeofilter.ast from pygeofilter.parsers.ecql import parse +from pygeofilter.values import Geometry from pygeoapi.api import API @@ -748,3 +750,69 @@ def test_get_collection_items_postgresql_automap_naming_conflicts(pg_api_): assert code == HTTPStatus.OK features = json.loads(response).get('features') assert len(features) == 0 + + +@pytest.mark.parametrize('original_filter, expected', [ + pytest.param( + "INTERSECTS(geometry, POINT(1 1))", + pygeofilter.ast.GeometryIntersects( + pygeofilter.ast.Attribute(name='custom_geom_name'), + Geometry({'type': 'Point', 'coordinates': (1, 1)}) + ), + id='unnested-geometry' + ), + pytest.param( + "some_attribute = 10 AND INTERSECTS(geometry, POINT(1 1))", + pygeofilter.ast.And( + pygeofilter.ast.Equal( + pygeofilter.ast.Attribute(name='some_attribute'), 10), + pygeofilter.ast.GeometryIntersects( + pygeofilter.ast.Attribute(name='custom_geom_name'), + Geometry({'type': 'Point', 'coordinates': (1, 1)}) + ), + ), + id='nested-geometry' + ), + pytest.param( + "(some_attribute = 10 AND INTERSECTS(geometry, POINT(1 1))) OR " + "DWITHIN(geometry, POINT(2 2), 10, meters)", + pygeofilter.ast.Or( + pygeofilter.ast.And( + pygeofilter.ast.Equal( + pygeofilter.ast.Attribute(name='some_attribute'), 10), + pygeofilter.ast.GeometryIntersects( + pygeofilter.ast.Attribute(name='custom_geom_name'), + Geometry({'type': 'Point', 'coordinates': (1, 1)}) + ), + ), + pygeofilter.ast.DistanceWithin( + pygeofilter.ast.Attribute(name='custom_geom_name'), + Geometry({'type': 'Point', 'coordinates': (2, 2)}), + distance=10, + units='meters', + ) + ), + id='complex-filter' + ), +]) +def test_modify_pygeofilter(original_filter, expected): + + class _CustomPostgreSqlProvider(PostgreSQLProvider): + """This is a subclass of the original PostgreSQLProvider. + + The current test is only interested in verifying the correctness of + the logic that modifies the parsed filter. As such, in order + to simplify instantiating the postgresql pygeoapi provider, and + in order to avoid dealing with mocking out the sqlalchemy table + reflection mechanism, this class overrides the __init__() method + and can be used to test the implementation of the base class' + `self._modify_pygeofilter()` method, which is really all we want + to test here. + """ + def __init__(self): + self.geom = 'custom_geom_name' + + provider = _CustomPostgreSqlProvider() + parsed_filter = parse(original_filter) + result = provider._modify_pygeofilter(parsed_filter) + assert result == expected