diff --git a/pygeoapi/provider/postgresql.py b/pygeoapi/provider/postgresql.py index bc3a53d..564255d 100644 --- a/pygeoapi/provider/postgresql.py +++ b/pygeoapi/provider/postgresql.py @@ -174,6 +174,35 @@ class PostgreSQLProvider(BaseProvider): self.fields = db.fields return self.fields + def __get_where_clauses(self, properties=[], bbox=[]): + """ + Generarates WHERE conditions to be implemented in query. + Private method mainly associated with query method + :param properties: list of tuples (name, value) + :param bbox: bounding box [minx,miny,maxx,maxy] + + :returns: psycopg2.sql.Composed or psycopg2.sql.SQL + """ + + where_conditions = [] + if properties: + property_clauses = [SQL('{} = {}').format( + Identifier(k), Literal(v)) for k, v in properties] + where_conditions += property_clauses + if bbox: + bbox_clause = SQL('{} && ST_MakeEnvelope({})').format( + Identifier(self.geom), SQL(', ').join( + [Literal(bbox_coord) for bbox_coord in bbox])) + where_conditions.append(bbox_clause) + + if where_conditions: + where_clause = SQL(' WHERE {}').format( + SQL(' AND ').join(where_conditions)) + else: + where_clause = SQL('') + + return where_clause + def query(self, startindex=0, limit=10, resulttype='results', bbox=[], datetime=None, properties=[], sortby=[]): """ @@ -198,8 +227,11 @@ class PostgreSQLProvider(BaseProvider): with DatabaseConnection(self.conn_dic, self.table, context="hits") as db: cursor = db.conn.cursor(cursor_factory=RealDictCursor) - sql_query = SQL("select count(*) as hits from {}").\ - format(Identifier(self.table)) + + where_clause = self.__get_where_clauses( + properties=properties, bbox=bbox) + sql_query = SQL("select count(*) as hits from {} {}").\ + format(Identifier(self.table), where_clause) try: cursor.execute(sql_query) except Exception as err: @@ -215,27 +247,10 @@ class PostgreSQLProvider(BaseProvider): with DatabaseConnection(self.conn_dic, self.table) as db: cursor = db.conn.cursor(cursor_factory=RealDictCursor) - where_conditions = [] - if properties: - property_clauses = \ - [SQL('{} = {}').format( - Identifier(k), Literal(v)) for k, v in properties] - where_conditions += property_clauses - if bbox: - bbox_clause = SQL('{} && ST_MakeEnvelope({})').format( - Identifier(self.geom), - SQL(', ').join( - [Literal(bbox_coord) for bbox_coord in bbox] - ) - ) - where_conditions.append(bbox_clause) - if where_conditions: - where_clause = SQL(' WHERE {}').format( - SQL(' AND ').join(where_conditions) - ) - else: - where_clause = SQL('') + where_clause = self.__get_where_clauses( + properties=properties, bbox=bbox) + sql_query = SQL("DECLARE \"geo_cursor\" CURSOR FOR \ SELECT {},ST_AsGeoJSON({}) FROM {}{}").\ format(db.columns, diff --git a/tests/test_postgresql_provider.py b/tests/test_postgresql_provider.py index f5097ca..94a9e05 100644 --- a/tests/test_postgresql_provider.py +++ b/tests/test_postgresql_provider.py @@ -88,6 +88,20 @@ def test_query_with_property_filter(config): assert (len(other_features) != 0) +def test_query_hits(config): + """Test query resulttype=hits with properties""" + psp = PostgreSQLProvider(config) + results = psp.query(resulttype="hits") + assert results["numberMatched"] == 14776 + + results = psp.query( + bbox=[29.3373, -3.4099, 29.3761, -3.3924], resulttype="hits") + assert results["numberMatched"] == 5 + + results = psp.query(properties=[("waterway", "stream")], resulttype="hits") + assert results["numberMatched"] == 13930 + + def test_query_bbox(config): """Test query with a specified bounding box""" psp = PostgreSQLProvider(config)