Improvements for xarray provider (#1800)
* Manage non-cf-compliant time dimension * Manage datasets without a time dimension * Allow reversed slices also for axes * Convert also metadata to float64 for json output * Use named temporary file to enable netcdf4 engine * Make float64 conversion faster * Add netcdf output to xarray provider * Flake8 fixes * Fix bug when no time axis in data * Use new xarray interface * Add test for zarr dataset without time dimension * Avoid errors if missing long_name * Manage zarr and netcdf output in the same way * Revert "Manage zarr and netcdf output in the same way" This reverts commit 0b09281b608da95221951d05004f213379da168d. * Revert "Add netcdf output to xarray provider" This reverts commit 9f72bf7614775b418f53f4808fcaeab567c7024a.
This commit is contained in:
+103
-55
@@ -85,13 +85,19 @@ class XarrayProvider(BaseProvider):
|
||||
else:
|
||||
data_to_open = self.data
|
||||
|
||||
self._data = open_func(data_to_open)
|
||||
try:
|
||||
self._data = open_func(data_to_open)
|
||||
except ValueError as err:
|
||||
# Manage non-cf-compliant time dimensions
|
||||
if 'time' in str(err):
|
||||
self._data = open_func(self.data, decode_times=False)
|
||||
else:
|
||||
raise err
|
||||
|
||||
self.storage_crs = self._parse_storage_crs(provider_def)
|
||||
self._coverage_properties = self._get_coverage_properties()
|
||||
|
||||
self.axes = [self._coverage_properties['x_axis_label'],
|
||||
self._coverage_properties['y_axis_label'],
|
||||
self._coverage_properties['time_axis_label']]
|
||||
self.axes = self._coverage_properties['axes']
|
||||
|
||||
self.get_fields()
|
||||
except Exception as err:
|
||||
@@ -101,7 +107,7 @@ class XarrayProvider(BaseProvider):
|
||||
def get_fields(self):
|
||||
if not self._fields:
|
||||
for key, value in self._data.variables.items():
|
||||
if len(value.shape) >= 3:
|
||||
if key not in self._data.coords:
|
||||
LOGGER.debug('Adding variable')
|
||||
dtype = value.dtype
|
||||
if dtype.name.startswith('float'):
|
||||
@@ -109,7 +115,7 @@ class XarrayProvider(BaseProvider):
|
||||
|
||||
self._fields[key] = {
|
||||
'type': dtype,
|
||||
'title': value.attrs['long_name'],
|
||||
'title': value.attrs.get('long_name'),
|
||||
'x-ogc-unit': value.attrs.get('units')
|
||||
}
|
||||
|
||||
@@ -142,9 +148,9 @@ class XarrayProvider(BaseProvider):
|
||||
|
||||
data = self._data[[*properties]]
|
||||
|
||||
if any([self._coverage_properties['x_axis_label'] in subsets,
|
||||
self._coverage_properties['y_axis_label'] in subsets,
|
||||
self._coverage_properties['time_axis_label'] in subsets,
|
||||
if any([self._coverage_properties.get('x_axis_label') in subsets,
|
||||
self._coverage_properties.get('y_axis_label') in subsets,
|
||||
self._coverage_properties.get('time_axis_label') in subsets,
|
||||
datetime_ is not None]):
|
||||
|
||||
LOGGER.debug('Creating spatio-temporal subset')
|
||||
@@ -163,18 +169,36 @@ class XarrayProvider(BaseProvider):
|
||||
self._coverage_properties['y_axis_label'] in subsets,
|
||||
len(bbox) > 0]):
|
||||
msg = 'bbox and subsetting by coordinates are exclusive'
|
||||
LOGGER.warning(msg)
|
||||
LOGGER.error(msg)
|
||||
raise ProviderQueryError(msg)
|
||||
else:
|
||||
query_params[self._coverage_properties['x_axis_label']] = \
|
||||
slice(bbox[0], bbox[2])
|
||||
query_params[self._coverage_properties['y_axis_label']] = \
|
||||
slice(bbox[1], bbox[3])
|
||||
x_axis_label = self._coverage_properties['x_axis_label']
|
||||
x_coords = data.coords[x_axis_label]
|
||||
if x_coords.values[0] > x_coords.values[-1]:
|
||||
LOGGER.debug(
|
||||
'Reversing slicing of x axis from high to low'
|
||||
)
|
||||
query_params[x_axis_label] = slice(bbox[2], bbox[0])
|
||||
else:
|
||||
query_params[x_axis_label] = slice(bbox[0], bbox[2])
|
||||
y_axis_label = self._coverage_properties['y_axis_label']
|
||||
y_coords = data.coords[y_axis_label]
|
||||
if y_coords.values[0] > y_coords.values[-1]:
|
||||
LOGGER.debug(
|
||||
'Reversing slicing of y axis from high to low'
|
||||
)
|
||||
query_params[y_axis_label] = slice(bbox[3], bbox[1])
|
||||
else:
|
||||
query_params[y_axis_label] = slice(bbox[1], bbox[3])
|
||||
|
||||
LOGGER.debug('bbox_crs is not currently handled')
|
||||
|
||||
if datetime_ is not None:
|
||||
if self._coverage_properties['time_axis_label'] in subsets:
|
||||
if self._coverage_properties['time_axis_label'] is None:
|
||||
msg = 'Dataset does not contain a time axis'
|
||||
LOGGER.error(msg)
|
||||
raise ProviderQueryError(msg)
|
||||
elif self._coverage_properties['time_axis_label'] in subsets:
|
||||
msg = 'datetime and temporal subsetting are exclusive'
|
||||
LOGGER.error(msg)
|
||||
raise ProviderQueryError(msg)
|
||||
@@ -196,13 +220,15 @@ class XarrayProvider(BaseProvider):
|
||||
LOGGER.warning(err)
|
||||
raise ProviderQueryError(err)
|
||||
|
||||
if (any([data.coords[self.x_field].size == 0,
|
||||
data.coords[self.y_field].size == 0,
|
||||
data.coords[self.time_field].size == 0])):
|
||||
if any(size == 0 for size in data.sizes.values()):
|
||||
msg = 'No data found'
|
||||
LOGGER.warning(msg)
|
||||
raise ProviderNoDataError(msg)
|
||||
|
||||
if format_ == 'json':
|
||||
# json does not support float32
|
||||
data = _convert_float32_to_float64(data)
|
||||
|
||||
out_meta = {
|
||||
'bbox': [
|
||||
data.coords[self.x_field].values[0],
|
||||
@@ -210,18 +236,20 @@ class XarrayProvider(BaseProvider):
|
||||
data.coords[self.x_field].values[-1],
|
||||
data.coords[self.y_field].values[-1]
|
||||
],
|
||||
"time": [
|
||||
_to_datetime_string(data.coords[self.time_field].values[0]),
|
||||
_to_datetime_string(data.coords[self.time_field].values[-1])
|
||||
],
|
||||
"driver": "xarray",
|
||||
"height": data.sizes[self.y_field],
|
||||
"width": data.sizes[self.x_field],
|
||||
"time_steps": data.sizes[self.time_field],
|
||||
"variables": {var_name: var.attrs
|
||||
for var_name, var in data.variables.items()}
|
||||
}
|
||||
|
||||
if self.time_field is not None:
|
||||
out_meta['time'] = [
|
||||
_to_datetime_string(data.coords[self.time_field].values[0]),
|
||||
_to_datetime_string(data.coords[self.time_field].values[-1]),
|
||||
]
|
||||
out_meta["time_steps"] = data.sizes[self.time_field]
|
||||
|
||||
LOGGER.debug('Serializing data in memory')
|
||||
if format_ == 'json':
|
||||
LOGGER.debug('Creating output in CoverageJSON')
|
||||
@@ -230,9 +258,11 @@ class XarrayProvider(BaseProvider):
|
||||
LOGGER.debug('Returning data in native zarr format')
|
||||
return _get_zarr_data(data)
|
||||
else: # return data in native format
|
||||
with tempfile.TemporaryFile() as fp:
|
||||
with tempfile.NamedTemporaryFile() as fp:
|
||||
LOGGER.debug('Returning data in native NetCDF format')
|
||||
fp.write(data.to_netcdf())
|
||||
data.to_netcdf(
|
||||
fp.name
|
||||
) # we need to pass a string to be able to use the "netcdf4" engine # noqa
|
||||
fp.seek(0)
|
||||
return fp.read()
|
||||
|
||||
@@ -249,7 +279,6 @@ class XarrayProvider(BaseProvider):
|
||||
|
||||
LOGGER.debug('Creating CoverageJSON domain')
|
||||
minx, miny, maxx, maxy = metadata['bbox']
|
||||
mint, maxt = metadata['time']
|
||||
|
||||
selected_fields = {
|
||||
key: value for key, value in self.fields.items()
|
||||
@@ -285,11 +314,6 @@ class XarrayProvider(BaseProvider):
|
||||
'start': maxy,
|
||||
'stop': miny,
|
||||
'num': metadata['height']
|
||||
},
|
||||
self.time_field: {
|
||||
'start': mint,
|
||||
'stop': maxt,
|
||||
'num': metadata['time_steps']
|
||||
}
|
||||
},
|
||||
'referencing': [{
|
||||
@@ -304,6 +328,14 @@ class XarrayProvider(BaseProvider):
|
||||
'ranges': {}
|
||||
}
|
||||
|
||||
if self.time_field is not None:
|
||||
mint, maxt = metadata['time']
|
||||
cj['domain']['axes'][self.time_field] = {
|
||||
'start': mint,
|
||||
'stop': maxt,
|
||||
'num': metadata['time_steps'],
|
||||
}
|
||||
|
||||
for key, value in selected_fields.items():
|
||||
parameter = {
|
||||
'type': 'Parameter',
|
||||
@@ -322,7 +354,6 @@ class XarrayProvider(BaseProvider):
|
||||
cj['parameters'][key] = parameter
|
||||
|
||||
data = data.fillna(None)
|
||||
data = _convert_float32_to_float64(data)
|
||||
|
||||
try:
|
||||
for key, value in selected_fields.items():
|
||||
@@ -330,13 +361,18 @@ class XarrayProvider(BaseProvider):
|
||||
'type': 'NdArray',
|
||||
'dataType': value['type'],
|
||||
'axisNames': [
|
||||
'y', 'x', self._coverage_properties['time_axis_label']
|
||||
'y', 'x'
|
||||
],
|
||||
'shape': [metadata['height'],
|
||||
metadata['width'],
|
||||
metadata['time_steps']]
|
||||
metadata['width']]
|
||||
}
|
||||
cj['ranges'][key]['values'] = data[key].values.flatten().tolist() # noqa
|
||||
|
||||
if self.time_field is not None:
|
||||
cj['ranges'][key]['axisNames'].append(
|
||||
self._coverage_properties['time_axis_label']
|
||||
)
|
||||
cj['ranges'][key]['shape'].append(metadata['time_steps'])
|
||||
except IndexError as err:
|
||||
LOGGER.warning(err)
|
||||
raise ProviderQueryError('Invalid query parameter')
|
||||
@@ -382,31 +418,37 @@ class XarrayProvider(BaseProvider):
|
||||
self._data.coords[self.x_field].values[-1],
|
||||
self._data.coords[self.y_field].values[-1],
|
||||
],
|
||||
'time_range': [
|
||||
_to_datetime_string(
|
||||
self._data.coords[self.time_field].values[0]
|
||||
),
|
||||
_to_datetime_string(
|
||||
self._data.coords[self.time_field].values[-1]
|
||||
)
|
||||
],
|
||||
'bbox_crs': 'http://www.opengis.net/def/crs/OGC/1.3/CRS84',
|
||||
'crs_type': 'GeographicCRS',
|
||||
'x_axis_label': self.x_field,
|
||||
'y_axis_label': self.y_field,
|
||||
'time_axis_label': self.time_field,
|
||||
'width': self._data.sizes[self.x_field],
|
||||
'height': self._data.sizes[self.y_field],
|
||||
'time': self._data.sizes[self.time_field],
|
||||
'time_duration': self.get_time_coverage_duration(),
|
||||
'bbox_units': 'degrees',
|
||||
'resx': np.abs(self._data.coords[self.x_field].values[1]
|
||||
- self._data.coords[self.x_field].values[0]),
|
||||
'resy': np.abs(self._data.coords[self.y_field].values[1]
|
||||
- self._data.coords[self.y_field].values[0]),
|
||||
'restime': self.get_time_resolution()
|
||||
'resx': np.abs(
|
||||
self._data.coords[self.x_field].values[1]
|
||||
- self._data.coords[self.x_field].values[0]
|
||||
),
|
||||
'resy': np.abs(
|
||||
self._data.coords[self.y_field].values[1]
|
||||
- self._data.coords[self.y_field].values[0]
|
||||
),
|
||||
}
|
||||
|
||||
if self.time_field is not None:
|
||||
properties['time_axis_label'] = self.time_field
|
||||
properties['time_range'] = [
|
||||
_to_datetime_string(
|
||||
self._data.coords[self.time_field].values[0]
|
||||
),
|
||||
_to_datetime_string(
|
||||
self._data.coords[self.time_field].values[-1]
|
||||
),
|
||||
]
|
||||
properties['time'] = self._data.sizes[self.time_field]
|
||||
properties['time_duration'] = self.get_time_coverage_duration()
|
||||
properties['restime'] = self.get_time_resolution()
|
||||
|
||||
# Update properties based on the xarray's CRS
|
||||
epsg_code = self.storage_crs.to_epsg()
|
||||
LOGGER.debug(f'{epsg_code}')
|
||||
@@ -425,10 +467,12 @@ class XarrayProvider(BaseProvider):
|
||||
|
||||
properties['axes'] = [
|
||||
properties['x_axis_label'],
|
||||
properties['y_axis_label'],
|
||||
properties['time_axis_label']
|
||||
properties['y_axis_label']
|
||||
]
|
||||
|
||||
if self.time_field is not None:
|
||||
properties['axes'].append(properties['time_axis_label'])
|
||||
|
||||
return properties
|
||||
|
||||
@staticmethod
|
||||
@@ -455,7 +499,8 @@ class XarrayProvider(BaseProvider):
|
||||
:returns: time resolution string
|
||||
"""
|
||||
|
||||
if self._data[self.time_field].size > 1:
|
||||
if self.time_field is not None \
|
||||
and self._data[self.time_field].size > 1:
|
||||
time_diff = (self._data[self.time_field][1] -
|
||||
self._data[self.time_field][0])
|
||||
|
||||
@@ -472,6 +517,9 @@ class XarrayProvider(BaseProvider):
|
||||
:returns: time coverage duration string
|
||||
"""
|
||||
|
||||
if self.time_field is None:
|
||||
return None
|
||||
|
||||
dur = self._data[self.time_field][-1] - self._data[self.time_field][0]
|
||||
ms_difference = dur.values.astype('timedelta64[ms]').astype(np.double)
|
||||
|
||||
@@ -634,7 +682,7 @@ def _convert_float32_to_float64(data):
|
||||
for var_name in data.variables:
|
||||
if data[var_name].dtype == 'float32':
|
||||
og_attrs = data[var_name].attrs
|
||||
data[var_name] = data[var_name].astype('float64')
|
||||
data[var_name] = data[var_name].astype('float64', copy=False)
|
||||
data[var_name].attrs = og_attrs
|
||||
|
||||
return data
|
||||
|
||||
@@ -30,6 +30,7 @@
|
||||
from numpy import float64, int64
|
||||
|
||||
import pytest
|
||||
import xarray as xr
|
||||
|
||||
from pygeoapi.provider.xarray_ import XarrayProvider
|
||||
from pygeoapi.util import json_serial
|
||||
@@ -53,6 +54,20 @@ def config():
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def config_no_time(tmp_path):
|
||||
ds = xr.open_zarr(path)
|
||||
ds = ds.sel(time=ds.time[0])
|
||||
ds = ds.drop_vars('time')
|
||||
ds.to_zarr(tmp_path / 'no_time.zarr')
|
||||
return {
|
||||
'name': 'zarr',
|
||||
'type': 'coverage',
|
||||
'data': str(tmp_path / 'no_time.zarr'),
|
||||
'format': {'name': 'zarr', 'mimetype': 'application/zip'},
|
||||
}
|
||||
|
||||
|
||||
def test_provider(config):
|
||||
p = XarrayProvider(config)
|
||||
|
||||
@@ -85,3 +100,14 @@ def test_numpy_json_serial():
|
||||
|
||||
d = float64(500.00000005)
|
||||
assert json_serial(d) == 500.00000005
|
||||
|
||||
|
||||
def test_no_time(config_no_time):
|
||||
p = XarrayProvider(config_no_time)
|
||||
|
||||
assert len(p.fields) == 4
|
||||
assert p.axes == ['lon', 'lat']
|
||||
|
||||
coverage = p.query(format='json')
|
||||
|
||||
assert sorted(coverage['domain']['axes'].keys()) == ['x', 'y']
|
||||
|
||||
Reference in New Issue
Block a user