Added method for translating between generic geometry name and the name of the actual geometry column in the postgres DB (#1453)
This commit is contained in:
committed by
GitHub
parent
cfa21f627c
commit
22ac69aa75
@@ -55,6 +55,7 @@ from geoalchemy2 import Geometry # noqa - this isn't used explicitly but is nee
|
||||
from geoalchemy2.functions import ST_MakeEnvelope
|
||||
from geoalchemy2.shape import to_shape
|
||||
from pygeofilter.backends.sqlalchemy.evaluate import to_filter
|
||||
import pygeofilter.ast
|
||||
import pyproj
|
||||
import shapely
|
||||
from sqlalchemy import create_engine, MetaData, PrimaryKeyConstraint, asc, desc
|
||||
@@ -138,7 +139,8 @@ class PostgreSQLProvider(BaseProvider):
|
||||
|
||||
LOGGER.debug('Preparing filters')
|
||||
property_filters = self._get_property_filters(properties)
|
||||
cql_filters = self._get_cql_filters(filterq)
|
||||
modified_filterq = self._modify_pygeofilter(filterq)
|
||||
cql_filters = self._get_cql_filters(modified_filterq)
|
||||
bbox_filter = self._get_bbox_filter(bbox)
|
||||
order_by_clauses = self._get_order_by_clauses(sortby, self.table_model)
|
||||
selected_properties = self._select_properties_clause(select_properties,
|
||||
@@ -495,3 +497,40 @@ class PostgreSQLProvider(BaseProvider):
|
||||
else:
|
||||
crs_transform = None
|
||||
return crs_transform
|
||||
|
||||
def _modify_pygeofilter(
|
||||
self,
|
||||
ast_tree: pygeofilter.ast.Node,
|
||||
) -> pygeofilter.ast.Node:
|
||||
"""
|
||||
Prepare the input pygeofilter for querying the database.
|
||||
|
||||
Returns a new ``pygeofilter.ast.Node`` object that can be used for
|
||||
querying the database.
|
||||
"""
|
||||
new_tree = deepcopy(ast_tree)
|
||||
_inplace_replace_geometry_filter_name(new_tree, self.geom)
|
||||
return new_tree
|
||||
|
||||
|
||||
def _inplace_replace_geometry_filter_name(
|
||||
node: pygeofilter.ast.Node,
|
||||
geometry_column_name: str
|
||||
):
|
||||
"""Recursively traverse node tree and rename nodes of type ``Attribute``.
|
||||
|
||||
Nodes of type ``Attribute`` named ``geometry`` are renamed to the value of
|
||||
the ``geometry_column_name`` parameter.
|
||||
"""
|
||||
try:
|
||||
sub_nodes = node.get_sub_nodes()
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
for sub_node in sub_nodes:
|
||||
is_attribute_node = isinstance(sub_node, pygeofilter.ast.Attribute)
|
||||
if is_attribute_node and sub_node.name == "geometry":
|
||||
sub_node.name = geometry_column_name
|
||||
else:
|
||||
_inplace_replace_geometry_filter_name(
|
||||
sub_node, geometry_column_name)
|
||||
|
||||
@@ -44,7 +44,9 @@ import pytest
|
||||
import pyproj
|
||||
from http import HTTPStatus
|
||||
|
||||
import pygeofilter.ast
|
||||
from pygeofilter.parsers.ecql import parse
|
||||
from pygeofilter.values import Geometry
|
||||
|
||||
from pygeoapi.api import API
|
||||
|
||||
@@ -748,3 +750,69 @@ def test_get_collection_items_postgresql_automap_naming_conflicts(pg_api_):
|
||||
assert code == HTTPStatus.OK
|
||||
features = json.loads(response).get('features')
|
||||
assert len(features) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize('original_filter, expected', [
|
||||
pytest.param(
|
||||
"INTERSECTS(geometry, POINT(1 1))",
|
||||
pygeofilter.ast.GeometryIntersects(
|
||||
pygeofilter.ast.Attribute(name='custom_geom_name'),
|
||||
Geometry({'type': 'Point', 'coordinates': (1, 1)})
|
||||
),
|
||||
id='unnested-geometry'
|
||||
),
|
||||
pytest.param(
|
||||
"some_attribute = 10 AND INTERSECTS(geometry, POINT(1 1))",
|
||||
pygeofilter.ast.And(
|
||||
pygeofilter.ast.Equal(
|
||||
pygeofilter.ast.Attribute(name='some_attribute'), 10),
|
||||
pygeofilter.ast.GeometryIntersects(
|
||||
pygeofilter.ast.Attribute(name='custom_geom_name'),
|
||||
Geometry({'type': 'Point', 'coordinates': (1, 1)})
|
||||
),
|
||||
),
|
||||
id='nested-geometry'
|
||||
),
|
||||
pytest.param(
|
||||
"(some_attribute = 10 AND INTERSECTS(geometry, POINT(1 1))) OR "
|
||||
"DWITHIN(geometry, POINT(2 2), 10, meters)",
|
||||
pygeofilter.ast.Or(
|
||||
pygeofilter.ast.And(
|
||||
pygeofilter.ast.Equal(
|
||||
pygeofilter.ast.Attribute(name='some_attribute'), 10),
|
||||
pygeofilter.ast.GeometryIntersects(
|
||||
pygeofilter.ast.Attribute(name='custom_geom_name'),
|
||||
Geometry({'type': 'Point', 'coordinates': (1, 1)})
|
||||
),
|
||||
),
|
||||
pygeofilter.ast.DistanceWithin(
|
||||
pygeofilter.ast.Attribute(name='custom_geom_name'),
|
||||
Geometry({'type': 'Point', 'coordinates': (2, 2)}),
|
||||
distance=10,
|
||||
units='meters',
|
||||
)
|
||||
),
|
||||
id='complex-filter'
|
||||
),
|
||||
])
|
||||
def test_modify_pygeofilter(original_filter, expected):
|
||||
|
||||
class _CustomPostgreSqlProvider(PostgreSQLProvider):
|
||||
"""This is a subclass of the original PostgreSQLProvider.
|
||||
|
||||
The current test is only interested in verifying the correctness of
|
||||
the logic that modifies the parsed filter. As such, in order
|
||||
to simplify instantiating the postgresql pygeoapi provider, and
|
||||
in order to avoid dealing with mocking out the sqlalchemy table
|
||||
reflection mechanism, this class overrides the __init__() method
|
||||
and can be used to test the implementation of the base class'
|
||||
`self._modify_pygeofilter()` method, which is really all we want
|
||||
to test here.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.geom = 'custom_geom_name'
|
||||
|
||||
provider = _CustomPostgreSqlProvider()
|
||||
parsed_filter = parse(original_filter)
|
||||
result = provider._modify_pygeofilter(parsed_filter)
|
||||
assert result == expected
|
||||
|
||||
Reference in New Issue
Block a user