From 1b6043eed3a84072896ba7063ae782abd389b949 Mon Sep 17 00:00:00 2001 From: Sebastien Jourdain Date: Tue, 7 Jan 2025 15:59:46 -0700 Subject: [PATCH 01/10] feat(reader): use vtk implementation if available --- pan3d/ui/preview.py | 4 + pan3d/viewers/preview.py | 10 +- pan3d/xarray/vtk.py | 536 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 549 insertions(+), 1 deletion(-) create mode 100644 pan3d/xarray/vtk.py diff --git a/pan3d/ui/preview.py b/pan3d/ui/preview.py index b87b7db..37a33fe 100644 --- a/pan3d/ui/preview.py +++ b/pan3d/ui/preview.py @@ -776,6 +776,10 @@ def on_change(self, slice_t, **_): self.source.slices = slices ds = self.source() + + if ds.GetClassName() == "vtkDataObject": + return # no mesh produced yet + self.state.dataset_bounds = ds.bounds self.ctrl.view_reset_clipping_range() diff --git a/pan3d/viewers/preview.py b/pan3d/viewers/preview.py index aa12e6c..d71aed6 100644 --- a/pan3d/viewers/preview.py +++ b/pan3d/viewers/preview.py @@ -129,7 +129,12 @@ def _setup_vtk(self): self.interactor.GetInteractorStyle().SetCurrentStyleToTrackballCamera() self.lut = vtkLookupTable() - self.source = vtkXArrayRectilinearSource(input=self.xarray) + try: + from pan3d.xarray.vtk import vtkXArraySource + + self.source = vtkXArraySource(input=self.xarray) + except ImportError: + self.source = vtkXArrayRectilinearSource(input=self.xarray) # Need explicit geometry extraction when used with WASM self.geometry = vtkDataSetSurfaceFilter( @@ -267,6 +272,9 @@ def _on_color_by(self, color_by, **__): return ds = self.source() + print("=" * 60) + print(ds) + print("=" * 60) if color_by in ds.point_data.keys(): array = ds.point_data[color_by] min_value, max_value = array.GetRange() diff --git a/pan3d/xarray/vtk.py b/pan3d/xarray/vtk.py new file mode 100644 index 0000000..9a92563 --- /dev/null +++ b/pan3d/xarray/vtk.py @@ -0,0 +1,536 @@ +from typing import List, Optional + +import xarray as xr +import numpy as np +import pandas as pd +import traceback + +from vtkmodules.util.vtkAlgorithm import VTKPythonAlgorithmBase +from vtkmodules.vtkCommonCore import vtkVariant +from vtkmodules.vtkCommonDataModel import vtkImageData +from vtkmodules.vtkFiltersCore import vtkArrayCalculator +from vtkmodules.util import numpy_support, xarray_support + +VTK_DATASETS = {"vtkImageData": vtkImageData} + +# ----------------------------------------------------------------------------- +# Helper functions +# ----------------------------------------------------------------------------- + + +def is_time_array(xarray, name, values): + if values.dtype.type == np.datetime64 or values.dtype.type == np.timedelta64: + un = np.datetime_data(values.dtype) + # unit = ns and 1 base unit in a spep + if un[0] == "ns" and un[1] == 1 and name in xarray.coords.keys(): + return True + return False + + +# ----------------------------------------------------------------------------- + + +def attr_value_to_vtk(value): + if np.issubdtype(type(value), np.integer): + return vtkVariant(int(value)) + + if np.issubdtype(type(value), np.floating): + return vtkVariant(float(value)) + + if isinstance(value, np.ndarray): + return vtkVariant(numpy_support.numpy_to_vtk(value)) + + return vtkVariant(value) + + +# ----------------------------------------------------------------------------- + + +def get_time_labels(times): + return [pd.to_datetime(time).strftime("%Y-%m-%d %H:%M:%S") for time in times] + + +# ----------------------------------------------------------------------------- +# VTK Algorithms +# ----------------------------------------------------------------------------- + + +class vtkXArraySource(VTKPythonAlgorithmBase): + """vtk source for converting XArray into a VTK mesh""" + + def __init__( + self, + input: Optional[xr.Dataset] = None, + arrays: Optional[List[str]] = None, + ): + """ + Create vtkXArraySource + + Parameters: + input (xr.Dataset): Provide an XArray to use as input. The load() method will replace it. + arrays (list[str]): List of field to load onto the generated VTK mesh. + """ + VTKPythonAlgorithmBase.__init__( + self, + nInputPorts=0, + nOutputPorts=1, + ) + # Data source + self._input = input + self._pipeline = None + self._computed = {} + self._data_origin = None + + # Array name selectors + self._x = None + self._y = None + self._z = None + self._t = None + + # Data sub-selection + self._array_names = set(arrays or []) + self._t_index = 0 + self._slices = None + + # vtk internal vars + self._arrays = {} + self._accessor = None + self._reader = None + + # Create reader if xarray available + if self._input: + self._build_reader() + + # ------------------------------------------------------------------------- + # Reader Setup + # ------------------------------------------------------------------------- + + def _build_reader(self): + # reset + self._array_names.clear() + self._arrays = {} # to prevent garbage collection + self._x = None + self._y = None + self._z = None + self._t = None + + # vtk reader + self._accessor = xarray_support.vtkXArrayAccessor() + self._reader = xarray_support.vtkNetCDFCFReader(accessor=self._accessor) + + # XArray binding + xarray = self._input + accessor = self._accessor + reader = self._reader + time_names = [] + + # map dimensions + dim_keys = list(xarray.sizes.keys()) + dim_values = [xarray.sizes[k] for k in dim_keys] + dim_name_to_index = {k: i for i, k in enumerate(dim_keys)} + + accessor.SetDim(dim_keys) + accessor.SetDimLen(dim_values) + + # map variables + var_keys = [*xarray.data_vars.keys(), *xarray.coords.keys()] + var_is_coord = [0] * len(xarray.data_vars) + [1] * len(xarray.coords) + var_name_to_index = {k: i for i, k in enumerate(var_keys)} + + accessor.SetVar(var_keys, var_is_coord) + for i, k in enumerate(var_keys): + # need contiguous array (noop if true, otherwise copy) + v_data = np.ascontiguousarray(xarray[k].values) + + # Capture time array names + if is_time_array(xarray, k, v_data): + time_names.append(k) + + # Convert time data + if v_data.dtype.char == "O": + # object array, assume cftime (copy as double array) + v_data = xarray_support.ndarray_cftime_toordinal(v_data).astype( + np.float64 + ) + time_names.append(k) + + # save array + self._arrays[k] = v_data + + # register data to accessor + accessor.SetVarValue(i, v_data) + accessor.SetVarType(i, xarray_support.get_nc_type(v_data.dtype)) + accessor.SetVarDims(i, [dim_name_to_index[name] for name in xarray[k].dims]) + accessor.SetVarCoords( + i, [var_name_to_index[name] for name in xarray[k].coords] + ) + + # handle attributes + for attr_k, attr_v in xarray[k].attrs.items(): + accessor.SetAtt(i, attr_k, attr_value_to_vtk(attr_v)) + + # map time array name to reader + if len(time_names) >= 1: + for name in time_names: + if accessor.IsCOARDSCoordinate(name): + reader.SetTimeDimensionName(name) + self._t = name + break + + # Extract coordinate mapping + reader.UpdateInformation() + vtk_str_array = reader.GetAllDimensions() + coords = [] + coords_names = ["_x", "_y", "_z"] + if vtk_str_array.GetNumberOfValues() == 1 and ", " in vtk_str_array.GetValue(0): + print("vtk reader.GetAllDimensions() is not properly working...") + coords = vtk_str_array.GetValue(0)[1:-1].split(", ") + else: + for i in range(vtk_str_array.GetNumberOfValues()): + coords.append(vtk_str_array.GetValue(i)) + while len(coords): + coord_name = coords.pop() # (Z, Y, X) + attr_name = coords_names.pop(0) # (X, Y, Z) + setattr(self, attr_name, coord_name) + + # ------------------------------------------------------------------------- + # Information + # ------------------------------------------------------------------------- + + def __str__(self): + return """VTK NetCDF/XArray reader""" + + # ------------------------------------------------------------------------- + # Data input + # ------------------------------------------------------------------------- + + @property + def input(self): + """return current input XArray""" + return self._input + + @input.setter + def input(self, xarray_dataset: xr.Dataset): + """update input with a new XArray""" + self._input = xarray_dataset + self._build_reader() + self.Modified() + + # ------------------------------------------------------------------------- + # Array selectors + # ------------------------------------------------------------------------- + + @property + def x(self): + """return the name that is currently mapped to the X axis""" + return self._x + + @property + def x_size(self): + """return the size of the coordinate used for the X axis""" + if self._x is None: + return 0 + return int(self._input[self._x].size) + + @property + def y(self): + """return the name that is currently mapped to the Y axis""" + return self._y + + @property + def y_size(self): + """return the size of the coordinate used for the Y axis""" + if self._y is None: + return 0 + return int(self._input[self._y].size) + + @property + def z(self): + """return the name that is currently mapped to the Z axis""" + return self._z + + @property + def z_size(self): + """return the size of the coordinate used for the Z axis""" + if self._z is None: + return 0 + return int(self._input[self._z].size) + + @property + def t(self): + """return the name that is currently mapped to the time axis""" + return self._t + + @property + def slice_extents(self): + """return a dictionary for the X, Y, Z dimensions with the corresponding extent [0, size-1]""" + return { + coord_name: [0, self.input[coord_name].size - 1] + for coord_name in [self.x, self.y, self.z] + if coord_name is not None + } + + @property + def available_coords(self): + """List available coordinates arrays that have are 1D""" + if self._input is None: + return [] + + return [k for k, v in self._input.coords.items() if len(v.shape) == 1] + + # ------------------------------------------------------------------------- + # Data sub-selection + # ------------------------------------------------------------------------- + + @property + def t_index(self): + """return the current selected time index""" + return self._t_index + + @t_index.setter + def t_index(self, t_index: int): + """update the current selected time index""" + if t_index != self._t_index: + self._t_index = t_index + self.Modified() + + @property + def t_size(self): + """return the size of the coordinate used for the time""" + if self._t is None: + return 0 + return int(self._input[self._t].size) + + @property + def t_labels(self): + """return a list of string that match the various time values available""" + if self._t is None: + return [] + + t_array = self._input[self._t] + t_type = t_array.dtype + if np.issubdtype(t_type, np.datetime64): + return get_time_labels(t_array.values) + return [str(t) for t in t_array.values] + + @property + def arrays(self): + """return the list of arrays that are currently selected to be added to the generated VTK mesh""" + return list(self._array_names) + + @arrays.setter + def arrays(self, array_names: List[str]): + """update the list of arrays to load on the generated VTK mesh""" + new_names = set(array_names or []) + if new_names != self._array_names: + self._array_names = new_names + self.Modified() + + @property + def available_arrays(self): + """List all available data fields for the `arrays` option""" + if self._input is None or self._reader is None: + return [] + + vtk_str_array = self._reader.GetAllVariableArrayNames() + return [ + vtk_str_array.GetValue(i) for i in range(vtk_str_array.GetNumberOfValues()) + ] + + @property + def slices(self): + """return the current slicing information which include axes crop/cut and time selection""" + result = dict(self._slices or {}) + if self.t is not None: + result[self.t] = self.t_index + return result + + @slices.setter + def slices(self, v): + """update the slicing of the data along axes""" + if v != self._slices: + self._slices = v + # FIXME !!! update accessor + # Ask Dan + # self.Modified() + # raise NotImplementedError() + print("set slices not implemented", v) + if "time" in v: + self.t_index = v.get("time", 0) + + # ------------------------------------------------------------------------- + # add-on logic + # ------------------------------------------------------------------------- + + @property + def computed(self): + """return the current description of the computed/derived fields on the VTK mesh""" + return self._computed + + @computed.setter + def computed(self, v): + """ + update the computed/derived fields to add on the VTK mesh + + The layout of the dictionary provided should be as follow: + - key: name of the field to be added + - value: formula to apply for the given field name. The syntax is captured in the document (https://docs.paraview.org/en/latest/UsersGuide/filteringData.html#calculator) + + Then additional keys need to be provided to describe your formula dependencies: + `_use_scalars` and `_use_vectors` which should be a list of string matching the name of the fields you are using in your expression. + + + Please find below an example: + + ``` + { + "_use_scalars": ["u", "v"], # (u,v) needed for "vec" and "m2" + "vec": "(u * iHat) + (v * jHat)", # 2D vector + "m2": "u*u + v*v", + } + ``` + """ + if self._computed != v: + self._computed = v or {} + self._pipeline = None + scalar_arrays = self._computed.get("_use_scalars", []) + vector_arrays = self._computed.get("_use_vectors", []) + + for output_name, func in self._computed.items(): + if output_name[0] == "_": + continue + filter = vtkArrayCalculator( + result_array_name=output_name, + function=func, + ) + + # register array dependencies + for scalar_array in scalar_arrays: + filter.AddScalarArrayName(scalar_array) + for vector_array in vector_arrays: + filter.AddVectorArrayName(vector_array) + + if self._pipeline is None: + self._pipeline = filter + else: + self._pipeline = self._pipeline >> filter + + self.Modified() + + def load(self, data_info): + """ + create a new XArray input with the `data_origin` and `dataset_config` information. + + Here is an example of the layout of the parameter + + ``` + { + "data_origin": { + "source": "url", # one of [file, url, xarray, pangeo, esgf] + "id": "https://ncsa.osn.xsede.org/Pangeo/pangeo-forge/noaa-coastwatch-geopolar-sst-feedstock/noaa-coastwatch-geopolar-sst.zarr", + "order": "C" # (optional) order to use in numpy + }, + "dataset_config": { + "x": "lon", # (optional) coord name for X + "y": "lat", # (optional) coord name for Y + "z": null, # (optional) coord name for Z + "t": "time", # (optional) coord name for time + "slices": { # (optional) array slicing + "lon": [ + 1000, + 6000, + 20 + ], + "lat": [ + 500, + 3000, + 20 + ], + "time": 5 + }, + "t_index": 5, # (optional) selected time index + "arrays": [ # (optional) names of arrays to load onto VTK mesh. + "analysed_sst" # If missing no array will be loaded + ] # onto the mesh. + } + } + ``` + """ + if "data_origin" not in data_info: + raise ValueError("Only state with data_origin can be loaded") + + from pan3d import catalogs + + self._data_origin = data_info["data_origin"] + self.input = catalogs.load_dataset( + self._data_origin["source"], self._data_origin["id"] + ) + self._build_reader() + + dataset_config = data_info.get("dataset_config") + if dataset_config is None: + self.arrays = self.available_arrays + else: + # self.slices = dataset_config.get("slices") # FIXME: not implemented yet + self.t_index = dataset_config.get("t_index", 0) + self.arrays = dataset_config.get("arrays", self.available_arrays) + + @property + def state(self): + """return current state that can be reused in a load() later on""" + if self._data_origin is None: + raise RuntimeError( + "No state available without data origin. Need to use the load method to set the data origin." + ) + + return { + "data_origin": self._data_origin, + "dataset_config": { + k: getattr(self, k) + for k in ["x", "y", "z", "t", "slices", "t_index", "arrays"] + }, + } + + # ------------------------------------------------------------------------- + # Algorithm + # ------------------------------------------------------------------------- + + def RequestData(self, request, inInfo, outInfo): + """implementation of the vtk algorithm for generating the VTK mesh""" + # Use open data_array handle to fetch data at + # desired Level of Detail + if self._reader is None: + return 0 + + try: + output = None + + if self.t_size: + t = self._arrays[self.t][self.t_index] + self._reader.UpdateTimeStep(t) + # print("Update VTK time", t) + + # update arrays + for name in self.available_arrays: + active = 1 if name in self._array_names else 0 + self._reader.SetVariableArrayStatus(name, active) + + # generate the mesh + mesh = self._reader() + + # Compute derived quantity + if self._pipeline is not None: + output = self._pipeline(mesh) + else: + output = mesh + + # set it as output + print("output => ", output.GetClassName()) + filter_output = VTK_DATASETS[output.GetClassName()].GetData(outInfo) + print(f"{filter_output=}") # <= it is None... + filter_output.ShallowCopy(output) + + except Exception as e: + traceback.print_exc() + raise e + return 1 From 7e0dd7705bdc80d8e1b91002b8dcd7f9eec231b0 Mon Sep 17 00:00:00 2001 From: Sebastien Jourdain Date: Wed, 8 Jan 2025 13:58:27 -0700 Subject: [PATCH 02/10] feat(xarray): enable more tutorial dataset --- pan3d/catalogs/xarray.py | 31 ++++++++++++------------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/pan3d/catalogs/xarray.py b/pan3d/catalogs/xarray.py index ece97c5..e28a0ff 100644 --- a/pan3d/catalogs/xarray.py +++ b/pan3d/catalogs/xarray.py @@ -14,25 +14,18 @@ "description": "Dataset with ocean basins marked using integers", }, # ------------------------------------------------------------------------- - # { - # "name": "ASE_ice_velocity", - # "description": "MEaSUREs InSAR-Based Ice Velocity of the Amundsen Sea Embayment, Antarctica, Version 1", - # }, - # ------------------------------------------------------------------------- - # { - # "name": "rasm", - # "description": "Output of the Regional Arctic System Model (RASM)", - # }, - # ------------------------------------------------------------------------- - # { - # "name": "ROMS_example", - # "description": "Regional Ocean Model System (ROMS) output", - # }, - # ------------------------------------------------------------------------- - # { - # "name": "tiny", - # "description": "small synthetic dataset with a 1D data variable", - # }, + { + "name": "ASE_ice_velocity", + "description": "MEaSUREs InSAR-Based Ice Velocity of the Amundsen Sea Embayment, Antarctica, Version 1", + }, + { + "name": "rasm", + "description": "Output of the Regional Arctic System Model (RASM)", + }, + { + "name": "ROMS_example", + "description": "Regional Ocean Model System (ROMS) output", + }, # ------------------------------------------------------------------------- # needs pandas[xarray] # { From fc06a322d849e44297f943b0b14126bd773d69aa Mon Sep 17 00:00:00 2001 From: Sebastien Jourdain Date: Wed, 8 Jan 2025 14:00:43 -0700 Subject: [PATCH 03/10] feat(vtkreader): add initial integration if available --- pan3d/ui/preview.py | 3 -- pan3d/viewers/preview.py | 14 +++++-- pan3d/xarray/vtk.py | 79 ++++++++++++++++++++-------------------- 3 files changed, 50 insertions(+), 46 deletions(-) diff --git a/pan3d/ui/preview.py b/pan3d/ui/preview.py index 37a33fe..17944c3 100644 --- a/pan3d/ui/preview.py +++ b/pan3d/ui/preview.py @@ -777,9 +777,6 @@ def on_change(self, slice_t, **_): self.source.slices = slices ds = self.source() - if ds.GetClassName() == "vtkDataObject": - return # no mesh produced yet - self.state.dataset_bounds = ds.bounds self.ctrl.view_reset_clipping_range() diff --git a/pan3d/viewers/preview.py b/pan3d/viewers/preview.py index d71aed6..050e722 100644 --- a/pan3d/viewers/preview.py +++ b/pan3d/viewers/preview.py @@ -272,9 +272,6 @@ def _on_color_by(self, color_by, **__): return ds = self.source() - print("=" * 60) - print(ds) - print("=" * 60) if color_by in ds.point_data.keys(): array = ds.point_data[color_by] min_value, max_value = array.GetRange() @@ -286,6 +283,17 @@ def _on_color_by(self, color_by, **__): self.mapper.SetScalarModeToUsePointFieldData() self.mapper.InterpolateScalarsBeforeMappingOn() self.mapper.SetScalarVisibility(1) + elif color_by in ds.cell_data.keys(): + array = ds.cell_data[color_by] + min_value, max_value = array.GetRange() + + self.state.color_min = min_value + self.state.color_max = max_value + + self.mapper.SelectColorArray(color_by) + self.mapper.SetScalarModeToUseCellFieldData() + self.mapper.InterpolateScalarsBeforeMappingOn() + self.mapper.SetScalarVisibility(1) else: self.mapper.SetScalarVisibility(0) self.state.color_min = 0 diff --git a/pan3d/xarray/vtk.py b/pan3d/xarray/vtk.py index 9a92563..37f1855 100644 --- a/pan3d/xarray/vtk.py +++ b/pan3d/xarray/vtk.py @@ -3,11 +3,11 @@ import xarray as xr import numpy as np import pandas as pd -import traceback from vtkmodules.util.vtkAlgorithm import VTKPythonAlgorithmBase from vtkmodules.vtkCommonCore import vtkVariant -from vtkmodules.vtkCommonDataModel import vtkImageData +from vtkmodules.vtkCommonExecutionModel import vtkStreamingDemandDrivenPipeline +from vtkmodules.vtkCommonDataModel import vtkImageData, vtkDataObject from vtkmodules.vtkFiltersCore import vtkArrayCalculator from vtkmodules.util import numpy_support, xarray_support @@ -74,6 +74,7 @@ def __init__( self, nInputPorts=0, nOutputPorts=1, + outputType="vtkDataObject", ) # Data source self._input = input @@ -174,24 +175,14 @@ def _build_reader(self): for name in time_names: if accessor.IsCOARDSCoordinate(name): reader.SetTimeDimensionName(name) - self._t = name break # Extract coordinate mapping reader.UpdateInformation() - vtk_str_array = reader.GetAllDimensions() - coords = [] - coords_names = ["_x", "_y", "_z"] - if vtk_str_array.GetNumberOfValues() == 1 and ", " in vtk_str_array.GetValue(0): - print("vtk reader.GetAllDimensions() is not properly working...") - coords = vtk_str_array.GetValue(0)[1:-1].split(", ") - else: - for i in range(vtk_str_array.GetNumberOfValues()): - coords.append(vtk_str_array.GetValue(i)) - while len(coords): - coord_name = coords.pop() # (Z, Y, X) - attr_name = coords_names.pop(0) # (X, Y, Z) - setattr(self, attr_name, coord_name) + self._x = reader.GetLongitudeDimensionName() + self._y = reader.GetLatitudeDimensionName() + self._z = reader.GetVerticalDimensionName() + self._t = reader.GetTimeDimensionName() # ------------------------------------------------------------------------- # Information @@ -324,6 +315,11 @@ def arrays(self, array_names: List[str]): new_names = set(array_names or []) if new_names != self._array_names: self._array_names = new_names + + for name in self.available_arrays: + active = 1 if name in self._array_names else 0 + self._reader.SetVariableArrayStatus(name, active) + self.Modified() @property @@ -495,42 +491,45 @@ def state(self): # Algorithm # ------------------------------------------------------------------------- + def RequestDataObject(self, request, inInfo, outInfo): + output = vtkImageData() + if self._reader: + self._reader.UpdateDataObject() + output = self._reader.GetOutputDataObject(0).NewInstance() + + outInfo.GetInformationObject(0).Set(vtkDataObject.DATA_OBJECT(), output) + return 1 + + def RequestInformation(self, request, inInfo, outInfo): + if self._reader: + self._reader.UpdateInformation() + info = self._reader.GetOutputInformation(0) + whole_extent = info.Get(vtkStreamingDemandDrivenPipeline.WHOLE_EXTENT()) + outInfo.GetInformationObject(0).Set( + vtkStreamingDemandDrivenPipeline.WHOLE_EXTENT(), *whole_extent + ) + + return 1 + def RequestData(self, request, inInfo, outInfo): """implementation of the vtk algorithm for generating the VTK mesh""" # Use open data_array handle to fetch data at # desired Level of Detail - if self._reader is None: - return 0 - - try: - output = None + if self._reader is not None: + pdo = self.GetOutputData(outInfo, 0) if self.t_size: t = self._arrays[self.t][self.t_index] self._reader.UpdateTimeStep(t) - # print("Update VTK time", t) - - # update arrays - for name in self.available_arrays: - active = 1 if name in self._array_names else 0 - self._reader.SetVariableArrayStatus(name, active) - # generate the mesh mesh = self._reader() # Compute derived quantity if self._pipeline is not None: - output = self._pipeline(mesh) + mesh = self._pipeline(mesh) + pdo.ShallowCopy(mesh) else: - output = mesh + pdo.ShallowCopy(mesh) - # set it as output - print("output => ", output.GetClassName()) - filter_output = VTK_DATASETS[output.GetClassName()].GetData(outInfo) - print(f"{filter_output=}") # <= it is None... - filter_output.ShallowCopy(output) - - except Exception as e: - traceback.print_exc() - raise e - return 1 + return 1 + return 0 From 04eac6ebe3821724850f54384f9562d94bc24a20 Mon Sep 17 00:00:00 2001 From: Sebastien Jourdain Date: Fri, 10 Jan 2025 08:56:07 -0700 Subject: [PATCH 04/10] fix(vtkreader): need to set PAN3D_USE_VTK_XARRAY to enable reader --- pan3d/viewers/preview.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pan3d/viewers/preview.py b/pan3d/viewers/preview.py index 050e722..38ed553 100644 --- a/pan3d/viewers/preview.py +++ b/pan3d/viewers/preview.py @@ -1,3 +1,4 @@ +import os from vtkmodules.vtkInteractionWidgets import vtkOrientationMarkerWidget from vtkmodules.vtkRenderingAnnotation import vtkAxesActor from vtkmodules.vtkCommonCore import vtkLookupTable @@ -129,11 +130,14 @@ def _setup_vtk(self): self.interactor.GetInteractorStyle().SetCurrentStyleToTrackballCamera() self.lut = vtkLookupTable() - try: - from pan3d.xarray.vtk import vtkXArraySource + if "PAN3D_USE_VTK_XARRAY" in os.environ: + try: + from pan3d.xarray.vtk import vtkXArraySource - self.source = vtkXArraySource(input=self.xarray) - except ImportError: + self.source = vtkXArraySource(input=self.xarray) + except ImportError: + self.source = vtkXArrayRectilinearSource(input=self.xarray) + else: self.source = vtkXArrayRectilinearSource(input=self.xarray) # Need explicit geometry extraction when used with WASM From 7159cb36f5e99bc1b38355d3eabc6b7d25abc86d Mon Sep 17 00:00:00 2001 From: Sebastien Jourdain Date: Fri, 17 Jan 2025 17:24:41 -0700 Subject: [PATCH 05/10] wip: cf reader --- .codespellrc | 1 + .gitignore | 1 + pan3d/xarray/cf/__init__.py | 0 pan3d/xarray/cf/__main__.py | 130 +++++ pan3d/xarray/cf/constants.py | 362 ++++++++++++++ pan3d/xarray/cf/coords/__init__.py | 0 pan3d/xarray/cf/coords/convert.py | 18 + .../cf/coords/coords_mesh_rectilinear.py | 43 ++ .../xarray/cf/coords/coords_mesh_spherical.py | 64 +++ .../cf/coords/coords_mesh_unstructured.py | 10 + pan3d/xarray/cf/coords/index_mapping.py | 95 ++++ pan3d/xarray/cf/coords/meta.py | 444 ++++++++++++++++ pan3d/xarray/cf/coords/parametric_vertical.py | 317 ++++++++++++ pan3d/xarray/cf/mesh/__init__.py | 4 + pan3d/xarray/cf/mesh/rectilinear.py | 27 + pan3d/xarray/cf/mesh/structured.py | 166 ++++++ pan3d/xarray/cf/mesh/uniform.py | 30 ++ pan3d/xarray/cf/mesh/unstructured.py | 4 + pan3d/xarray/cf/usecases/__init__.py | 0 pan3d/xarray/cf/usecases/rom.py | 85 ++++ pan3d/xarray/cf/utils.py | 472 ++++++++++++++++++ 21 files changed, 2273 insertions(+) create mode 100644 pan3d/xarray/cf/__init__.py create mode 100644 pan3d/xarray/cf/__main__.py create mode 100644 pan3d/xarray/cf/constants.py create mode 100644 pan3d/xarray/cf/coords/__init__.py create mode 100644 pan3d/xarray/cf/coords/convert.py create mode 100644 pan3d/xarray/cf/coords/coords_mesh_rectilinear.py create mode 100644 pan3d/xarray/cf/coords/coords_mesh_spherical.py create mode 100644 pan3d/xarray/cf/coords/coords_mesh_unstructured.py create mode 100644 pan3d/xarray/cf/coords/index_mapping.py create mode 100644 pan3d/xarray/cf/coords/meta.py create mode 100644 pan3d/xarray/cf/coords/parametric_vertical.py create mode 100644 pan3d/xarray/cf/mesh/__init__.py create mode 100644 pan3d/xarray/cf/mesh/rectilinear.py create mode 100644 pan3d/xarray/cf/mesh/structured.py create mode 100644 pan3d/xarray/cf/mesh/uniform.py create mode 100644 pan3d/xarray/cf/mesh/unstructured.py create mode 100644 pan3d/xarray/cf/usecases/__init__.py create mode 100644 pan3d/xarray/cf/usecases/rom.py create mode 100644 pan3d/xarray/cf/utils.py diff --git a/.codespellrc b/.codespellrc index 5c5ba3b..84f98a9 100644 --- a/.codespellrc +++ b/.codespellrc @@ -1,2 +1,3 @@ [codespell] skip = CHANGELOG.md,tests/data/* +ignore-words-list = degreee \ No newline at end of file diff --git a/.gitignore b/.gitignore index d3213a4..62ac0f0 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ node_modules .venv* *.nc *.data-origin.json +*.vtk examples/jupyter/*.gif site/* diff --git a/pan3d/xarray/cf/__init__.py b/pan3d/xarray/cf/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pan3d/xarray/cf/__main__.py b/pan3d/xarray/cf/__main__.py new file mode 100644 index 0000000..b6008d6 --- /dev/null +++ b/pan3d/xarray/cf/__main__.py @@ -0,0 +1,130 @@ +import sys +from xarray.tutorial import open_dataset +from .utils import XArrayDataSetCFHelper +from pathlib import Path +import xarray as xr +from vtkmodules.vtkIOLegacy import vtkDataSetWriter + +from pan3d.catalogs.xarray import ALL_ENTRIES +from .coords.meta import MetaArrayMapping + + +def main(): + output_name = "test.vtk" + if Path(sys.argv[-1]).exists(): + input_file = Path(sys.argv[-1]).resolve() + output_name = f"{input_file.name}.vtk" + ds = xr.open_dataset(str(input_file)) + else: + output_name = f"{sys.argv[-1]}.vtk" + ds = open_dataset(sys.argv[-1]) + + helper = XArrayDataSetCFHelper(ds) + print("Arrays:", helper.available_arrays) + + helper.array_selection = [helper.available_arrays[0]] + + print("=" * 60) + print("Coordinates:") + print("=" * 60) + for name in ds.coords: + print(helper.get_info(name)) + print("=" * 60) + print("Data:") + print("=" * 60) + for name in ds.data_vars: + print(helper.get_info(name)) + print("=" * 60) + + if helper.x_info: + print(f"{'-'*60}\nX\n{'-'*60}", helper.x_info) + if helper.y_info: + print(f"{'-'*60}\nY\n{'-'*60}", helper.y_info) + if helper.z_info: + print(f"{'-'*60}\nZ\n{'-'*60}", helper.z_info) + if helper.t_info: + print(f"{'-'*60}\nT\n{'-'*60}", helper.t_info) + + mesh = helper.mesh + print(mesh) + + writer = vtkDataSetWriter() + writer.SetInputData(mesh) + writer.SetFileName(output_name) + writer.Write() + + +# ----------------------------------------------------------------------------- +def main2(): + DATASETS = [item.get("name") for item in ALL_ENTRIES] + FILES = ["/Users/sebastien.jourdain/Downloads/sampleGenGrid3.nc"] + for ds_name in DATASETS: + ds = open_dataset(ds_name) + meta = MetaArrayMapping(ds) + + print("-" * 60) + print(ds_name) + print("-" * 60) + print(meta) + + for f_path in FILES: + input_file = Path(f_path).resolve() + ds = xr.open_dataset(str(input_file)) + meta = MetaArrayMapping(ds) + + print("-" * 60) + print(input_file.name) + print("-" * 60) + print(meta) + + +# ----------------------------------------------------------------------------- +def save_dataset(ds_name, field): + ds = open_dataset(ds_name) + meta = MetaArrayMapping(ds) + + print(meta) + + print("-" * 60) + print("spherical=ON") + salt_spherical = meta.get_mesh(fields=[field], spherical=True) + + print("-" * 60) + print("spherical=OFF") + salt_euclidian = meta.get_mesh(fields=[field], spherical=False) + print("-" * 60) + + writer = vtkDataSetWriter() + + if salt_spherical: + writer.SetInputData(salt_spherical) + writer.SetFileName(f"{ds_name}-{field}-spherical.vtk") + writer.Write() + + if salt_euclidian: + writer.SetInputData(salt_euclidian) + writer.SetFileName(f"{ds_name}-{field}-euclidian.vtk") + writer.Write() + + +def main3(): + write = [ + # ("ROMS_example","salt"), # ok + # ("ROMS_example","zeta"), # ok + # ("air_temperature", "air"), # ok + # ("air_temperature_gradient", "Tair"), # ok + # ("basin_mask", "basin"), # ok + # ("rasm", "Tair"), # not following convention + ("eraint_uvz", "v"), + # ("ersstv5", "sst"), # bounds + ] + + for ds, field in write: + print("#" * 80) + print(ds) + print("#" * 80) + save_dataset(ds, field) + + +if __name__ == "__main__": + main3() diff --git a/pan3d/xarray/cf/constants.py b/pan3d/xarray/cf/constants.py new file mode 100644 index 0000000..cf6468b --- /dev/null +++ b/pan3d/xarray/cf/constants.py @@ -0,0 +1,362 @@ +from enum import Enum +import re +import numpy as np +from vtkmodules.vtkCommonDataModel import ( + vtkImageData, + vtkRectilinearGrid, + vtkStructuredGrid, + vtkUnstructuredGrid, +) + +LATITUDE_REGEXP = "degrees?_?n" # degrees_north +LONGITUDE_REGEXP = "degrees?_?e" # degrees_east + + +class DimensionInformation: + def __init__(self, xr_dataset, name): + coord = xr_dataset[name] + + self._xr_dataset = xr_dataset + self.name = name + self.attrs = coord.attrs + self.origin = 0 + self.spacing = 1 + self.has_regular_spacing = True # will be updated later + self.unit = Units.UNDEFINED_UNITS + self.dims = coord.dims + self.has_bounds = False + self.dim_size = -1 + + # unit handling + self.unit = Units.extract(coord) + + # is as direct mapping + if coord.dims == (name,): + # coordinates handling (origin, spacing, has_regular_spacing) + self.dim_size = 1 + self.coords = coord.values + self.origin = float(coord.values[0]) + self.spacing = (float(coord.values[-1]) - self.origin) / (coord.size - 1) + tolerance = 0.01 * self.spacing + + for i in range(coord.size): + expected = self.origin + i * self.spacing + truth = float(coord.values[i]) + if not np.isclose(expected, truth, atol=tolerance): + self.has_regular_spacing = False + break + + # direction + if coord.attrs.get("positive", "") == "down": + self.spacing *= -1 + + # bounds + bounds = coord.attrs.get("bounds") + if bounds: + # TODO ? only min/max + self.bounds = None + self.has_bounds = bounds + elif np.issubdtype(coord.dtype, np.number): + self.coords = coord.values + self.bounds = np.zeros(coord.size + 1, dtype=np.double) + self.bounds[0] = float(coord.values[0]) - 0.5 * self.spacing + for i in range(1, coord.size): + center = 0.5 * float(coord.values[i - 1] + coord.values[i]) + self.bounds[i] = center + self.bounds[-1] = float(coord.values[-1]) + 0.5 * self.spacing + else: # Fake coordinates + xr_dataset + # TODO should do a better job + self.origin = 0 + self.spacing = 1 + self.coords = np.linspace( + 0, coord.size - 1, num=coord.size, dtype=np.double + ) + self.bounds = np.linspace( + -0.5, coord.size - 0.5, num=coord.size + 1, dtype=np.double + ) + + def __repr__(self): + return f""" +{self.name}: + - origin : {self.origin} + - spacing : {self.spacing} + - uniform : {self.has_regular_spacing} + - unit : {self.unit} + - dimensions : {self.dims} + - bounds : {self.has_bounds} + - attributes : {self.attrs} + """ + + +class Scale(Enum): + da = (1e1, {"deca", "deka"}) + h = (1e2, {"hecto"}) + k = (1e3, {"kilo"}) + M = (1e6, {"mega"}) + G = (1e9, {"giga"}) + T = (1e12, {"tera"}) + P = (1e15, {"peta"}) + E = (1e18, {"exa"}) + Z = (1e21, {"zetta"}) + Y = (1e24, {"yotta"}) + d = (1e-1, {"deci"}) + c = (1e-2, {"centi"}) + m = (1e-3, {"milli"}) + u = (1e-6, {"micro"}) + n = (1e-9, {"nano"}) + p = (1e-12, {"pico"}) + f = (1e-15, {"femto"}) + a = (1e-18, {"atto"}) + z = (1e-21, {"zepto"}) + y = (1e-24, {"yocto"}) + IDENTITY = (1, {}) + + def __float__(self): + return self.value[0] + + @classmethod + def from_name(cls, name): + for scale in cls: + if name in scale.value[1]: + return scale + return cls.IDENTITY + + +class Units(Enum): + UNDEFINED_UNITS = "undefined" + TIME_UNITS = "time" + LATITUDE_UNITS = "latitude" + LONGITUDE_UNITS = "longitude" + VERTICAL_UNITS = "vertical" + NUMBER_OF_UNITS = "number" + + @classmethod + def extract(cls, xr_array): + # 1. check "units" + units = xr_array.attrs.get("units", "").lower() + if units: + # time + if units in KNOWN_TIME_UNITS: + return cls.TIME_UNITS + for time_str in KNOWN_TIME_CONTAINS: + if time_str in units: + return cls.TIME_UNITS + + # latitude + if re.search(LATITUDE_REGEXP, units) is not None: + return cls.LATITUDE_UNITS + + # longitude + if re.search(LONGITUDE_REGEXP, units) is not None: + return cls.LONGITUDE_UNITS + + # 2. check "axis" + axis = xr_array.attrs.get("axis", "").lower() + if axis: + if axis == "x": + return cls.LONGITUDE_UNITS + if axis == "y": + return cls.LATITUDE_UNITS + if axis == "z": + return cls.VERTICAL_UNITS + if axis == "t": + return cls.TIME_UNITS + + # 3. use field name + field_name = xr_array.name.lower() + if "time" in field_name: + return cls.TIME_UNITS + if "lat" in field_name: + return cls.LATITUDE_UNITS + if "lon" in field_name: + return cls.LONGITUDE_UNITS + + # 4. data type + if np.issubdtype(xr_array.dtype, np.datetime64): + print("time because of type datetime64") + return cls.TIME_UNITS + if np.issubdtype(xr_array.dtype, np.dtype("O")): + print("time because of type O") + return cls.TIME_UNITS + + # Don't know + return cls.UNDEFINED_UNITS + + +MESH_3D_UNITS = { + Units.LATITUDE_UNITS, + Units.LONGITUDE_UNITS, + Units.VERTICAL_UNITS, +} +MESH_2D_UNITS = { + Units.LATITUDE_UNITS, + Units.LONGITUDE_UNITS, +} + + +class MeshTypes(Enum): + VTK_IMAGE_DATA = vtkImageData + VTK_RECTILINEAR_GRID = vtkRectilinearGrid + VTK_STRUCTURED_GRID = vtkStructuredGrid + VTK_UNSTRUCTURED_GRID = vtkUnstructuredGrid + + @classmethod + def from_coord_type(cls, coord_type): + if coord_type == CoordinateTypes.UNIFORM_RECTILINEAR: + return cls.VTK_IMAGE_DATA + if coord_type == CoordinateTypes.NONUNIFORM_RECTILINEAR: + return cls.VTK_RECTILINEAR_GRID + if coord_type in { + CoordinateTypes.REGULAR_SPHERICAL, + CoordinateTypes.EUCLIDEAN_2D, + CoordinateTypes.SPHERICAL_2D, + CoordinateTypes.EUCLIDEAN_4SIDED_CELLS, + CoordinateTypes.SPHERICAL_4SIDED_CELLS, + }: + return cls.VTK_STRUCTURED_GRID + if coord_type in { + CoordinateTypes.EUCLIDEAN_PSIDED_CELLS, + CoordinateTypes.SPHERICAL_PSIDED_CELLS, + }: + return cls.VTK_STRUCTURED_GRID + + msg = f"Don't have a matching mesh for {coord_type}" + raise ValueError(msg) + + def new(self): + return self.value() + + +class CoordinateTypes(Enum): + UNIFORM_RECTILINEAR = "uniform rectilinear" + NONUNIFORM_RECTILINEAR = "non-uniform rectilinear" + REGULAR_SPHERICAL = "regular spherical" + EUCLIDEAN_2D = "2d euclidean" + SPHERICAL_2D = "2d spherical" + EUCLIDEAN_4SIDED_CELLS = "euclidean 4 sided cells" + SPHERICAL_4SIDED_CELLS = "spherical 4 sided cells" + EUCLIDEAN_PSIDED_CELLS = "euclidean p-sided cells" + SPHERICAL_PSIDED_CELLS = "spherical p-sided cells" + + @classmethod + def get_coordinate_type(cls, xr_dataset, field_name, use_spherical=False): + print(f"{use_spherical=}") + # Remove time axis + dims = [ + array_name + for array_name in xr_dataset[field_name].dims + if Units.extract(xr_dataset[array_name]) != Units.TIME_UNITS + ] + coords = xr_dataset[field_name].coords + cells_unstructured = len(coords) != len(dims) and len(dims) == 1 + has_bounds_count = 0 + + # Check bounds + for coord_name in coords: + if Units.extract(xr_dataset[coord_name]) in MESH_2D_UNITS: + if xr_dataset[coord_name].attrs.get("bounds"): + has_bounds_count += 1 + + if cells_unstructured: + return ( + cls.SPHERICAL_PSIDED_CELLS + if use_spherical + else cls.EUCLIDEAN_PSIDED_CELLS + ) + + if has_bounds_count == 2: + return cls.SPHERICAL_2D if use_spherical else cls.EUCLIDEAN_2D + + name_to_unit = { + dim_name: Units.extract(xr_dataset[dim_name]) for dim_name in dims + } + dim_units = set(name_to_unit.values()) + if use_spherical: + if dim_units >= MESH_2D_UNITS or dim_units >= MESH_3D_UNITS: + return cls.REGULAR_SPHERICAL + + # Check irregular spacing + for name, unit in name_to_unit.items(): + if unit in MESH_3D_UNITS: + info = DimensionInformation(xr_dataset, name) + if not info.has_regular_spacing: + return cls.NONUNIFORM_RECTILINEAR + + return cls.UNIFORM_RECTILINEAR + + @property + def use_point_data(self): + # point_data for only the following types + return self in { + CoordinateTypes.UNIFORM_RECTILINEAR, + CoordinateTypes.NONUNIFORM_RECTILINEAR, + CoordinateTypes.EUCLIDEAN_2D, + CoordinateTypes.SPHERICAL_2D, + } + + +KNOWN_TIME_CONTAINS = { + " since ", + " after ", +} +KNOWN_TIME_UNITS = { + "second", + "seconds", + "day", + "days", + "hour", + "hours", + "minute", + "minutes", + "s", + "sec", + "secs", + "shake", + "shakes", + "sidereal_day", + "sidereal_days", + "sidereal_hour", + "sidereal_hours", + "sidereal_minute", + "sidereal_minutes", + "sidereal_second", + "sidereal_seconds", + "sidereal_year", + "sidereal_years", + "tropical_year", + "tropical_years", + "lunar_month", + "lunar_months", + "common_year", + "common_years", + "leap_year", + "leap_years", + "Julian_year", + "Julian_years", + "Gregorian_year", + "Gregorian_years", + "sidereal_month", + "sidereal_months", + "tropical_month", + "tropical_months", + "d", + "min", + "mins", + "hrs", + "h", + "fortnight", + "fortnights", + "week", + "jiffy", + "jiffies", + "year", + "years", + "yr", + "yrs", + "a", + "eon", + "eons", + "month", + "months", +} diff --git a/pan3d/xarray/cf/coords/__init__.py b/pan3d/xarray/cf/coords/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pan3d/xarray/cf/coords/convert.py b/pan3d/xarray/cf/coords/convert.py new file mode 100644 index 0000000..ed082e6 --- /dev/null +++ b/pan3d/xarray/cf/coords/convert.py @@ -0,0 +1,18 @@ +import math + + +def point_insert(vtk_point, spherical, longitude, latitude, vertical): + if spherical: + longitude = math.pi * longitude / 180 + latitude = math.pi * latitude / 180 + vtk_point.InsertNextPoint( + vertical * math.cos(longitude) * math.cos(latitude), + vertical * math.sin(longitude) * math.cos(latitude), + vertical * math.sin(latitude), + ) + else: + vtk_point.InsertNextPoint( + longitude, + latitude, + vertical, + ) diff --git a/pan3d/xarray/cf/coords/coords_mesh_rectilinear.py b/pan3d/xarray/cf/coords/coords_mesh_rectilinear.py new file mode 100644 index 0000000..5d91425 --- /dev/null +++ b/pan3d/xarray/cf/coords/coords_mesh_rectilinear.py @@ -0,0 +1,43 @@ +def add_imagedata(image_data, coords): + print(f"{coords.extent=}") + print(f"{coords.origin=}") + print(f"{coords.spacing=}") + image_data.SetExtent(*coords.extent) + image_data.SetOrigin(*coords.origin) + image_data.SetSpacing(*coords.spacing) + + +def add_rectilinear(rectilinear_grid): + raise NotImplementedError() + + +def fake_rectilinear(rectilinear_grid): + raise NotImplementedError() + + +def add_1d_points(vtk_point, extent): + raise NotImplementedError() + + +def add_2d_points(vtk_point, extent): + raise NotImplementedError() + + +def add_1d_structured(vtk_structured_grid): + raise NotImplementedError() + + +def add_2d_structured(vtk_structured_grid): + raise NotImplementedError() + + +def fake_structured(vtk_structured_grid): + raise NotImplementedError() + + +def add_1d_unstructured(vtk_unstructured_grid, extent): + raise NotImplementedError() + + +def add_2d_unstructured(vtk_unstructured_grid, extent): + raise NotImplementedError() diff --git a/pan3d/xarray/cf/coords/coords_mesh_spherical.py b/pan3d/xarray/cf/coords/coords_mesh_spherical.py new file mode 100644 index 0000000..cfbeec2 --- /dev/null +++ b/pan3d/xarray/cf/coords/coords_mesh_spherical.py @@ -0,0 +1,64 @@ +import math +from vtkmodules.vtkCommonCore import vtkPoints, vtkMath + + +def add_1d_points(vtk_point, coords): + extent = coords.extent + vtk_point.SetDataTypeToDouble() + vtk_point.Allocate( + (extent[1] - extent[0] + 1) + * (extent[3] - extent[2] + 1) + * (extent[5] - extent[4] + 1) + ) + + # check the height scale and bias + z_scale = coords.vertical_scale + z_bias = coords.vertical_bias + if coords.vertical: + z_min = float(coords.vertical_array.min()) + z_max = float(coords.vertical_array.max()) + if z_min * z_scale + z_bias < 0 or z_max * z_scale + z_bias < 0: + z_bias = -math.min(z_min, z_max) * z_scale + elif (z_scale + z_bias) <= 0: + z_scale = 1 + z_bias = 0 + + # Fill points + longitude = coords.longitude_bounds + latitude = coords.latitude_bounds + vertical = coords.vertical_bounds + for k in range(extent[4], extent[5] + 1): + h = vertical[k] if vertical else 1 + h = h * z_scale + z_bias + for j in range(extent[2], extent[3] + 1): + lat = vtkMath.RadiansFromDegrees(latitude[j]) + for i in range(extent[0], extent[1] + 1): + lon = vtkMath.RadiansFromDegrees(longitude[i]) + vtk_point.InsertNextPoint( + h * math.cos(lon) * math.cos(lat), + h * math.sin(lon) * math.cos(lat), + h * math.sin(lat), + ) + + +def add_2d_points(vtk_point, extent): + raise NotImplementedError() + + +def add_1d_structured(vtk_structured, coords): + vtk_structured.SetExtent(coords.extent) + vtk_points = vtkPoints() + add_1d_points(vtk_points, coords) + vtk_structured.SetPoints(vtk_points) + + +def add_2d_structured(vtk_point, extent): + raise NotImplementedError() + + +def add_1d_unstructured(vtk_point, extent): + raise NotImplementedError() + + +def add_2d_unstructured(vtk_point, extent): + raise NotImplementedError() diff --git a/pan3d/xarray/cf/coords/coords_mesh_unstructured.py b/pan3d/xarray/cf/coords/coords_mesh_unstructured.py new file mode 100644 index 0000000..bc89c61 --- /dev/null +++ b/pan3d/xarray/cf/coords/coords_mesh_unstructured.py @@ -0,0 +1,10 @@ +def add_structured_cells(unstructured_grid, extent): + raise NotImplementedError() + + +def add_unstructured_rectilinear_coordinates(unstructured_grid, extent): + raise NotImplementedError() + + +def add_unstructured_spherical_coordinates(unstructured_grid, extent): + raise NotImplementedError() diff --git a/pan3d/xarray/cf/coords/index_mapping.py b/pan3d/xarray/cf/coords/index_mapping.py new file mode 100644 index 0000000..d75e973 --- /dev/null +++ b/pan3d/xarray/cf/coords/index_mapping.py @@ -0,0 +1,95 @@ +def get_formula(metadata, kji_dims): + lon_mapper = get_value_mapper(metadata.xr_dataset, kji_dims, metadata.longitude) + lat_mapper = get_value_mapper(metadata.xr_dataset, kji_dims, metadata.latitude) + return CoordMapper(lon_mapper, lat_mapper) + + +class CoordMapper: + def __init__(self, lon_mapper, lat_mapper): + self.lon = lon_mapper + self.lat = lat_mapper + + def __call__(self, i=0, j=0, k=0): + return self.lon(i=i, j=j, k=k), self.lat(i=i, j=j, k=k) + + +def get_value_mapper(xr_dataset, in_dims, out_name): + out_slice_size = len(xr_dataset[out_name].dims) + if out_slice_size == 3: + return IndexMapper3(xr_dataset, in_dims, out_name) + if out_slice_size == 2: + return IndexMapper2(xr_dataset, in_dims, out_name) + if out_slice_size == 1: + return IndexMapper1(xr_dataset, in_dims, out_name) + + msg = f"No IndexMapper for dimensions {xr_dataset[out_name].dims}" + raise ValueError(msg) + + +class IndexMapper: + def __init__(self, xr_dataset, in_dims, out_name): + self.out_array = xr_dataset[out_name].values + + name_to_ijk = {in_dims[-(i + 1)]: "ijk"[i] for i in range(len(in_dims))} + out_dims = xr_dataset[out_name].dims + map_method_name = "".join([name_to_ijk[name] for name in out_dims]) + print(out_name, "=>", map_method_name) + + setattr(self, "fn", getattr(self, map_method_name)) + + def __call__(self, **kwargs): + return self.fn(**kwargs) + + +class IndexMapper3(IndexMapper): + + def ijk(self, i=0, j=0, k=0, **_): + return self.out_array[i, j, k] + + def ikj(self, i=0, j=0, k=0, **_): + return self.out_array[i, k, j] + + def jki(self, i=0, j=0, k=0, **_): + return self.out_array[j, k, i] + + def kij(self, i=0, j=0, k=0, **_): + return self.out_array[k, i, j] + + def jik(self, i=0, j=0, k=0, **_): + return self.out_array[j, i, k] + + def kji(self, i=0, j=0, k=0, **_): + return self.out_array[k, j, i] + + +class IndexMapper2(IndexMapper): + + def ij(self, i=0, j=0, k=0, **_): + return self.out_array[i, j] + + def ji(self, i=0, j=0, k=0, **_): + return self.out_array[j, i] + + def ik(self, i=0, j=0, k=0, **_): + return self.out_array[i, k] + + def ki(self, i=0, j=0, k=0, **_): + return self.out_array[k, i] + + def jk(self, i=0, j=0, k=0, **_): + return self.out_array[j, k] + + def kj(self, i=0, j=0, k=0, **_): + return self.out_array[k, j] + + +class IndexMapper1(IndexMapper): + + def i(self, i=0, j=0, k=0, **_): + return self.out_array[i] + + def j(self, i=0, j=0, k=0, **_): + return self.out_array[j] + + def k(self, i=0, j=0, k=0, **_): + return self.out_array[k] diff --git a/pan3d/xarray/cf/coords/meta.py b/pan3d/xarray/cf/coords/meta.py new file mode 100644 index 0000000..871ccf7 --- /dev/null +++ b/pan3d/xarray/cf/coords/meta.py @@ -0,0 +1,444 @@ +from enum import Enum +import numpy as np +from pan3d.xarray.cf import mesh + +PRESSURE_UNITS = { + "bar", + "bar", + "millibar", + "decibar", + "atmosphere", + "atm", + "pascal", + "Pa", + "hPa", +} +LENGTH_UNITS = { + "meter", + "metre", + "m", + "kilometer", + "km", +} + +COORDINATES_DETECTION = { + "longitude": { + "units": { + "degrees_east", + "degree_east", + "degree_e", + "degrees_e", + "degreee", # codespell:ignore + "degreese", + }, + "axis": {"X"}, + "std_name": { + "longitude", + }, + }, + "latitude": { + "units": { + "degrees_north", + "degree_north", + "degree_n", + "degrees_n", + "degreen", + "degreesn", + }, + "axis": {"Y"}, + "std_name": { + "latitude", + }, + }, + "vertical": { + "units": {}, + "axis": {"Z"}, + "std_name": { + "depth", + "level", + }, + "positive": { + "up", + "down", + }, + }, + "time": { + "units": { + "since", + "second", + "seconds", + "day", + "days", + "hour", + "hours", + "minute", + "minutes", + "s", + "sec", + "secs", + "shake", + "shakes", + "sidereal_day", + "sidereal_days", + "sidereal_hour", + "sidereal_hours", + "sidereal_minute", + "sidereal_minutes", + "sidereal_second", + "sidereal_seconds", + "sidereal_year", + "sidereal_years", + "tropical_year", + "tropical_years", + "lunar_month", + "lunar_months", + "common_year", + "common_years", + "leap_year", + "leap_years", + "Julian_year", + "Julian_years", + "Gregorian_year", + "Gregorian_years", + "sidereal_month", + "sidereal_months", + "tropical_month", + "tropical_months", + "d", + "min", + "mins", + "hrs", + "h", + "fortnight", + "fortnights", + "week", + "jiffy", + "jiffies", + "year", + "years", + "yr", + "yrs", + "a", + "eon", + "eons", + "month", + "months", + }, + "axis": {"T"}, + "std_name": { + "time", + }, + "calendar": { + "standard", + "utc", + "proleptic_gregorian", + "julian", + "tai", + "noleap", + "365_day", + "all_leap", + "366_day", + "360_day", + "none", + }, + }, +} + + +def is_uniform(array): + origin = float(array[0]) + spacing = (float(array[-1]) - origin) / (array.size - 1) + tolerance = 0.01 * spacing + + for i in range(array.size): + expected = origin + i * spacing + truth = float(array[i]) + if not np.isclose(expected, truth, atol=tolerance): + return False + + return True + + +class CoordinateType(Enum): + LONGITUDE = "longitude" + LATITUDE = "latitude" + VERTICAL = "vertical" + TIME = "time" + UNKNOWN = "unknown" + + @classmethod + def from_array(cls, xr_array): + axis = xr_array.attrs.get("axis", "") + units = xr_array.attrs.get("units", "").lower() + std_name = xr_array.attrs.get("standard_name", "") + positive = xr_array.attrs.get("positive", "") + calendar = xr_array.attrs.get("calendar", "") + + # axis provide a direct mapping if available + if axis: + if axis == "X": + return cls.LONGITUDE + if axis == "Y": + return cls.LATITUDE + if axis == "Z": + return cls.VERTICAL + if axis == "T": + return cls.TIME + + # units + if units: + if units in COORDINATES_DETECTION["longitude"]["units"]: + return cls.LONGITUDE + if units in COORDINATES_DETECTION["latitude"]["units"]: + return cls.LATITUDE + + # Time unit check + t_unit = COORDINATES_DETECTION["time"]["units"] + unit_token = set(units.split(" ")) + if len(unit_token & t_unit) == 2: + return cls.TIME + + # unique attribute + if positive: + return cls.VERTICAL + if calendar: + return cls.TIME + + # std name + if std_name in COORDINATES_DETECTION["longitude"]["std_name"]: + return cls.LONGITUDE + if std_name in COORDINATES_DETECTION["latitude"]["std_name"]: + return cls.LATITUDE + if std_name in COORDINATES_DETECTION["time"]["std_name"]: + return cls.TIME + if std_name in COORDINATES_DETECTION["vertical"]["std_name"]: + return cls.VERTICAL + + # based on data type + if np.issubdtype(xr_array.dtype, np.datetime64): + return cls.TIME + if np.issubdtype(xr_array.dtype, np.dtype("O")): + return cls.TIME + + return cls.UNKNOWN + + @classmethod + def can_be_vertical(cls, xr_array): + units = xr_array.attrs.get("units", "").lower() + + if xr_array.dims != (xr_array.name,): + return False + + # vertical tends to be harder to detect + for unit in PRESSURE_UNITS: + if unit in units: + return True + + for unit in LENGTH_UNITS: + if unit in units: + return True + + return False + + @classmethod + def can_be_time(cls, xr_array): + if xr_array.dims != (xr_array.name,): + return False + + return xr_array.name in COORDINATES_DETECTION["time"]["units"] + + +class MetaArrayMapping: + def __init__(self, xr_dataset): + self.xr_dataset = xr_dataset + self.conventions = ( + xr_dataset.Conventions if hasattr(xr_dataset, "Conventions") else None + ) + self.data_arrays = {} + self.longitude = None + self.latitude = None + self.vertical = None + self.time = None + self.valid = False + self.vertical_bias = 6378137 + self.vertical_scale = 100 + + if self.conventions is not None: + for convention in {"COARDS", "CF-1"}: + if convention in self.conventions: + self.valid = True + break + + # start extracting coordinates from data dimensions + for array_name in xr_dataset: + # skip bounds array as data array + if "bnd" in array_name or "bound" in array_name: + continue + + dims = xr_dataset[array_name].dims + if dims not in self.data_arrays: + self.data_arrays[dims] = [array_name] + for coord_name in dims: + coord_type = CoordinateType.from_array(xr_dataset[coord_name]) + if coord_type != CoordinateType.UNKNOWN: + setattr(self, coord_type.value, coord_name) + + # Extended binding if not found + if self.vertical is None and CoordinateType.can_be_vertical( + xr_dataset[array_name] + ): + self.vertical = array_name + if self.time is None and CoordinateType.can_be_time( + xr_dataset[coord_name] + ): + self.time = coord_name + else: + self.data_arrays[dims].append(array_name) + + # inspect coordinates if not already filled + if any( + coord is None for coord in [self.longitude, self.latitude, self.vertical] + ): + for coord_name in xr_dataset.coords: + coord_type = CoordinateType.from_array(xr_dataset[coord_name]) + if ( + coord_type != CoordinateType.UNKNOWN + and getattr(self, coord_type.value) is None + ): + setattr(self, coord_type.value, coord_name) + + # Extended binding if not found + if self.vertical is None and CoordinateType.can_be_vertical( + xr_dataset[coord_name] + ): + self.vertical = coord_name + + def __repr__(self): + data_lines = [] + for dims, array_names in self.data_arrays.items(): + data_lines.append(f" - {dims}:") + for name in array_names: + data_lines.append(f" - {name}") + data_str = "\n".join(data_lines) + return f"""Conventions: {self.conventions} {'✅' if self.valid else '❌'} +Coordinates: + - longitude : {self.longitude} + - latitude : {self.latitude} + - vertical : {self.vertical} + - time : {self.time} +Computed: + - has_bound : {self.coords_has_bounds} + - uniform : {self.uniform_spacing} +Data: +{data_str} +""" + + @property + def coords_has_bounds(self): + if self.latitude is None or self.longitude is None: + return None + + lon_bnd = self.xr_dataset[self.longitude].attrs.get("bounds") + lat_bnd = self.xr_dataset[self.latitude].attrs.get("bounds") + return lon_bnd is not None and lat_bnd is not None + + @property + def uniform_spacing(self): + uniform = ( + self.coords_1d + and is_uniform(self.xr_dataset[self.longitude].values) + and is_uniform(self.xr_dataset[self.latitude].values) + ) + + if self.vertical is not None and uniform: + uniform = len(self.xr_dataset[self.vertical].dims) == 1 and is_uniform( + self.xr_dataset[self.vertical].values + ) + + return uniform + + @property + def coords_1d(self): + vertical_ok = True + if self.vertical is not None: + vertical_ok = len( + self.xr_dataset[self.vertical].dims + ) == 1 and self.xr_dataset[self.vertical].dims == (self.vertical,) + + return ( + vertical_ok + and self.longitude is not None + and len(self.xr_dataset[self.longitude].dims) == 1 + and self.xr_dataset[self.longitude].dims == (self.longitude,) + and self.latitude is not None + and len(self.xr_dataset[self.latitude].dims) == 1 + and self.xr_dataset[self.latitude].dims == (self.latitude,) + ) + + def use_coords(self, dims): + if self.longitude not in dims: + return False + if self.latitude not in dims: + return False + if self.vertical is not None and self.vertical not in dims: + return False + return True + + def get_mesh(self, time_index=0, spherical=True, fields=None): + vtk_mesh, data_location = None, None + + # ensure similar dimension across array names + data_dims = self.xr_dataset[fields[0]].dims + data_dims_no_time = data_dims[1:] if data_dims[0] == self.time else data_dims + valid_data_array_names = [ + n for n in fields if self.xr_dataset[n].dims == data_dims + ] + + # No mesh if no lon/lat + if self.longitude is None or self.latitude is None: + return vtk_mesh + + # Unstructured + if len(data_dims_no_time) == 1: + vtk_mesh, data_location = mesh.unstructured.generate_mesh( + self, data_dims_no_time, time_index, spherical + ) + + # Structured + if vtk_mesh is None and ( + self.coords_has_bounds or spherical or not self.coords_1d + ): + vtk_mesh, data_location = mesh.structured.generate_mesh( + self, data_dims_no_time, time_index, spherical + ) + + # This should only happen if we don't want spherical + if vtk_mesh is None: + assert not spherical + + # Rectilinear + if vtk_mesh is None and not self.uniform_spacing: + vtk_mesh, data_location = mesh.rectilinear.generate_mesh( + self, data_dims_no_time, time_index + ) + + # Uniform + if vtk_mesh is None: + vtk_mesh, data_location = mesh.uniform.generate_mesh( + self, data_dims_no_time, time_index + ) + + # Add fields + if vtk_mesh: + container = getattr(vtk_mesh, data_location) + for field_name in valid_data_array_names: + field = ( + self.xr_dataset[field_name][time_index].values + if self.time + else self.xr_dataset[field_name].values + ) + container[field_name] = field.ravel() + else: + print(" !!! No mesh for data") + + return vtk_mesh diff --git a/pan3d/xarray/cf/coords/parametric_vertical.py b/pan3d/xarray/cf/coords/parametric_vertical.py new file mode 100644 index 0000000..a443a84 --- /dev/null +++ b/pan3d/xarray/cf/coords/parametric_vertical.py @@ -0,0 +1,317 @@ +""" +Based on Parametric Vertical Coordinates appendix from CF-1.12 spec + +https://cfconventions.org/Data/cf-conventions/cf-conventions-1.12/cf-conventions.html#parametric-v-coord +""" + +import math +import sys +import inspect + +CONVENTION_BASE_URL = "https://cfconventions.org/Data/cf-conventions/cf-conventions-1.12/cf-conventions.html" + + +# ----------------------------------------------------------------------------- +# Factory method +# ----------------------------------------------------------------------------- +def get_formula(xr_dataset, name, bias=0, scale=1): + array_attributes = xr_dataset[name].attrs + std_name = array_attributes.get("standard_name") + formula_terms = array_attributes.get("formula_terms") + + formula_classes = inspect.getmembers(sys.modules[__name__], inspect.isclass) + for klass_name, klass in formula_classes: + if std_name == klass.name: + return FormulaAdapter( + klass(xr_dataset, **extract_formula_terms(formula_terms)), + bias=bias, + scale=scale, + ) + + return None + + +# ----------------------------------------------------------------------------- +# Helpers +# ----------------------------------------------------------------------------- +def extract_formula_terms(formula_terms: str): + tokens = formula_terms.split(" ") + if len(tokens) % 2 != 0: + msg = f"Invalid key/value pairing: {tokens}" + raise ValueError(msg) + + key_mapping = {} + nb_keys = int(len(tokens) / 2) + for i in range(nb_keys): + k = tokens[i * 2][:-1] + v = tokens[i * 2 + 1] + key_mapping[k] = v + + return key_mapping + + +# ----------------------------------------------------------------------------- +class AbstractFormula: + name = "__abstract__" + + def __init__(self, xr_dataset, **name_mapping): + for k, v in name_mapping.items(): + setattr(self, k, xr_dataset[v].values) + + self.select_formula(name_mapping) + + def select_formula(self, name_mapping): + pass + + +class FormulaAdapter: + name = "__internal__" + + def __init__(self, formula, bias=0, scale=1): + self._fn = formula + self._bias = bias + self._scale = scale + print(f"{bias=} {scale=}") + + def __call__(self, n=0, k=0, j=0, i=0): + return self._bias + self._scale * self._fn(n=n, k=k, j=j, i=i) + + +# ----------------------------------------------------------------------------- +# Atmosphere natural log pressure coordinate +# ----------------------------------------------------------------------------- +class AtmosphereNaturalLogPressureCoordinate(AbstractFormula): + url = f"{CONVENTION_BASE_URL}#atmosphere-natural-log-pressure-coordinate" + name = "atmosphere_ln_pressure_coordinate" + name_computed = "air_pressure" + std_p0 = "reference_air_pressure_for_atmosphere_vertical_coordinate" + + def __len__(self): + return 1 + + def __call__(self, k, **_): + return self.p0 * math.exp(-self.lev[k]) + + +# ----------------------------------------------------------------------------- +# Atmosphere sigma coordinate +# ----------------------------------------------------------------------------- +class AtmosphereSigmaCoordinate(AbstractFormula): + url = f"{CONVENTION_BASE_URL}#_atmosphere_sigma_coordinate" + name = "atmosphere_sigma_coordinate" + name_computed = "air_pressure" + std_ptop = "air_pressure_at_top_of_atmosphere_model" + std_ps = "surface_air_pressure" + + def __len__(self): + return 4 + + def __call__(self, n, k, j, i): + return self.ptop + self.sigma[k] * (self.ps[n, j, i] - self.ptop) + + +# ----------------------------------------------------------------------------- +# Atmosphere hybrid sigma pressure coordinate +# ----------------------------------------------------------------------------- +class AtmosphereHybridSigmaPressureCoordinate(AbstractFormula): + url = f"{CONVENTION_BASE_URL}#_atmosphere_hybrid_sigma_pressure_coordinate" + name = "atmosphere_hybrid_sigma_pressure_coordinate" + name_computed = "air_pressure" + std_p0 = "reference_air_pressure_for_atmosphere_vertical_coordinate" + std_ps = "surface_air_pressure" + + def select_formula(self, key_mapping): + if "p0" in key_mapping: + setattr(self, "__call__", self._a_p0) + else: + setattr(self, "__call__", self._ap) + + def _a_p0(self, n, k, j, i): + return self.a[k] * self.p0 + self.b[k] * self.ps[n, j, i] + + def _ap(self, n, k, j, i): + return self.ap[k] + self.b[k] * self.ps[n, j, i] + + def __len__(self): + return 4 + + +# ----------------------------------------------------------------------------- +# Atmosphere hybrid height coordinate +# ----------------------------------------------------------------------------- +class AtmosphereHybridHeightCoordinate(AbstractFormula): + url = f"{CONVENTION_BASE_URL}#atmosphere-hybrid-height-coordinate" + name = "atmosphere_hybrid_height_coordinate" + name_computed = ("altitude", "height_above_geopotential_datum") + std_orog = ("surface_altitude", "surface_height_above_geopotential_datum") + + def __call__(self, n, k, j, i): + return self.a[k] + self.b[k] * self.orog[n, j, i] + + def __len__(self): + return 4 + + +# ----------------------------------------------------------------------------- +# Atmosphere smooth level vertical (SLEVE) coordinate +# ----------------------------------------------------------------------------- +class AtmosphereSmoothLevelVerticalCoordinate(AbstractFormula): + url = f"{CONVENTION_BASE_URL}#_atmosphere_smooth_level_vertical_sleve_coordinate" + name = "atmosphere_sleve_coordinate" + name_computed = ("altitude", "height_above_geopotential_datum") + std_ztop = ( + "altitude_at_top_of_atmosphere_model", + "height_above_geopotential_datum_at_top_of_atmosphere_model", + ) + + def __call__(self, n, k, j, i): + return ( + self.a[k] * self.ztop + + self.b1[k] * self.zsurf1[n, j, i] + + self.b2[k] * self.zsurf2[n, j, i] + ) + + def __len__(self): + return 4 + + +# ----------------------------------------------------------------------------- +# Ocean sigma coordinate +# ----------------------------------------------------------------------------- +class OceanSigmaCoordinate(AbstractFormula): + url = f"{CONVENTION_BASE_URL}#_ocean_sigma_coordinate" + name = "ocean_sigma_coordinate" + + def __call__(self, n, k, j, i): + return self.eta[n, j, i] + self.sigma[k] * ( + self.depth[j, i] + self.eta[n, j, i] + ) + + def __len__(self): + return 4 + + +# ----------------------------------------------------------------------------- +# Ocean s-coordinate +# ----------------------------------------------------------------------------- +class OceanSCoordinate(AbstractFormula): + url = f"{CONVENTION_BASE_URL}#_ocean_s_coordinate" + name = "ocean_s_coordinate" + + def __call__(self, n, k, j, i): + c_k = (1 - self.b) * math.sinh(self.a * self.s[k]) / math.sinh( + self.a + ) + self.b * [ + math.tanh(self.a * (self.s[k] + 0.5)) / (2 * math.tanh(0.5 * self.a)) - 0.5 + ] + return ( + self.eta[n, j, i] * (1 + self.s[k]) + + self.depth_c * self.s[k] + + (self.depth[j, i] - self.depth_c) * c_k + ) + + def __len__(self): + return 4 + + +# ----------------------------------------------------------------------------- +# Ocean s-coordinate, generic form 1 +# ----------------------------------------------------------------------------- +class OceanSCoordinateGenericForm1(AbstractFormula): + url = f"{CONVENTION_BASE_URL}#_ocean_s_coordinate_generic_form_1" + name = "ocean_s_coordinate_g1" + + def __call__(self, n, k, j, i): + s_kji = self.depth_c * self.s[k] + (self.depth[j, i] - self.depth_c) * self.C[k] + return s_kji + self.eta[n, j, i] * (1 + self.S[k, j, i] / self.depth[j, i]) + + def __len__(self): + return 4 + + +# ----------------------------------------------------------------------------- +# Ocean s-coordinate, generic form 2 +# ----------------------------------------------------------------------------- +class OceanSCoordinateGenericForm2(AbstractFormula): + url = f"{CONVENTION_BASE_URL}#_ocean_s_coordinate_generic_form_2" + name = "ocean_s_coordinate_g2" + std_names = { + "altitude": { + "zlev": "altitude", + "eta": "sea_surface_height_above_geoid", + "depth": "sea_floor_depth_below_geoid", + }, + "height_above_geopotential_ datum": { + "zlev": "height_above_geopotential_datum", + "eta": "sea_surface_height_above_geopotential_datum", + "depth": "sea_floor_depth_below_geopotential_datum", + }, + "height_above_reference_ ellipsoid": { + "zlev": "height_above_reference_ellipsoid", + "eta": "sea_surface_height_above_reference_ellipsoid", + "depth": "sea_floor_depth_below_reference_ellipsoid", + }, + "height_above_mean_sea_ level": { + "zlev": "height_above_mean_sea_level", + "eta": "sea_surface_height_above_mean_ sea_level", + "depth": "sea_floor_depth_below_mean_ sea_level", + }, + } + + def __call__(self, n, k, j, i): + s_kji = (self.depth_c * self.s[k] + self.depth[j, i] * self.C[k]) / ( + self.depth_c + self.depth[j, i] + ) + return self.eta[n, j, i] + (self.eta[n, j, i] + self.depth[j, i]) * s_kji + + def __len__(self): + return 4 + + +# ----------------------------------------------------------------------------- +# Ocean sigma over z coordinate +# ----------------------------------------------------------------------------- +class OceanSigmaOverZCoordinate(AbstractFormula): + url = f"{CONVENTION_BASE_URL}#_ocean_sigma_over_z_coordinate" + name = "ocean_sigma_z_coordinate" + + def select_formula(self, key_mapping): + if "sigma" in key_mapping and "zlev" not in key_mapping: + setattr(self, "__call__", self._sigma) + elif "zlev" in key_mapping: + setattr(self, "__call__", self._z_lev) + else: + msg = f"No formula for 'ocean_sigma_z_coordinate' with given formula: {key_mapping}" + raise ValueError(msg) + + def _z_lev(self, n, k, j, i): + return self.zlev[k] + + def _sigma(self, n, k, j, i): + return self.eta[n, j, i] + self.sigma[k] * ( + min(self.depth_c, self.depth[j, i]) + self.eta[n, j, i] + ) + + def __len__(self): + return 4 + + +# ----------------------------------------------------------------------------- +# Ocean double sigma coordinate +# ----------------------------------------------------------------------------- +class OceanDoubleSigmaCoordinate(AbstractFormula): + url = f"{CONVENTION_BASE_URL}#_ocean_double_sigma_coordinate" + name = "ocean_double_sigma_coordinate" + + def __call__(self, k, j, i, **_): + f_ji = 0.5 * (self.z1 + self.z2) + 0.5 * (self.z1 - self.z2) * math.tanh( + 2 * self.a / (self.z1 - self.z2) * (self.depth[j, i] - self.href) + ) + + if k <= self.k_c: + return self.sigma[k] * f_ji + else: + return f_ji + (self.sigma[k] - 1) * (self.depth[j, i] - f_ji) + + def __len__(self): + return 3 diff --git a/pan3d/xarray/cf/mesh/__init__.py b/pan3d/xarray/cf/mesh/__init__.py new file mode 100644 index 0000000..bccb22b --- /dev/null +++ b/pan3d/xarray/cf/mesh/__init__.py @@ -0,0 +1,4 @@ +from . import unstructured # noqa: F401 +from . import structured # noqa: F401 +from . import rectilinear # noqa: F401 +from . import uniform # noqa: F401 diff --git a/pan3d/xarray/cf/mesh/rectilinear.py b/pan3d/xarray/cf/mesh/rectilinear.py new file mode 100644 index 0000000..cdf26e5 --- /dev/null +++ b/pan3d/xarray/cf/mesh/rectilinear.py @@ -0,0 +1,27 @@ +import numpy as np +from vtkmodules.vtkCommonDataModel import vtkRectilinearGrid + + +def generate_mesh(metadata, dimensions, time_index): + data_location = "point_data" + extent = [0, 0, 0, 0, 0, 0] + empty_coords = np.zeros((1,), dtype=np.double) + arrays = [empty_coords, empty_coords, empty_coords] + + assert metadata.coords_1d + + for idx in range(len(dimensions)): + array = metadata.xr_dataset[dimensions[-(1 + idx)]] + arrays[idx] = array.values + + # Fill in reverse order (t, z, y, x) => [0, x.size, 0, y.size, 0, z.size] + # And extent include both index so (len-1) + extent[idx * 2 + 1] = array.size - 1 + + mesh = vtkRectilinearGrid() + mesh.x_coordinates = arrays[0] + mesh.y_coordinates = arrays[1] + mesh.z_coordinates = arrays[2] + mesh.extent = extent + + return mesh, data_location diff --git a/pan3d/xarray/cf/mesh/structured.py b/pan3d/xarray/cf/mesh/structured.py new file mode 100644 index 0000000..0759cbe --- /dev/null +++ b/pan3d/xarray/cf/mesh/structured.py @@ -0,0 +1,166 @@ +from vtkmodules.vtkCommonCore import vtkPoints +from vtkmodules.vtkCommonDataModel import vtkStructuredGrid + +from ..coords.parametric_vertical import get_formula as get_z_formula +from ..coords.index_mapping import get_formula as get_coords_formula +from ..coords.convert import point_insert + + +def generate_mesh(metadata, dimensions, time_index, spherical): + # Data location and extend depend if we can extrapolate cell locations + # bounds or uniform allow to define cell bounds + if metadata.coords_has_bounds: + print(" => structured: bounds") + return generate_bound_cells(metadata, dimensions, time_index, spherical) + + if metadata.uniform_spacing: + print(" => structured: uniform spacing") + return generate_uniform_cells(metadata, dimensions, time_index, spherical) + + # We can only figure out the point location + print(" => structured: on points") + return generate_mesh_points(metadata, dimensions, time_index, spherical) + + +def generate_uniform_cells(metadata, dimensions, time_index, spherical): + data_location = "cell_data" + assert spherical + + # 2D or 3D + dims_size = len(dimensions) + assert dims_size == 2 or dims_size == 3 + + # extract extent, origin, spacing + origin = [0, 0, 0] + spacing = [1, 1, 1] + extent = [0, 0, 0, 0, 0, 0] + n_points = 1 + dims_origin_spacing = [] + for idx in range(len(dimensions)): + array = metadata.xr_dataset[dimensions[-(1 + idx)]] + + # Fill in reverse order (t, z, y, x) => [0, x.size, 0, y.size, 0, z.size] + # Use size as end of extent as we are adding 1 point for map data on cell + extent[idx * 2 + 1] = array.size + n_points *= array.size + + # axis origin/spacing + axis_spacing = (array[-1].values - array[0].values) / (array.size - 1) + axis_origin = float(array[0].values) - axis_spacing * 0.5 + + # update global origin/spacing + origin[idx] = axis_origin + spacing[idx] = axis_spacing + + # Add (origin, spacing) for missing coords + while len(dims_origin_spacing) < 3: + dims_origin_spacing.insert(0, (0, 1)) + + # debug + # print(f"{dimensions=}") + # print(f"{extent=}") + # print(f"{n_points=}") + # print(f"{dimensions=}") + # print(f"{extent=}") + # print(f"{dims_origin_spacing=}") + + # Points + vtk_points = vtkPoints() + vtk_points.SetDataTypeToDouble() + vtk_points.Allocate(n_points) + + # Check if direct coord mapping + for k in range(extent[5] + 1): + z = origin[2] + k * spacing[2] + z = metadata.vertical_bias + metadata.vertical_scale * z + for j in range(extent[3] + 1): + lat = origin[1] + j * spacing[1] + for i in range(extent[1] + 1): + lon = origin[0] + i * spacing[0] + point_insert(vtk_points, spherical, lon, lat, z) + + # Mesh + mesh = vtkStructuredGrid() + mesh.points = vtk_points + mesh.extent = extent + + return mesh, data_location + + +def generate_bound_cells(metadata, dimensions, time_index, spherical): + data_location = "cell_data" + raise NotImplementedError("structured::generate_bound_cells") + return False, data_location + + +def generate_mesh_points(metadata, dimensions, time_index, spherical): + data_location = "point_data" + + # 2D or 3D + dims_size = len(dimensions) + assert dims_size == 2 or dims_size == 3 + + extent = [0, 0, 0, 0, 0, 0] + n_points = 1 + for idx in range(len(dimensions)): + array = metadata.xr_dataset[dimensions[-(1 + idx)]] + # Fill in reverse order (t, z, y, x) => [0, x.size, 0, y.size, 0, z.size] + # And extent include both index so (len-1) + extent[idx * 2 + 1] = array.size - 1 + n_points *= array.size + + # Points + vtk_points = vtkPoints() + vtk_points.SetDataTypeToDouble() + vtk_points.Allocate(n_points) + + # debug + # print(f"{dimensions=}") + # print(f"{extent=}") + # print(f"{n_points=}") + + # Check if direct coord mapping + if metadata.coords_1d and metadata.use_coords(dimensions): + if dims_size == 2: # 2D + x_array = metadata.xr_dataset[metadata.longitude].values + y_array = metadata.xr_dataset[metadata.latitude].values + z = 0 + for j in range(extent[3] + 1): + lat = y_array[j] + for i in range(extent[1] + 1): + lon = x_array[i] + point_insert(vtk_points, spherical, lon, lat, z) + else: # 3D + x_array = metadata.xr_dataset[metadata.longitude].values + y_array = metadata.xr_dataset[metadata.latitude].values + z_array = metadata.xr_dataset[metadata.vertical].values + for k in range(extent[5] + 1): + z = z_array[k] + for j in range(extent[3] + 1): + lat = y_array[j] + for i in range(extent[1] + 1): + lon = x_array[i] + point_insert(vtk_points, spherical, lon, lat, z) + else: + # need some index mapping + z_formula = get_z_formula( + metadata.xr_dataset, + metadata.vertical, + bias=metadata.vertical_bias, + scale=metadata.vertical_scale, + ) + coords_formula = get_coords_formula(metadata, dimensions) + for k in range(extent[5] + 1): + for j in range(extent[3] + 1): + for i in range(extent[1] + 1): + lon, lat = coords_formula(i=i, j=j, k=k) + z = z_formula(i=i, j=j, k=k, n=time_index) + # print(f"{time_index=}, {k=}, {j=}, {i=} = {lon=}, {lat=}, {z=}") + point_insert(vtk_points, spherical, lon, lat, z) + + # Mesh + mesh = vtkStructuredGrid() + mesh.points = vtk_points + mesh.extent = extent + + return mesh, data_location diff --git a/pan3d/xarray/cf/mesh/uniform.py b/pan3d/xarray/cf/mesh/uniform.py new file mode 100644 index 0000000..0d070ae --- /dev/null +++ b/pan3d/xarray/cf/mesh/uniform.py @@ -0,0 +1,30 @@ +from vtkmodules.vtkCommonDataModel import vtkImageData + + +def generate_mesh(metadata, dimensions, time_index): + data_location = "cell_data" + + # data to capture + origin = [0, 0, 0] + spacing = [1, 1, 1] + extent = [0, 0, 0, 0, 0, 0] + + # extract information from dimensions + for idx in range(len(dimensions)): + array = metadata.xr_dataset[dimensions[-(1 + idx)]] + # Fill in reverse order (t, z, y, x) => [0, x.size, 0, y.size, 0, z.size] + # Use size as end of extent as we are adding 1 point for map data on cell + extent[idx * 2 + 1] = array.size + + # axis origin/spacing + axis_spacing = (array[-1].values - array[0].values) / (array.size - 1) + axis_origin = float(array[0].values) - axis_spacing * 0.5 + + # update global origin/spacing + origin[idx] = axis_origin + spacing[idx] = axis_spacing + + # Configure mesh + mesh = vtkImageData(origin=origin, spacing=spacing, extent=extent) + + return mesh, data_location diff --git a/pan3d/xarray/cf/mesh/unstructured.py b/pan3d/xarray/cf/mesh/unstructured.py new file mode 100644 index 0000000..8991645 --- /dev/null +++ b/pan3d/xarray/cf/mesh/unstructured.py @@ -0,0 +1,4 @@ +def generate_mesh(metadata, dimensions, spherical): + print(" => unstructured: cell_data") + raise NotImplementedError("unstructured::generate_mesh") + return False, "cell_data" diff --git a/pan3d/xarray/cf/usecases/__init__.py b/pan3d/xarray/cf/usecases/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pan3d/xarray/cf/usecases/rom.py b/pan3d/xarray/cf/usecases/rom.py new file mode 100644 index 0000000..4f8d38a --- /dev/null +++ b/pan3d/xarray/cf/usecases/rom.py @@ -0,0 +1,85 @@ +import math +from xarray.tutorial import open_dataset +from vtkmodules.vtkCommonCore import vtkPoints, vtkMath +from vtkmodules.vtkCommonDataModel import vtkStructuredGrid +from vtkmodules.vtkIOLegacy import vtkDataSetWriter + +from ..coords.parametric_vertical import get_formula + + +class Formula_ocean_s_coordinate_g2: + def __init__(self, xr_dataset, s, c, eta, depth, depth_c): + self.xr_dataset = xr_dataset + self.s = self.xr_dataset[s].values + self.c = self.xr_dataset[c].values + self.eta = self.xr_dataset[eta].values + self.depth = self.xr_dataset[depth].values + self.depth_c = self.xr_dataset[depth_c].values + + def __call__(self, n, k, j, i): + return self.eta[n, j, i] + (self.eta[n, j, i] + self.depth[j, i]) * ( + self.depth_c * self.s[k] + self.depth[j, i] * self.c[k] + ) / (self.depth_c + self.depth[j, i]) + + +def main(): + ds = open_dataset("ROMS_example") + + t = 0 + salt = ds.salt # (ocean_time, s_rho, eta_rho, xi_rho) + lon_rho = ds.lon_rho.values # (eta_rho, xi_rho) + lat_rho = ds.lat_rho.values # (eta_rho, xi_rho) + + # 1D coords + eta_rho = ds.eta_rho + xi_rho = ds.xi_rho + s_rho = ( + ds.s_rho + ) # (s_rho) | ocean_s_coordinate_g2 | formula_terms= "s: s_rho C: Cs_r eta: zeta depth: h depth_c: hc" + + formula = get_formula(ds, "s_rho") + + # ocean_s_coordinate_g2 + # z(n,k,j,i) = eta(n,j,i) + (eta(n,j,i) + depth(j,i)) * S(k,j,i) + # S(k,j,i) = (depth_c * s(k) + depth(j,i) * C(k)) / (depth_c + depth(j,i)) + # formula_terms = "s: var1 C: var2 eta: var3 depth: var4 depth_c: var5" + # - s: s_rho + # - C: Cs_r + # - eta: zeta + # - depth: h + # - depth_c: hc + + earth_radius = 6378137 # in meters + bias = earth_radius + scale = 100 + + n_points = s_rho.size * eta_rho.size * xi_rho.size + points = vtkPoints() + points.SetDataTypeToDouble() + points.Allocate(n_points) + + for k in range(s_rho.size): + for j in range(eta_rho.size): + for i in range(xi_rho.size): + lon = vtkMath.RadiansFromDegrees(lon_rho[j, i]) + lat = vtkMath.RadiansFromDegrees(lat_rho[j, i]) + h = bias + scale * formula(t, k, j, i) + points.InsertNextPoint( + h * math.cos(lon) * math.cos(lat), + h * math.sin(lon) * math.cos(lat), + h * math.sin(lat), + ) + + mesh = vtkStructuredGrid() + mesh.SetExtent(0, xi_rho.size - 1, 0, eta_rho.size - 1, 0, s_rho.size - 1) + mesh.points = points + mesh.point_data["salt"] = salt[t].values.ravel() + + writer = vtkDataSetWriter() + writer.SetInputData(mesh) + writer.SetFileName("rom_salt.vtk") + writer.Write() + + +if __name__ == "__main__": + main() diff --git a/pan3d/xarray/cf/utils.py b/pan3d/xarray/cf/utils.py new file mode 100644 index 0000000..44ac7fb --- /dev/null +++ b/pan3d/xarray/cf/utils.py @@ -0,0 +1,472 @@ +import numpy as np +from . import constants +from .coords import coords_mesh_rectilinear, coords_mesh_spherical + + +def to_bounds(xr_dataset, coord_name): + xr_array = xr_dataset[coord_name] + if xr_array.dims == (coord_name,): + bound_array_name = xr_array.attrs.get("bounds") + if bound_array_name: + # TODO + # xr_dataset[bound_array_name] + return bound_array_name + elif np.issubdtype(xr_array.dtype, np.number): + origin = float(xr_array.values[0]) + spacing = (float(xr_array.values[-1]) - origin) / (xr_array.size - 1) + + if xr_array.attrs.get("positive", "") == "down": + spacing *= -1 + + bounds = np.zeros(xr_array.size + 1, dtype=np.double) + bounds[0] = float(xr_array.values[0]) - 0.5 * spacing + for i in range(1, xr_array.size): + center = 0.5 * float(xr_array.values[i - 1] + xr_array.values[i]) + bounds[i] = center + bounds[-1] = float(xr_array.values[-1]) + 0.5 * spacing + + return bounds + + # fake coordinates + return np.linspace( + -0.5, xr_array.size - 0.5, num=xr_array.size + 1, dtype=np.double + ) + + +class Coordinates: + def __init__(self, xr_dataset, vertical_scale=1, vertical_bias=0): + self.xr_dataset = xr_dataset + self.vertical_scale = vertical_scale + self.vertical_bias = vertical_bias + self.longitude = None + self.latitude = None + self.vertical = None + self.time = None + + for coord_name in xr_dataset.coords: + unit = constants.Units.extract(xr_dataset[coord_name]) + if unit == constants.Units.LONGITUDE_UNITS: + self.longitude = coord_name + elif unit == constants.Units.LATITUDE_UNITS: + self.latitude = coord_name + elif unit == constants.Units.VERTICAL_UNITS: + self.vertical = coord_name + elif unit == constants.Units.TIME_UNITS: + self.time = coord_name + else: + print(f"Skip coord {coord_name}") + + @property + def extent(self): + x_size = self.longitude_array.size if self.longitude else 0 + y_size = self.latitude_array.size if self.latitude else 0 + z_size = self.vertical_array.size if self.vertical else 0 + + return [ + 0, + x_size, + 0, + y_size, + 0, + z_size, + ] + + @property + def origin(self): + print(self.longitude, "=>", self.longitude_array.dims) + print(self.latitude, "=>", self.latitude_array.dims) + x = self.longitude_array.values[0] if self.longitude else 0 + y = self.latitude_array.values[0] if self.latitude else 0 + z = self.vertical_array.values[0] if self.vertical else 0 + return [x, y, z] + + @property + def spacing(self): + origin = self.origin + sx, sy, sz = 1, 1, 1 + if self.longitude: + coord = self.longitude_array + sx = (float(coord.values[-1]) - origin[0]) / (coord.size - 1) + if coord.attrs.get("positive", "") == "down": + sx *= -1 + if self.latitude: + coord = self.latitude_array + sy = (float(coord.values[-1]) - origin[1]) / (coord.size - 1) + if coord.attrs.get("positive", "") == "down": + sy *= -1 + if self.vertical: + coord = self.vertical_array + sz = (float(coord.values[-1]) - origin[2]) / (coord.size - 1) + if coord.attrs.get("positive", "") == "down": + sz *= -1 + + return [sx, sy, sz] + + @property + def longitude_array(self): + if self.longitude is None: + return None + + return self.xr_dataset[self.longitude] + + @property + def longitude_bounds(self): + if self.longitude is None: + return None + + return to_bounds(self.xr_dataset, self.longitude) + + @property + def latitude_array(self): + if self.latitude is None: + return None + + return self.xr_dataset[self.latitude] + + @property + def latitude_bounds(self): + if self.latitude is None: + return None + + return to_bounds(self.xr_dataset, self.latitude) + + @property + def vertical_array(self): + if self.vertical is None: + return None + + return self.xr_dataset[self.vertical] + + @property + def vertical_bounds(self): + if self.vertical is None: + return None + + return to_bounds(self.xr_dataset, self.vertical) + + @property + def time_array(self): + if self.time is None: + return None + + return self.xr_dataset[self.time] + + +class XArrayDataSetCFHelper: + def __init__(self, xr_dataset): + self.xr_dataset = xr_dataset + self._active_array_names = {} + self._active_dimensions = None + self._cached_dimensions_info = {} + self._x = None + self._y = None + self._z = None + self._t = None + + @property + def available_arrays(self) -> list[str]: + if self._active_dimensions is None: + return [ + n + for n in self.xr_dataset.data_vars.keys() + if "bnd" not in n and "bound" not in n + ] + + # List only arrays with the same active dimensions + return [ + n + for n in self.xr_dataset.data_vars.keys() + if self.xr_dataset[n].dims == self._active_dimensions + ] + + @property + def array_selection(self) -> set[str]: + """return the list of arrays that are currently selected to be added to the generated VTK mesh""" + return self._active_array_names + + @array_selection.setter + def array_selection(self, array_names=None): + """update the list of arrays to load on the generated VTK mesh""" + if array_names is None: + array_names = [] + + # Filter with only valid arrays + allowed_names = set(self.available_arrays) + new_names = set([n for n in array_names if n in allowed_names]) + + # Do we have a change + if new_names != self._active_array_names: + if len(new_names): + # Check compatibility + self._active_dimensions = None + compatible_set = set() + for name in new_names: + if self._active_dimensions is None: + self._active_dimensions = self.xr_dataset[name].dims + compatible_set.add(name) + elif self.xr_dataset[name].dims == self._active_dimensions: + compatible_set.add(name) + + self._active_array_names = compatible_set + else: + # no selection + self._active_array_names = new_names + self._active_dimensions = None + + # Update dimension mapping + self._compute_dimensions_information() + + return True + + return False + + def _compute_dimensions_information(self): + # reset + self._cached_dimensions_info.clear() + self._x = None + self._y = None + self._z = None + self._t = None + self._t_index = 0 + + # look deeper if possible + if self._active_dimensions is None: + return + + # extract field dimension information + for dim_name in self._active_dimensions: + info = constants.DimensionInformation(self.xr_dataset, dim_name) + self._cached_dimensions_info[dim_name] = info + + # track dimension orientation + if info.unit == constants.Units.LONGITUDE_UNITS: + self._x = dim_name + if info.unit == constants.Units.LATITUDE_UNITS: + self._y = dim_name + if info.unit == constants.Units.VERTICAL_UNITS: + self._z = dim_name + if info.unit == constants.Units.TIME_UNITS: + self._t = dim_name + + # extract coord dimension from dataset + for coord_name in self.xr_dataset.coords: + if coord_name not in self._cached_dimensions_info: + info = constants.DimensionInformation(self.xr_dataset, coord_name) + self._cached_dimensions_info[coord_name] = info + + # track dimension orientation + if info.unit == constants.Units.LONGITUDE_UNITS: + self._x = coord_name + if info.unit == constants.Units.LATITUDE_UNITS: + self._y = coord_name + if info.unit == constants.Units.VERTICAL_UNITS: + self._z = coord_name + + @property + def x(self): + """return the name that is currently mapped to the X axis""" + return self._x + + @property + def x_size(self): + """return the size of the coordinate used for the X axis""" + if self._x is None: + return 0 + return int(self.xr_dataset[self._x].size) + + @property + def x_info(self): + """return the X coordinate information if available""" + if self._x is None: + return None + return self._cached_dimensions_info.get(self._x) + + @property + def y(self): + """return the name that is currently mapped to the Y axis""" + return self._y + + @property + def y_size(self): + """return the size of the coordinate used for the Y axis""" + if self._y is None: + return 0 + return int(self.xr_dataset[self._y].size) + + @property + def y_info(self): + """return the Y coordinate information if available""" + if self._y is None: + return None + return self._cached_dimensions_info.get(self._y) + + @property + def z(self): + """return the name that is currently mapped to the Z axis""" + return self._z + + @property + def z_size(self): + """return the size of the coordinate used for the Z axis""" + if self._z is None: + return 0 + return int(self.xr_dataset[self._z].size) + + @property + def z_info(self): + """return the Z coordinate information if available""" + if self._z is None: + return None + return self._cached_dimensions_info.get(self._z) + + @property + def t(self): + """return the name that is currently mapped to the time axis""" + return self._t + + @property + def t_size(self): + """return the size of the coordinate used for the time axis""" + if self._t is None: + return 0 + return int(self.xr_dataset[self._t].size) + + @property + def t_info(self): + """return the T coordinate information if available""" + if self._t is None: + return None + return self._cached_dimensions_info.get(self._t) + + def get_info(self, name): + info = self._cached_dimensions_info.get(name) + if info is None: + info = self._cached_dimensions_info.setdefault( + name, constants.DimensionInformation(self.xr_dataset, name) + ) + return info + + @property + def mesh(self): + if not self._cached_dimensions_info: + print("no cache") + return None + + if len(self.array_selection) == 0: + print("no field") + return None + + # Most coordinate variables are defined by a variable the same name as the + # dimension they describe. Those are handled elsewhere. This class handles + # dependent variables that define coordinates that are not the same name as + # any dimension. This is only done when the coordinates cannot be expressed + # as a 1D table lookup from dimension index. This occurs in only two places + # in the CF convention. First, it happens for 2D coordinate variables with + # 4-sided cells. This is basically when the grid is a 2D curvilinear grid. + # Each i,j topological point can be placed anywhere in space. Second, it + # happens for multi-dimensional coordinate variables with p-sided cells. + # These are unstructured collections of polygons. + + # we need at least a 2D surface + if self.x is None and self.y is None: + print("no (x,y)") + return None + + if len(self.x_info.dims) != len(self.y_info.dims): + msg = f"Number of dimensions in different coordinate arrays do not match (x:{self.x_info.dims}, y:{self.y_info.dims})" + raise ValueError(msg) + + field_name = next(iter(self.array_selection)) + coord_mode = constants.CoordinateTypes.get_coordinate_type( + self.xr_dataset, field_name, use_spherical=True + ) + print(f"{coord_mode=}") + mesh_type = constants.MeshTypes.from_coord_type(coord_mode) + mesh = mesh_type.new() + + # ------------------------------------------------- + # Generate mesh structure + # ------------------------------------------------- + if mesh_type == constants.MeshTypes.VTK_IMAGE_DATA: + coords_mesh_rectilinear.add_imagedata(mesh, Coordinates(self.xr_dataset)) + elif mesh_type == constants.MeshTypes.VTK_RECTILINEAR_GRID: + if coord_mode in { + constants.CoordinateTypes.EUCLIDEAN_PSIDED_CELLS, + constants.CoordinateTypes.SPHERICAL_PSIDED_CELLS, + }: + # There is no sensible way to store p-sided cells in a structured grid. + # Just fake some coordinates (related to ParaView bug #11543). + coords_mesh_rectilinear.fake_rectilinear(mesh, self.xr_dataset) + else: + coords_mesh_rectilinear.add_rectilinear(mesh, self.xr_dataset) + elif mesh_type in { + constants.MeshTypes.VTK_STRUCTURED_GRID, + constants.MeshTypes.VTK_UNSTRUCTURED_GRID, + }: + if coord_mode in { + constants.CoordinateTypes.UNIFORM_RECTILINEAR, + constants.CoordinateTypes.NONUNIFORM_RECTILINEAR, + }: + if mesh_type == constants.MeshTypes.VTK_STRUCTURED_GRID: + coords_mesh_rectilinear.add_1d_structured(mesh, self.xr_dataset) + else: + coords_mesh_rectilinear.add_1d_unstructured(mesh, self.xr_dataset) + elif coord_mode == constants.CoordinateTypes.REGULAR_SPHERICAL: + if mesh_type == constants.MeshTypes.VTK_STRUCTURED_GRID: + coords_mesh_spherical.add_1d_structured( + mesh, Coordinates(self.xr_dataset) + ) + else: + coords_mesh_spherical.add_1d_unstructured(mesh, self.xr_dataset) + elif coord_mode in { + constants.CoordinateTypes.EUCLIDEAN_2D, + constants.CoordinateTypes.EUCLIDEAN_4SIDED_CELLS, + }: + if mesh_type == constants.MeshTypes.VTK_STRUCTURED_GRID: + coords_mesh_rectilinear.add_2d_structured(mesh, self.xr_dataset) + else: + coords_mesh_rectilinear.add_2d_unstructured(mesh, self.xr_dataset) + elif coord_mode in { + constants.CoordinateTypes.SPHERICAL_2D, + constants.CoordinateTypes.SPHERICAL_4SIDED_CELLS, + }: + + coords_mesh_spherical.add_2d_structured(mesh, self.xr_dataset) + elif coord_mode in { + constants.CoordinateTypes.EUCLIDEAN_PSIDED_CELLS, + constants.CoordinateTypes.SPHERICAL_PSIDED_CELLS, + }: + # There is no sensible way to store p-sided cells in a structured grid. + # Just fake some coordinates (ParaView bug #11543). + coords_mesh_rectilinear.fake_structured(mesh, self.xr_dataset) + else: + msg = f"Unknown coordinate type {coord_mode}" + raise ValueError(msg) + elif mesh_type == constants.MeshTypes.VTK_UNSTRUCTURED_GRID: + if coord_mode in { + constants.CoordinateTypes.UNIFORM_RECTILINEAR, + constants.CoordinateTypes.NONUNIFORM_RECTILINEAR, + }: + coords_mesh_rectilinear.add_1d_unstructured(mesh, self.xr_dataset) + + # ------------------------------------------------- + # Add fields to mesh + # ------------------------------------------------- + for field_name in self.array_selection: + array = self.xr_dataset[field_name] + + # slice array with current time + if self.t: + array = array.isel({self.t: self._t_index}) + + if coord_mode.use_point_data: + mesh.point_data[field_name] = array.values.ravel(order="C") + else: + mesh.cell_data[field_name] = array.values.ravel(order="C") + + # ------------------------------------------------- + # Add time metadata + # ------------------------------------------------- + # TODO + + return mesh From afc12360ddaab0c9eebb1ba2bad05dc3395c137c Mon Sep 17 00:00:00 2001 From: Sebastien Jourdain Date: Fri, 17 Jan 2025 17:37:09 -0700 Subject: [PATCH 06/10] wip: try to enable cell-data --- pan3d/xarray/cf/coords/meta.py | 14 ++++++++++---- pan3d/xarray/cf/mesh/structured.py | 3 +++ 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/pan3d/xarray/cf/coords/meta.py b/pan3d/xarray/cf/coords/meta.py index 871ccf7..16f8647 100644 --- a/pan3d/xarray/cf/coords/meta.py +++ b/pan3d/xarray/cf/coords/meta.py @@ -327,8 +327,10 @@ def __repr__(self): - vertical : {self.vertical} - time : {self.time} Computed: - - has_bound : {self.coords_has_bounds} - - uniform : {self.uniform_spacing} + - has_bound : {self.coords_has_bounds} + - uniform (2D) : {self.uniform_lat_lon} + - uniform (all) : {self.uniform_spacing} + - coords 1d : {self.coords_1d} Data: {data_str} """ @@ -343,13 +345,17 @@ def coords_has_bounds(self): return lon_bnd is not None and lat_bnd is not None @property - def uniform_spacing(self): - uniform = ( + def uniform_lat_lon(self): + return ( self.coords_1d and is_uniform(self.xr_dataset[self.longitude].values) and is_uniform(self.xr_dataset[self.latitude].values) ) + @property + def uniform_spacing(self): + uniform = self.uniform_lat_lon + if self.vertical is not None and uniform: uniform = len(self.xr_dataset[self.vertical].dims) == 1 and is_uniform( self.xr_dataset[self.vertical].values diff --git a/pan3d/xarray/cf/mesh/structured.py b/pan3d/xarray/cf/mesh/structured.py index 0759cbe..f7c975a 100644 --- a/pan3d/xarray/cf/mesh/structured.py +++ b/pan3d/xarray/cf/mesh/structured.py @@ -96,6 +96,9 @@ def generate_bound_cells(metadata, dimensions, time_index, spherical): def generate_mesh_points(metadata, dimensions, time_index, spherical): data_location = "point_data" + if metadata.coords_1d and metadata.uniform_spacing: + print("Should put data on cell!") + # 2D or 3D dims_size = len(dimensions) assert dims_size == 2 or dims_size == 3 From 85eae17ccc382df84f28c88a4a4d8de4ec0a3591 Mon Sep 17 00:00:00 2001 From: Sebastien Jourdain Date: Thu, 23 Jan 2025 13:36:59 -0700 Subject: [PATCH 07/10] wip: integrate new python reader --- pan3d/ui/preview.py | 148 +++-- pan3d/viewers/preview.py | 40 +- pan3d/xarray/cf/coords/convert.py | 44 ++ .../cf/coords/coords_mesh_rectilinear.py | 43 -- .../xarray/cf/coords/coords_mesh_spherical.py | 64 -- .../cf/coords/coords_mesh_unstructured.py | 10 - pan3d/xarray/cf/coords/meta.py | 42 +- pan3d/xarray/cf/mesh/rectilinear.py | 9 +- pan3d/xarray/cf/mesh/structured.py | 67 +- pan3d/xarray/cf/reader.py | 584 ++++++++++++++++++ 10 files changed, 851 insertions(+), 200 deletions(-) delete mode 100644 pan3d/xarray/cf/coords/coords_mesh_rectilinear.py delete mode 100644 pan3d/xarray/cf/coords/coords_mesh_spherical.py delete mode 100644 pan3d/xarray/cf/coords/coords_mesh_unstructured.py create mode 100644 pan3d/xarray/cf/reader.py diff --git a/pan3d/ui/preview.py b/pan3d/ui/preview.py index 17944c3..23f0feb 100644 --- a/pan3d/ui/preview.py +++ b/pan3d/ui/preview.py @@ -264,6 +264,7 @@ def __init__(self, source, update_rendering): self.state.setdefault("max_time_width", 0) self.state.setdefault("max_time_index_width", 0) self.state.setdefault("dataset_bounds", [0, 1, 0, 1, 0, 1]) + self.state.setdefault("projection_mode", "spherical") with self.content: v3.VSelect( @@ -576,63 +577,115 @@ def __init__(self, source, update_rendering): type="number", ) - # Actor scaling - with v3.VTooltip(text="Representation scaling"): + # Projection mode + scaling + with v3.VTooltip( + text=( + "projection_mode === 'spherical' ? `Spherical Projection: scaling=${spherical_scale} bias=${spherical_bias}` : 'Euclidian Projection'", + ) + ): with html.Template(v_slot_activator="{ props }"): - with v3.VRow( + with v3.VCol( v_bind="props", no_gutter=True, - classes="align-center my-0 mx-0 border-b-thin", + classes="align-center pa-0 my-0 mx-0 border-b-thin", ): - v3.VIcon( - "mdi-ruler-square", - classes="ml-2 text-medium-emphasis", + v3.VSelect( + prepend_inner_icon=( + "projection_mode === 'spherical' ? 'mdi-earth' : 'mdi-earth-box'", + ), + v_model=("projection_mode", "spherical"), + items=("['spherical', 'euclidian']",), + hide_details=True, + density="compact", + flat=True, + variant="solo", + classes="mx-n1", ) - with v3.VCol(classes="pa-0", v_if="axis_names?.[0]"): - v3.VTextField( - v_model=("scale_x", 1), + with v3.VCol( + no_gutter=True, + classes="align-center my-0 mx-0 pa-0", + v_if="projection_mode === 'spherical'", + ): + v3.VSlider( + prepend_icon="mdi-radius-outline", + v_model=("spherical_bias", 6378137), + min=1, + max=6378137, + step=100, hide_details=True, density="compact", flat=True, variant="solo", - reverse=True, - raw_attrs=[ - 'pattern="^\d*(\.\d)?$"', # noqa: W605 - 'min="0.001"', - 'step="0.1"', - ], - type="number", + classes="mr-4", + end=self.server.controller.view_update, ) - with v3.VCol(classes="pa-0", v_if="axis_names?.[1]"): - v3.VTextField( - v_model=("scale_y", 1), + v3.VSlider( + prepend_icon="mdi-magnify", + v_model=("spherical_scale", 100), + min=1, + max=1000, + step=20, hide_details=True, density="compact", flat=True, variant="solo", - reverse=True, - raw_attrs=[ - 'pattern="^\d*(\.\d)?$"', # noqa: W605 - 'min="0.001"', - 'step="0.1"', - ], - type="number", + classes="mr-4", + end=self.server.controller.view_update, ) - with v3.VCol(classes="pa-0", v_if="axis_names?.[2]"): - v3.VTextField( - v_model=("scale_z", 1), - hide_details=True, - density="compact", - flat=True, - variant="solo", - reverse=True, - raw_attrs=[ - 'pattern="^\d*(\.\d)?$"', # noqa: W605 - 'min="0.001"', - 'step="0.1"', - ], - type="number", + with v3.VRow( + no_gutter=True, + classes="align-center my-0 mx-0", + v_else=True, + ): + v3.VIcon( + "mdi-ruler-square", + classes="ml-2 text-medium-emphasis", ) + with v3.VCol(classes="pa-0", v_if="axis_names?.[0]"): + v3.VTextField( + v_model=("scale_x", 1), + hide_details=True, + density="compact", + flat=True, + variant="solo", + reverse=True, + raw_attrs=[ + 'pattern="^\d*(\.\d)?$"', # noqa: W605 + 'min="0.001"', + 'step="0.1"', + ], + type="number", + ) + with v3.VCol(classes="pa-0", v_if="axis_names?.[1]"): + v3.VTextField( + v_model=("scale_y", 1), + hide_details=True, + density="compact", + flat=True, + variant="solo", + reverse=True, + raw_attrs=[ + 'pattern="^\d*(\.\d)?$"', # noqa: W605 + 'min="0.001"', + 'step="0.1"', + ], + type="number", + ) + with v3.VCol(classes="pa-0", v_if="axis_names?.[2]"): + v3.VTextField( + v_model=("scale_z", 1), + hide_details=True, + density="compact", + flat=True, + variant="solo", + reverse=True, + raw_attrs=[ + 'pattern="^\d*(\.\d)?$"', # noqa: W605 + 'min="0.001"', + 'step="0.1"', + ], + type="number", + ) # Time slider with v3.VTooltip( @@ -678,6 +731,11 @@ def update_from_source(self, source=None): self.state.color_by = None self.state.axis_names = [source.x, source.y, source.z] self.state.slice_extents = source.slice_extents + self.state.projection_mode = ( + "spherical" if source.spherical else "euclidean" + ) + self.state.spherical_bias = source.vertical_bias + self.state.spherical_scale = source.vertical_scale slices = source.slices for axis in XYZ: # default @@ -803,6 +861,14 @@ def _on_array_selection(self, data_arrays, **_): self.source.arrays = data_arrays + @change("spherical_bias", "spherical_scale", "projection_mode") + def _on_projection_change( + self, spherical_bias, spherical_scale, projection_mode, **_ + ): + self.source.spherical = projection_mode == "spherical" + self.source.vertical_bias = spherical_bias + self.source.vertical_scale = spherical_scale + class ControlPanel(v3.VCard): def __init__( diff --git a/pan3d/viewers/preview.py b/pan3d/viewers/preview.py index 38ed553..83f8dde 100644 --- a/pan3d/viewers/preview.py +++ b/pan3d/viewers/preview.py @@ -1,4 +1,3 @@ -import os from vtkmodules.vtkInteractionWidgets import vtkOrientationMarkerWidget from vtkmodules.vtkRenderingAnnotation import vtkAxesActor from vtkmodules.vtkCommonCore import vtkLookupTable @@ -26,7 +25,9 @@ from trame.ui.vuetify3 import VAppLayout from trame.widgets import vuetify3 as v3 -from pan3d.xarray.algorithm import vtkXArrayRectilinearSource +from pan3d.xarray.cf.reader import vtkXArrayCFSource + +# from pan3d.xarray.algorithm import vtkXArrayRectilinearSource from pan3d.utils.constants import has_gpu from pan3d.utils.convert import update_camera, to_image, to_float @@ -130,15 +131,17 @@ def _setup_vtk(self): self.interactor.GetInteractorStyle().SetCurrentStyleToTrackballCamera() self.lut = vtkLookupTable() - if "PAN3D_USE_VTK_XARRAY" in os.environ: - try: - from pan3d.xarray.vtk import vtkXArraySource - self.source = vtkXArraySource(input=self.xarray) - except ImportError: - self.source = vtkXArrayRectilinearSource(input=self.xarray) - else: - self.source = vtkXArrayRectilinearSource(input=self.xarray) + self.source = vtkXArrayCFSource(input=self.xarray) + # if "PAN3D_USE_VTK_XARRAY" in os.environ: + # try: + # from pan3d.xarray.vtk import vtkXArraySource + + # self.source = vtkXArraySource(input=self.xarray) + # except ImportError: + # self.source = vtkXArrayRectilinearSource(input=self.xarray) + # else: + # self.source = vtkXArrayRectilinearSource(input=self.xarray) # Need explicit geometry extraction when used with WASM self.geometry = vtkDataSetSurfaceFilter( @@ -319,13 +322,16 @@ def _on_color_preset( self.ctrl.view_update() - @change("scale_x", "scale_y", "scale_z") - def _on_scale_change(self, scale_x, scale_y, scale_z, **_): - self.actor.SetScale( - to_float(scale_x), - to_float(scale_y), - to_float(scale_z), - ) + @change("scale_x", "scale_y", "scale_z", "projection_mode") + def _on_scale_change(self, scale_x, scale_y, scale_z, projection_mode, **_): + if projection_mode == "spherical": + self.actor.SetScale(1, 1, 1) + else: + self.actor.SetScale( + to_float(scale_x), + to_float(scale_y), + to_float(scale_z), + ) if self.state.import_pending: return diff --git a/pan3d/xarray/cf/coords/convert.py b/pan3d/xarray/cf/coords/convert.py index ed082e6..377ae4d 100644 --- a/pan3d/xarray/cf/coords/convert.py +++ b/pan3d/xarray/cf/coords/convert.py @@ -1,4 +1,48 @@ import math +import numpy as np + + +def extract_uniform_info(array): + origin = float(array[0]) + spacing = (float(array[-1]) - origin) / (array.size - 1) + tolerance = 0.01 * spacing + + for i in range(array.size): + expected = origin + i * spacing + truth = float(array[i]) + if not np.isclose(expected, truth, atol=tolerance): + return None + + return (origin, spacing, array.size) + + +def is_uniform(array): + return extract_uniform_info(array) is not None + + +def cell_center_to_point(in_array): + uniform_data = extract_uniform_info(in_array) + if uniform_data is not None: + origin, spacing, size = uniform_data + return np.linspace( + start=origin - spacing * 0.5, + stop=origin - spacing * 0.5 + size * spacing, + num=size + 1, + endpoint=True, + dtype=np.double, + ) + + # generate fake coords + n_cells = in_array.size + n_points = n_cells + 1 + out_array = np.zeros(n_points, dtype=np.double) + for i in range(1, n_cells): + out_array[i] = 0.5 * (in_array[i - 1] + in_array[i]) + + out_array[0] = out_array[1] - 2 * (out_array[1] - in_array[0]) + out_array[-1] = out_array[-2] + 2 * (in_array[-1] - out_array[-2]) + + return out_array def point_insert(vtk_point, spherical, longitude, latitude, vertical): diff --git a/pan3d/xarray/cf/coords/coords_mesh_rectilinear.py b/pan3d/xarray/cf/coords/coords_mesh_rectilinear.py deleted file mode 100644 index 5d91425..0000000 --- a/pan3d/xarray/cf/coords/coords_mesh_rectilinear.py +++ /dev/null @@ -1,43 +0,0 @@ -def add_imagedata(image_data, coords): - print(f"{coords.extent=}") - print(f"{coords.origin=}") - print(f"{coords.spacing=}") - image_data.SetExtent(*coords.extent) - image_data.SetOrigin(*coords.origin) - image_data.SetSpacing(*coords.spacing) - - -def add_rectilinear(rectilinear_grid): - raise NotImplementedError() - - -def fake_rectilinear(rectilinear_grid): - raise NotImplementedError() - - -def add_1d_points(vtk_point, extent): - raise NotImplementedError() - - -def add_2d_points(vtk_point, extent): - raise NotImplementedError() - - -def add_1d_structured(vtk_structured_grid): - raise NotImplementedError() - - -def add_2d_structured(vtk_structured_grid): - raise NotImplementedError() - - -def fake_structured(vtk_structured_grid): - raise NotImplementedError() - - -def add_1d_unstructured(vtk_unstructured_grid, extent): - raise NotImplementedError() - - -def add_2d_unstructured(vtk_unstructured_grid, extent): - raise NotImplementedError() diff --git a/pan3d/xarray/cf/coords/coords_mesh_spherical.py b/pan3d/xarray/cf/coords/coords_mesh_spherical.py deleted file mode 100644 index cfbeec2..0000000 --- a/pan3d/xarray/cf/coords/coords_mesh_spherical.py +++ /dev/null @@ -1,64 +0,0 @@ -import math -from vtkmodules.vtkCommonCore import vtkPoints, vtkMath - - -def add_1d_points(vtk_point, coords): - extent = coords.extent - vtk_point.SetDataTypeToDouble() - vtk_point.Allocate( - (extent[1] - extent[0] + 1) - * (extent[3] - extent[2] + 1) - * (extent[5] - extent[4] + 1) - ) - - # check the height scale and bias - z_scale = coords.vertical_scale - z_bias = coords.vertical_bias - if coords.vertical: - z_min = float(coords.vertical_array.min()) - z_max = float(coords.vertical_array.max()) - if z_min * z_scale + z_bias < 0 or z_max * z_scale + z_bias < 0: - z_bias = -math.min(z_min, z_max) * z_scale - elif (z_scale + z_bias) <= 0: - z_scale = 1 - z_bias = 0 - - # Fill points - longitude = coords.longitude_bounds - latitude = coords.latitude_bounds - vertical = coords.vertical_bounds - for k in range(extent[4], extent[5] + 1): - h = vertical[k] if vertical else 1 - h = h * z_scale + z_bias - for j in range(extent[2], extent[3] + 1): - lat = vtkMath.RadiansFromDegrees(latitude[j]) - for i in range(extent[0], extent[1] + 1): - lon = vtkMath.RadiansFromDegrees(longitude[i]) - vtk_point.InsertNextPoint( - h * math.cos(lon) * math.cos(lat), - h * math.sin(lon) * math.cos(lat), - h * math.sin(lat), - ) - - -def add_2d_points(vtk_point, extent): - raise NotImplementedError() - - -def add_1d_structured(vtk_structured, coords): - vtk_structured.SetExtent(coords.extent) - vtk_points = vtkPoints() - add_1d_points(vtk_points, coords) - vtk_structured.SetPoints(vtk_points) - - -def add_2d_structured(vtk_point, extent): - raise NotImplementedError() - - -def add_1d_unstructured(vtk_point, extent): - raise NotImplementedError() - - -def add_2d_unstructured(vtk_point, extent): - raise NotImplementedError() diff --git a/pan3d/xarray/cf/coords/coords_mesh_unstructured.py b/pan3d/xarray/cf/coords/coords_mesh_unstructured.py deleted file mode 100644 index bc89c61..0000000 --- a/pan3d/xarray/cf/coords/coords_mesh_unstructured.py +++ /dev/null @@ -1,10 +0,0 @@ -def add_structured_cells(unstructured_grid, extent): - raise NotImplementedError() - - -def add_unstructured_rectilinear_coordinates(unstructured_grid, extent): - raise NotImplementedError() - - -def add_unstructured_spherical_coordinates(unstructured_grid, extent): - raise NotImplementedError() diff --git a/pan3d/xarray/cf/coords/meta.py b/pan3d/xarray/cf/coords/meta.py index 16f8647..bb6c8d8 100644 --- a/pan3d/xarray/cf/coords/meta.py +++ b/pan3d/xarray/cf/coords/meta.py @@ -1,6 +1,7 @@ from enum import Enum import numpy as np from pan3d.xarray.cf import mesh +from pan3d.xarray.cf.coords.convert import is_uniform PRESSURE_UNITS = { "bar", @@ -145,20 +146,6 @@ } -def is_uniform(array): - origin = float(array[0]) - spacing = (float(array[-1]) - origin) / (array.size - 1) - tolerance = 0.01 * spacing - - for i in range(array.size): - expected = origin + i * spacing - truth = float(array[i]) - if not np.isclose(expected, truth, atol=tolerance): - return False - - return True - - class CoordinateType(Enum): LONGITUDE = "longitude" LATITUDE = "latitude" @@ -250,11 +237,9 @@ def can_be_time(cls, xr_array): class MetaArrayMapping: def __init__(self, xr_dataset): - self.xr_dataset = xr_dataset - self.conventions = ( - xr_dataset.Conventions if hasattr(xr_dataset, "Conventions") else None - ) + self.xr_dataset = None self.data_arrays = {} + self.conventions = None self.longitude = None self.latitude = None self.vertical = None @@ -263,6 +248,24 @@ def __init__(self, xr_dataset): self.vertical_bias = 6378137 self.vertical_scale = 100 + self.update(xr_dataset) + + def update(self, xr_dataset): + self.xr_dataset = xr_dataset + self.valid = False + self.data_arrays = {} + self.conventions = None + self.longitude = None + self.latitude = None + self.vertical = None + self.time = None + + if xr_dataset is None: + return + + self.conventions = ( + xr_dataset.Conventions if hasattr(xr_dataset, "Conventions") else None + ) if self.conventions is not None: for convention in {"COARDS", "CF-1"}: if convention in self.conventions: @@ -391,6 +394,9 @@ def use_coords(self, dims): return True def get_mesh(self, time_index=0, spherical=True, fields=None): + if self.xr_dataset is None or not fields: + return None + vtk_mesh, data_location = None, None # ensure similar dimension across array names diff --git a/pan3d/xarray/cf/mesh/rectilinear.py b/pan3d/xarray/cf/mesh/rectilinear.py index cdf26e5..09fe0b8 100644 --- a/pan3d/xarray/cf/mesh/rectilinear.py +++ b/pan3d/xarray/cf/mesh/rectilinear.py @@ -1,9 +1,10 @@ import numpy as np from vtkmodules.vtkCommonDataModel import vtkRectilinearGrid +from ..coords.convert import cell_center_to_point def generate_mesh(metadata, dimensions, time_index): - data_location = "point_data" + data_location = "cell_data" extent = [0, 0, 0, 0, 0, 0] empty_coords = np.zeros((1,), dtype=np.double) arrays = [empty_coords, empty_coords, empty_coords] @@ -12,11 +13,11 @@ def generate_mesh(metadata, dimensions, time_index): for idx in range(len(dimensions)): array = metadata.xr_dataset[dimensions[-(1 + idx)]] - arrays[idx] = array.values + arrays[idx] = cell_center_to_point(array.values) # Fill in reverse order (t, z, y, x) => [0, x.size, 0, y.size, 0, z.size] - # And extent include both index so (len-1) - extent[idx * 2 + 1] = array.size - 1 + # And extent include both index but add 1 point so (len-1+1) => len + extent[idx * 2 + 1] = array.size mesh = vtkRectilinearGrid() mesh.x_coordinates = arrays[0] diff --git a/pan3d/xarray/cf/mesh/structured.py b/pan3d/xarray/cf/mesh/structured.py index f7c975a..d0f9d53 100644 --- a/pan3d/xarray/cf/mesh/structured.py +++ b/pan3d/xarray/cf/mesh/structured.py @@ -3,7 +3,7 @@ from ..coords.parametric_vertical import get_formula as get_z_formula from ..coords.index_mapping import get_formula as get_coords_formula -from ..coords.convert import point_insert +from ..coords.convert import point_insert, cell_center_to_point def generate_mesh(metadata, dimensions, time_index, spherical): @@ -96,8 +96,14 @@ def generate_bound_cells(metadata, dimensions, time_index, spherical): def generate_mesh_points(metadata, dimensions, time_index, spherical): data_location = "point_data" - if metadata.coords_1d and metadata.uniform_spacing: - print("Should put data on cell!") + if ( + metadata.coords_1d + and metadata.uniform_lat_lon + and metadata.use_coords(dimensions) + ): + return generate_mesh_points_data_on_cell( + metadata, dimensions, time_index, spherical + ) # 2D or 3D dims_size = len(dimensions) @@ -167,3 +173,58 @@ def generate_mesh_points(metadata, dimensions, time_index, spherical): mesh.extent = extent return mesh, data_location + + +def generate_mesh_points_data_on_cell(metadata, dimensions, time_index, spherical): + data_location = "cell_data" + # 2D or 3D + dims_size = len(dimensions) + assert dims_size == 2 or dims_size == 3 + + extent = [0, 0, 0, 0, 0, 0] + n_points = 1 + for idx in range(len(dimensions)): + array = metadata.xr_dataset[dimensions[-(1 + idx)]] + # Fill in reverse order (t, z, y, x) => [0, x.size, 0, y.size, 0, z.size] + # And extent include both index and adding 1 point so (len-1+1) => len + extent[idx * 2 + 1] = array.size + n_points *= array.size + 1 + + # Points + vtk_points = vtkPoints() + vtk_points.SetDataTypeToDouble() + vtk_points.Allocate(n_points) + + # debug + # print(f"{dimensions=}") + # print(f"{extent=}") + # print(f"{n_points=}") + + # Check if direct coord mapping + if dims_size == 2: # 2D + x_array = cell_center_to_point(metadata.xr_dataset[metadata.longitude].values) + y_array = cell_center_to_point(metadata.xr_dataset[metadata.latitude].values) + z = 0 + for j in range(extent[3] + 1): + lat = y_array[j] + for i in range(extent[1] + 1): + lon = x_array[i] + point_insert(vtk_points, spherical, lon, lat, z) + else: # 3D + x_array = cell_center_to_point(metadata.xr_dataset[metadata.longitude].values) + y_array = cell_center_to_point(metadata.xr_dataset[metadata.latitude].values) + z_array = cell_center_to_point(metadata.xr_dataset[metadata.vertical].values) + for k in range(extent[5] + 1): + z = z_array[k] + for j in range(extent[3] + 1): + lat = y_array[j] + for i in range(extent[1] + 1): + lon = x_array[i] + point_insert(vtk_points, spherical, lon, lat, z) + + # Mesh + mesh = vtkStructuredGrid() + mesh.points = vtk_points + mesh.extent = extent + + return mesh, data_location diff --git a/pan3d/xarray/cf/reader.py b/pan3d/xarray/cf/reader.py new file mode 100644 index 0000000..d43f7b9 --- /dev/null +++ b/pan3d/xarray/cf/reader.py @@ -0,0 +1,584 @@ +from typing import List, Optional + +import xarray as xr +import numpy as np +import pandas as pd + +from vtkmodules.util.vtkAlgorithm import VTKPythonAlgorithmBase +from vtkmodules.vtkCommonCore import vtkVariant +from vtkmodules.vtkCommonExecutionModel import vtkStreamingDemandDrivenPipeline +from vtkmodules.vtkCommonDataModel import vtkImageData, vtkDataObject +from vtkmodules.vtkFiltersCore import vtkArrayCalculator +from vtkmodules.util import numpy_support + +from pan3d.xarray.cf.coords.meta import MetaArrayMapping + +# ----------------------------------------------------------------------------- +# Helper functions +# ----------------------------------------------------------------------------- + + +def is_time_array(xarray, name, values): + if values.dtype.type == np.datetime64 or values.dtype.type == np.timedelta64: + un = np.datetime_data(values.dtype) + # unit = ns and 1 base unit in a spep + if un[0] == "ns" and un[1] == 1 and name in xarray.coords.keys(): + return True + return False + + +# ----------------------------------------------------------------------------- + + +def attr_value_to_vtk(value): + if np.issubdtype(type(value), np.integer): + return vtkVariant(int(value)) + + if np.issubdtype(type(value), np.floating): + return vtkVariant(float(value)) + + if isinstance(value, np.ndarray): + return vtkVariant(numpy_support.numpy_to_vtk(value)) + + return vtkVariant(value) + + +# ----------------------------------------------------------------------------- + + +def get_time_labels(times): + return [pd.to_datetime(time).strftime("%Y-%m-%d %H:%M:%S") for time in times] + + +# ----------------------------------------------------------------------------- +# VTK Algorithms +# ----------------------------------------------------------------------------- + + +class CFHelper: + def __init__(self, xr_dataset): + self._xr_dataset = xr_dataset + self._meta = MetaArrayMapping(xr_dataset) + self._t_index = 0 + self._mesh = None + self._array_names = [] + self._spherical = True + + def update(self, xr_dataset): + self._array_names = [] + self._mesh = None + self._xr_dataset = xr_dataset + self._meta.update(xr_dataset) + + def _reset(self): + self._mesh = None + + @property + def x(self): + return self._meta.longitude + + @property + def y(self): + return self._meta.latitude + + @property + def z(self): + return self._meta.vertical + + @property + def t(self): + return self._meta.time + + @property + def spherical(self): + return self._spherical + + @spherical.setter + def spherical(self, v): + if self._spherical != v: + self._spherical = v + self._reset() + + @property + def vertical_bias(self): + return self._meta.vertical_bias + + @vertical_bias.setter + def vertical_bias(self, v): + if self._meta.vertical_bias != v: + self._meta.vertical_bias = v + self._reset() + + @property + def vertical_scale(self): + return self._meta.vertical_scale + + @vertical_scale.setter + def vertical_scale(self, v): + if self._meta.vertical_scale != v: + self._meta.vertical_scale = v + self._reset() + + @property + def t_index(self): + """return the current selected time index""" + return self._t_index + + @t_index.setter + def t_index(self, t_index: int): + """update the current selected time index""" + if t_index != self._t_index: + self._t_index = t_index + self._reset() + + @property + def t_size(self): + """return the size of the coordinate used for the time""" + if self.t is None or self._xr_dataset is None: + return 0 + return int(self._xr_dataset[self.t].size) + + @property + def arrays(self): + """return the list of arrays that are currently selected to be added to the generated VTK mesh""" + return list(self._array_names) + + @arrays.setter + def arrays(self, array_names: List[str]): + """update the list of arrays to load on the generated VTK mesh""" + new_names = set(array_names or []) + if new_names != self._array_names: + self._array_names = new_names + self._reset() + + @property + def available_arrays(self): + all_list = [] + for arrays in self._meta.data_arrays.values(): + all_list.extend(arrays) + + return all_list + + @property + def mesh(self): + if self._mesh is None: + self._mesh = self._meta.get_mesh(self._t_index, self.spherical, self.arrays) + + return self._mesh + + +class vtkXArrayCFSource(VTKPythonAlgorithmBase): + """vtk source for converting XArray into a VTK mesh""" + + def __init__( + self, + input: Optional[xr.Dataset] = None, + arrays: Optional[List[str]] = None, + ): + """ + Create vtkXArraySource + + Parameters: + input (xr.Dataset): Provide an XArray to use as input. The load() method will replace it. + arrays (list[str]): List of field to load onto the generated VTK mesh. + """ + VTKPythonAlgorithmBase.__init__( + self, + nInputPorts=0, + nOutputPorts=1, + outputType="vtkDataObject", + ) + # Data source + self._input = input + self._pipeline = None + self._computed = {} + self._data_origin = None + + # Data sub-selection + self._array_names = set(arrays or []) + self._t_index = 0 + self._slices = None + + # vtk internal vars + self._arrays = {} + + # Create reader if xarray available + self._meta = CFHelper(self._input) + + # ------------------------------------------------------------------------- + # Information + # ------------------------------------------------------------------------- + + def __str__(self): + return """VTK XArray CF (Python) reader""" + + # ------------------------------------------------------------------------- + # Data input + # ------------------------------------------------------------------------- + + @property + def input(self): + """return current input XArray""" + return self._input + + @input.setter + def input(self, xarray_dataset: xr.Dataset): + """update input with a new XArray""" + self._slices = None + self._array_names.clear() + self._input = xarray_dataset + self._meta.update(self._input) + self._meta.t_index = self.t_index + self.Modified() + + # ------------------------------------------------------------------------- + # Array selectors + # ------------------------------------------------------------------------- + + @property + def x(self): + """return the name that is currently mapped to the X axis""" + return self._meta.x + + @property + def x_size(self): + """return the size of the coordinate used for the X axis""" + if self.x is None: + return 0 + return int(self._input[self.x].size) + + @property + def y(self): + """return the name that is currently mapped to the Y axis""" + return self._meta.y + + @property + def y_size(self): + """return the size of the coordinate used for the Y axis""" + if self.y is None: + return 0 + return int(self._input[self.y].size) + + @property + def z(self): + """return the name that is currently mapped to the Z axis""" + return self._meta.z + + @property + def z_size(self): + """return the size of the coordinate used for the Z axis""" + if self.z is None: + return 0 + return int(self._input[self.z].size) + + @property + def t(self): + """return the name that is currently mapped to the time axis""" + return self._meta.t + + @property + def slice_extents(self): + """return a dictionary for the X, Y, Z dimensions with the corresponding extent [0, size-1]""" + return { + coord_name: [0, self.input[coord_name].size - 1] + for coord_name in [self.x, self.y, self.z] + if coord_name is not None + } + + @property + def available_coords(self): + """List available coordinates arrays that have are 1D""" + if self._input is None: + return [] + + return [k for k, v in self._input.coords.items() if len(v.shape) == 1] + + # ------------------------------------------------------------------------- + # Projection management + # ------------------------------------------------------------------------- + + @property + def spherical(self): + return self._meta.spherical + + @spherical.setter + def spherical(self, v): + if self._meta.spherical != v: + self._meta.spherical = v + self.Modified() + + @property + def vertical_bias(self): + return self._meta.vertical_bias + + @vertical_bias.setter + def vertical_bias(self, v): + if self._meta.vertical_bias != v: + self._meta.vertical_bias = v + self.Modified() + + @property + def vertical_scale(self): + return self._meta.vertical_scale + + @vertical_scale.setter + def vertical_scale(self, v): + if self._meta.vertical_scale != v: + self._meta.vertical_scale = v + self.Modified() + + # ------------------------------------------------------------------------- + # Data sub-selection + # ------------------------------------------------------------------------- + + @property + def t_index(self): + """return the current selected time index""" + return self._t_index + + @t_index.setter + def t_index(self, t_index: int): + """update the current selected time index""" + if t_index != self._t_index: + self._t_index = t_index + self._meta.t_index = t_index + self.Modified() + + @property + def t_size(self): + """return the size of the coordinate used for the time""" + if self.t is None: + return 0 + return int(self._input[self.t].size) + + @property + def t_labels(self): + """return a list of string that match the various time values available""" + if self.t is None: + return [] + + t_array = self._input[self.t] + t_type = t_array.dtype + if np.issubdtype(t_type, np.datetime64): + return get_time_labels(t_array.values) + return [str(t) for t in t_array.values] + + @property + def arrays(self): + """return the list of arrays that are currently selected to be added to the generated VTK mesh""" + return list(self._array_names) + + @arrays.setter + def arrays(self, array_names: List[str]): + """update the list of arrays to load on the generated VTK mesh""" + new_names = set(array_names or []) + if new_names != self._array_names: + self._array_names = new_names + self._meta.arrays = new_names + self.Modified() + + @property + def available_arrays(self): + """List all available data fields for the `arrays` option""" + if self._input is None or self._meta is None: + return [] + + return self._meta.available_arrays + + @property + def slices(self): + """return the current slicing information which include axes crop/cut and time selection""" + result = dict(self._slices or {}) + if self.t is not None: + result[self.t] = self.t_index + return result + + @slices.setter + def slices(self, v): + """update the slicing of the data along axes""" + if v != self._slices: + self._slices = v + # FIXME !!! update accessor + # Ask Dan + # self.Modified() + # raise NotImplementedError() + print("set slices not implemented", v) + if "time" in v: + self.t_index = v.get("time", 0) + + # ------------------------------------------------------------------------- + # add-on logic + # ------------------------------------------------------------------------- + + @property + def computed(self): + """return the current description of the computed/derived fields on the VTK mesh""" + return self._computed + + @computed.setter + def computed(self, v): + """ + update the computed/derived fields to add on the VTK mesh + + The layout of the dictionary provided should be as follow: + - key: name of the field to be added + - value: formula to apply for the given field name. The syntax is captured in the document (https://docs.paraview.org/en/latest/UsersGuide/filteringData.html#calculator) + + Then additional keys need to be provided to describe your formula dependencies: + `_use_scalars` and `_use_vectors` which should be a list of string matching the name of the fields you are using in your expression. + + + Please find below an example: + + ``` + { + "_use_scalars": ["u", "v"], # (u,v) needed for "vec" and "m2" + "vec": "(u * iHat) + (v * jHat)", # 2D vector + "m2": "u*u + v*v", + } + ``` + """ + if self._computed != v: + self._computed = v or {} + self._pipeline = None + scalar_arrays = self._computed.get("_use_scalars", []) + vector_arrays = self._computed.get("_use_vectors", []) + + for output_name, func in self._computed.items(): + if output_name[0] == "_": + continue + filter = vtkArrayCalculator( + result_array_name=output_name, + function=func, + ) + + # register array dependencies + for scalar_array in scalar_arrays: + filter.AddScalarArrayName(scalar_array) + for vector_array in vector_arrays: + filter.AddVectorArrayName(vector_array) + + if self._pipeline is None: + self._pipeline = filter + else: + self._pipeline = self._pipeline >> filter + + self.Modified() + + def load(self, data_info): + """ + create a new XArray input with the `data_origin` and `dataset_config` information. + + Here is an example of the layout of the parameter + + ``` + { + "data_origin": { + "source": "url", # one of [file, url, xarray, pangeo, esgf] + "id": "https://ncsa.osn.xsede.org/Pangeo/pangeo-forge/noaa-coastwatch-geopolar-sst-feedstock/noaa-coastwatch-geopolar-sst.zarr", + "order": "C" # (optional) order to use in numpy + }, + "dataset_config": { + "x": "lon", # (optional) coord name for X + "y": "lat", # (optional) coord name for Y + "z": null, # (optional) coord name for Z + "t": "time", # (optional) coord name for time + "slices": { # (optional) array slicing + "lon": [ + 1000, + 6000, + 20 + ], + "lat": [ + 500, + 3000, + 20 + ], + "time": 5 + }, + "t_index": 5, # (optional) selected time index + "arrays": [ # (optional) names of arrays to load onto VTK mesh. + "analysed_sst" # If missing no array will be loaded + ] # onto the mesh. + } + } + ``` + """ + if "data_origin" not in data_info: + raise ValueError("Only state with data_origin can be loaded") + + from pan3d import catalogs + + self._data_origin = data_info["data_origin"] + self.input = catalogs.load_dataset( + self._data_origin["source"], self._data_origin["id"] + ) + + dataset_config = data_info.get("dataset_config") + if dataset_config is None: + self.arrays = self.available_arrays + else: + # self.slices = dataset_config.get("slices") # FIXME: not implemented yet + self.t_index = dataset_config.get("t_index", 0) + self.arrays = dataset_config.get("arrays", self.available_arrays) + + @property + def state(self): + """return current state that can be reused in a load() later on""" + if self._data_origin is None: + raise RuntimeError( + "No state available without data origin. Need to use the load method to set the data origin." + ) + + return { + "data_origin": self._data_origin, + "dataset_config": { + k: getattr(self, k) + for k in ["x", "y", "z", "t", "slices", "t_index", "arrays"] + }, + } + + # ------------------------------------------------------------------------- + # Algorithm + # ------------------------------------------------------------------------- + + def RequestDataObject(self, request, inInfo, outInfo): + output = vtkImageData() + if self._meta.mesh is not None: + output = self._meta.mesh.NewInstance() + print(f"RequestDataObject::{output.GetClassName()}") + + outInfo.GetInformationObject(0).Set(vtkDataObject.DATA_OBJECT(), output) + return 1 + + def RequestInformation(self, request, inInfo, outInfo): + if ( + self._meta._xr_dataset + and self._meta.mesh + and hasattr(self._meta.mesh, "extent") + ): + whole_extent = self._meta.mesh.extent + print(f"RequestInformation::{whole_extent=}") + outInfo.GetInformationObject(0).Set( + vtkStreamingDemandDrivenPipeline.WHOLE_EXTENT(), *whole_extent + ) + + return 1 + + def RequestData(self, request, inInfo, outInfo): + """implementation of the vtk algorithm for generating the VTK mesh""" + # Use open data_array handle to fetch data at + # desired Level of Detail + mesh = self._meta.mesh + if mesh is not None: + pdo = self.GetOutputData(outInfo, 0) + + # Compute derived quantity + if self._pipeline is not None: + mesh = self._pipeline(mesh) + pdo.ShallowCopy(mesh) + else: + pdo.ShallowCopy(mesh) + + return 1 + return 0 From 891cef44a652b7a4ca85390707a5343bd836c9a5 Mon Sep 17 00:00:00 2001 From: Sebastien Jourdain Date: Thu, 23 Jan 2025 18:40:30 -0700 Subject: [PATCH 08/10] wip: cleaner reader --- pan3d/ui/preview.py | 11 ++- pan3d/xarray/cf/constants.py | 5 ++ pan3d/xarray/cf/coords/meta.py | 122 +++++++++++++++++++++++++++---- pan3d/xarray/cf/reader.py | 126 +++++++++++++++++++-------------- 4 files changed, 194 insertions(+), 70 deletions(-) diff --git a/pan3d/ui/preview.py b/pan3d/ui/preview.py index 23f0feb..d6bcb61 100644 --- a/pan3d/ui/preview.py +++ b/pan3d/ui/preview.py @@ -11,6 +11,7 @@ from pan3d.ui.css import base, preview from pan3d.ui.collapsible import CollapsableSection +from pan3d.xarray.cf.constants import Projection class SummaryToolbar(v3.VCard): @@ -732,7 +733,9 @@ def update_from_source(self, source=None): self.state.axis_names = [source.x, source.y, source.z] self.state.slice_extents = source.slice_extents self.state.projection_mode = ( - "spherical" if source.spherical else "euclidean" + "spherical" + if source.projection == Projection.SPHERICAL + else "euclidean" ) self.state.spherical_bias = source.vertical_bias self.state.spherical_scale = source.vertical_scale @@ -865,7 +868,11 @@ def _on_array_selection(self, data_arrays, **_): def _on_projection_change( self, spherical_bias, spherical_scale, projection_mode, **_ ): - self.source.spherical = projection_mode == "spherical" + self.source.projection = ( + Projection.SPHERICAL + if projection_mode == "spherical" + else Projection.EUCLIDEAN + ) self.source.vertical_bias = spherical_bias self.source.vertical_scale = spherical_scale diff --git a/pan3d/xarray/cf/constants.py b/pan3d/xarray/cf/constants.py index cf6468b..0570d48 100644 --- a/pan3d/xarray/cf/constants.py +++ b/pan3d/xarray/cf/constants.py @@ -89,6 +89,11 @@ def __repr__(self): """ +class Projection(Enum): + SPHERICAL = "Spherical" + EUCLIDEAN = "Euclidean" + + class Scale(Enum): da = (1e1, {"deca", "deka"}) h = (1e2, {"hecto"}) diff --git a/pan3d/xarray/cf/coords/meta.py b/pan3d/xarray/cf/coords/meta.py index bb6c8d8..e5afd95 100644 --- a/pan3d/xarray/cf/coords/meta.py +++ b/pan3d/xarray/cf/coords/meta.py @@ -2,6 +2,13 @@ import numpy as np from pan3d.xarray.cf import mesh from pan3d.xarray.cf.coords.convert import is_uniform +from pan3d.xarray.cf.constants import Projection +from vtkmodules.vtkCommonDataModel import ( + vtkImageData, + vtkRectilinearGrid, + vtkStructuredGrid, + vtkUnstructuredGrid, +) PRESSURE_UNITS = { "bar", @@ -21,7 +28,6 @@ "kilometer", "km", } - COORDINATES_DETECTION = { "longitude": { "units": { @@ -393,18 +399,104 @@ def use_coords(self, dims): return False return True - def get_mesh(self, time_index=0, spherical=True, fields=None): - if self.xr_dataset is None or not fields: - return None + def compatible_fields(self, fields=None): + if not fields: + return [] + data_dims = self.xr_dataset[fields[0]].dims + return [n for n in fields if self.xr_dataset[n].dims == data_dims] + + def dimensions(self, field): + return self.xr_dataset[field].dims + + def timeless_dimensions(self, field): + dims = self.xr_dataset[field].dims + return dims[1:] if dims[0] == self.time else dims + + def field_extent(self, field): + extent = [0, 0, 0, 0, 0, 0] + dimensions = self.timeless_dimensions(field) + for idx in range(len(dimensions)): + array = self.xr_dataset[dimensions[-(1 + idx)]] + # Fill in reverse order (t, z, y, x) => [0, x.size, 0, y.size, 0, z.size] + # And extent include both index so (len-1) + extent[idx * 2 + 1] = array.size - 1 + + return extent + + def get_vtk_mesh_type(self, projection, fields=None): + fields = self.compatible_fields(fields) + + if self.longitude is None or self.latitude is None or not fields: + # default empty mesh + return vtkImageData() + + # unstructured + timeless_dims = self.timeless_dimensions(fields[0]) + if len(timeless_dims) == 1: + return vtkUnstructuredGrid() + + # structured + if ( + self.coords_has_bounds + or projection == Projection.SPHERICAL + or not self.coords_1d + ): + return vtkStructuredGrid() + + # rectilinear + if not self.uniform_spacing: + return vtkRectilinearGrid() + + # imagedata + return vtkImageData() + + def get_vtk_whole_extent(self, projection, fields=None): + if self.longitude is None or self.latitude is None or not fields: + return [ + 0, + 0, + 0, + 0, + 0, + 0, + ] + + mesh_type = self.get_vtk_mesh_type(projection, fields) + fields = self.compatible_fields(fields) + extent = self.field_extent(fields[0]) + dimensions = self.timeless_dimensions(fields[0]) + + print(f"before {extent=}") + print(f"class {mesh_type.GetClassName()}") + + if mesh_type.IsA("vtkStructuredGrid") and not ( + self.uniform_lat_lon and self.use_coords(dimensions) + ): + # point data + return extent + + # cell data, need to +1 on the extent + for i in range(3): + if extent[i * 2 + 1] > 0: + extent[i * 2 + 1] += 1 + print(f"after {extent=}") + + return extent + + def get_vtk_mesh(self, time_index=0, projection=None, fields=None): vtk_mesh, data_location = None, None + if self.xr_dataset is None or not fields: + return vtk_mesh + + # resolve projection + if projection is None: + projection = Projection.SPHERICAL + spherical_proj = projection == Projection.SPHERICAL # ensure similar dimension across array names - data_dims = self.xr_dataset[fields[0]].dims - data_dims_no_time = data_dims[1:] if data_dims[0] == self.time else data_dims - valid_data_array_names = [ - n for n in fields if self.xr_dataset[n].dims == data_dims - ] + fields = self.compatible_fields(fields) + data_dims_no_time = self.timeless_dimensions(fields[0]) # No mesh if no lon/lat if self.longitude is None or self.latitude is None: @@ -413,20 +505,20 @@ def get_mesh(self, time_index=0, spherical=True, fields=None): # Unstructured if len(data_dims_no_time) == 1: vtk_mesh, data_location = mesh.unstructured.generate_mesh( - self, data_dims_no_time, time_index, spherical + self, data_dims_no_time, time_index, spherical_proj ) # Structured if vtk_mesh is None and ( - self.coords_has_bounds or spherical or not self.coords_1d + self.coords_has_bounds or spherical_proj or not self.coords_1d ): vtk_mesh, data_location = mesh.structured.generate_mesh( - self, data_dims_no_time, time_index, spherical + self, data_dims_no_time, time_index, spherical_proj ) - # This should only happen if we don't want spherical + # This should only happen if we don't want spherical_proj if vtk_mesh is None: - assert not spherical + assert not spherical_proj # Rectilinear if vtk_mesh is None and not self.uniform_spacing: @@ -443,7 +535,7 @@ def get_mesh(self, time_index=0, spherical=True, fields=None): # Add fields if vtk_mesh: container = getattr(vtk_mesh, data_location) - for field_name in valid_data_array_names: + for field_name in fields: field = ( self.xr_dataset[field_name][time_index].values if self.time diff --git a/pan3d/xarray/cf/reader.py b/pan3d/xarray/cf/reader.py index d43f7b9..02852cd 100644 --- a/pan3d/xarray/cf/reader.py +++ b/pan3d/xarray/cf/reader.py @@ -7,11 +7,12 @@ from vtkmodules.util.vtkAlgorithm import VTKPythonAlgorithmBase from vtkmodules.vtkCommonCore import vtkVariant from vtkmodules.vtkCommonExecutionModel import vtkStreamingDemandDrivenPipeline -from vtkmodules.vtkCommonDataModel import vtkImageData, vtkDataObject +from vtkmodules.vtkCommonDataModel import vtkDataObject from vtkmodules.vtkFiltersCore import vtkArrayCalculator from vtkmodules.util import numpy_support from pan3d.xarray.cf.coords.meta import MetaArrayMapping +from pan3d.xarray.cf.constants import Projection # ----------------------------------------------------------------------------- # Helper functions @@ -55,7 +56,7 @@ def get_time_labels(times): # ----------------------------------------------------------------------------- -class CFHelper: +class CFHelperXXX: def __init__(self, xr_dataset): self._xr_dataset = xr_dataset self._meta = MetaArrayMapping(xr_dataset) @@ -155,7 +156,13 @@ def arrays(self, array_names: List[str]): def available_arrays(self): all_list = [] for arrays in self._meta.data_arrays.values(): - all_list.extend(arrays) + if self._array_names: + if self._array_names.intersection(arrays): + # only add compatible arrays + all_list.extend(arrays) + else: + # when no array selected, add all of them + all_list.extend(arrays) return all_list @@ -202,15 +209,22 @@ def __init__( # vtk internal vars self._arrays = {} + # projection + self._proj_mode = Projection.SPHERICAL + self._proj_vertical_bias = 6378137 # earth radius in meter + self._proj_vertical_scale = ( + 100 # increase z scaling to see something compare to earth radius + ) + # Create reader if xarray available - self._meta = CFHelper(self._input) + self._mapping = MetaArrayMapping(self._input) # ------------------------------------------------------------------------- # Information # ------------------------------------------------------------------------- def __str__(self): - return """VTK XArray CF (Python) reader""" + return f"VTK XArray CF (Python) reader\n{self._mapping}" # ------------------------------------------------------------------------- # Data input @@ -227,8 +241,8 @@ def input(self, xarray_dataset: xr.Dataset): self._slices = None self._array_names.clear() self._input = xarray_dataset - self._meta.update(self._input) - self._meta.t_index = self.t_index + self._mapping.update(self._input) + self.t_index = 0 self.Modified() # ------------------------------------------------------------------------- @@ -238,7 +252,7 @@ def input(self, xarray_dataset: xr.Dataset): @property def x(self): """return the name that is currently mapped to the X axis""" - return self._meta.x + return self._mapping.longitude @property def x_size(self): @@ -250,7 +264,7 @@ def x_size(self): @property def y(self): """return the name that is currently mapped to the Y axis""" - return self._meta.y + return self._mapping.latitude @property def y_size(self): @@ -262,7 +276,7 @@ def y_size(self): @property def z(self): """return the name that is currently mapped to the Z axis""" - return self._meta.z + return self._mapping.vertical @property def z_size(self): @@ -274,11 +288,13 @@ def z_size(self): @property def t(self): """return the name that is currently mapped to the time axis""" - return self._meta.t + return self._mapping.time @property def slice_extents(self): """return a dictionary for the X, Y, Z dimensions with the corresponding extent [0, size-1]""" + # !!! can be different based on which field is selected !!! + # -> result not obvious based on the mesh type return { coord_name: [0, self.input[coord_name].size - 1] for coord_name in [self.x, self.y, self.z] @@ -287,7 +303,8 @@ def slice_extents(self): @property def available_coords(self): - """List available coordinates arrays that have are 1D""" + """List available coordinates arrays that are 1D""" + # !!! Do we use that ??? if self._input is None: return [] @@ -298,33 +315,35 @@ def available_coords(self): # ------------------------------------------------------------------------- @property - def spherical(self): - return self._meta.spherical + def projection(self): + return self._proj_mode - @spherical.setter - def spherical(self, v): - if self._meta.spherical != v: - self._meta.spherical = v + @projection.setter + def projection(self, mode=None): + if mode in Projection and self._proj_mode != mode: + if isinstance(mode, str): + mode = Projection(mode) + self._proj_mode = mode self.Modified() @property def vertical_bias(self): - return self._meta.vertical_bias + return self._proj_vertical_bias @vertical_bias.setter def vertical_bias(self, v): - if self._meta.vertical_bias != v: - self._meta.vertical_bias = v + if self._proj_vertical_bias != v: + self._proj_vertical_bias = v self.Modified() @property def vertical_scale(self): - return self._meta.vertical_scale + return self._proj_vertical_scale @vertical_scale.setter def vertical_scale(self, v): - if self._meta.vertical_scale != v: - self._meta.vertical_scale = v + if self._proj_vertical_scale != v: + self._proj_vertical_scale = v self.Modified() # ------------------------------------------------------------------------- @@ -341,7 +360,6 @@ def t_index(self, t_index: int): """update the current selected time index""" if t_index != self._t_index: self._t_index = t_index - self._meta.t_index = t_index self.Modified() @property @@ -374,16 +392,25 @@ def arrays(self, array_names: List[str]): new_names = set(array_names or []) if new_names != self._array_names: self._array_names = new_names - self._meta.arrays = new_names self.Modified() @property def available_arrays(self): """List all available data fields for the `arrays` option""" - if self._input is None or self._meta is None: + if self._input is None: return [] - return self._meta.available_arrays + all_list = [] + for arrays in self._mapping.data_arrays.values(): + if self._array_names: + if self._array_names.intersection(arrays): + # only add compatible arrays + all_list.extend(arrays) + else: + # when no array selected, add all of them + all_list.extend(arrays) + + return all_list @property def slices(self): @@ -398,14 +425,11 @@ def slices(self, v): """update the slicing of the data along axes""" if v != self._slices: self._slices = v - # FIXME !!! update accessor - # Ask Dan - # self.Modified() - # raise NotImplementedError() - print("set slices not implemented", v) if "time" in v: self.t_index = v.get("time", 0) + self.Modified() + # ------------------------------------------------------------------------- # add-on logic # ------------------------------------------------------------------------- @@ -518,9 +542,9 @@ def load(self, data_info): if dataset_config is None: self.arrays = self.available_arrays else: - # self.slices = dataset_config.get("slices") # FIXME: not implemented yet + self.slices = dataset_config.get("slices") self.t_index = dataset_config.get("t_index", 0) - self.arrays = dataset_config.get("arrays", self.available_arrays) + self.arrays = dataset_config.get("arrays", [self.available_arrays[0]]) @property def state(self): @@ -543,25 +567,18 @@ def state(self): # ------------------------------------------------------------------------- def RequestDataObject(self, request, inInfo, outInfo): - output = vtkImageData() - if self._meta.mesh is not None: - output = self._meta.mesh.NewInstance() - print(f"RequestDataObject::{output.GetClassName()}") - + output = self._mapping.get_vtk_mesh_type(self._proj_mode, self.arrays) outInfo.GetInformationObject(0).Set(vtkDataObject.DATA_OBJECT(), output) + print(f"RequestDataObject::{output.GetClassName()=}") return 1 def RequestInformation(self, request, inInfo, outInfo): - if ( - self._meta._xr_dataset - and self._meta.mesh - and hasattr(self._meta.mesh, "extent") - ): - whole_extent = self._meta.mesh.extent - print(f"RequestInformation::{whole_extent=}") - outInfo.GetInformationObject(0).Set( - vtkStreamingDemandDrivenPipeline.WHOLE_EXTENT(), *whole_extent - ) + whole_extent = self._mapping.get_vtk_whole_extent(self._proj_mode, self.arrays) + print(f"{whole_extent=}") + outInfo.GetInformationObject(0).Set( + vtkStreamingDemandDrivenPipeline.WHOLE_EXTENT(), + *whole_extent, + ) return 1 @@ -569,8 +586,12 @@ def RequestData(self, request, inInfo, outInfo): """implementation of the vtk algorithm for generating the VTK mesh""" # Use open data_array handle to fetch data at # desired Level of Detail - mesh = self._meta.mesh + mesh = self._mapping.get_vtk_mesh( + time_index=self.t_index, projection=self._proj_mode, fields=self.arrays + ) if mesh is not None: + print(f"{mesh.extent=}") + print(f"{mesh.GetClassName()=}") pdo = self.GetOutputData(outInfo, 0) # Compute derived quantity @@ -580,5 +601,4 @@ def RequestData(self, request, inInfo, outInfo): else: pdo.ShallowCopy(mesh) - return 1 - return 0 + return 1 From 4edb1ab653b461971924dca28868688881cc2bfd Mon Sep 17 00:00:00 2001 From: Sebastien Jourdain Date: Fri, 24 Jan 2025 17:02:24 -0700 Subject: [PATCH 09/10] wip: start supporting slicing --- pan3d/ui/preview.py | 56 +++++--- pan3d/xarray/cf/coords/convert.py | 31 +++++ pan3d/xarray/cf/coords/index_mapping.py | 2 +- pan3d/xarray/cf/coords/meta.py | 62 ++++++--- pan3d/xarray/cf/coords/parametric_vertical.py | 1 - pan3d/xarray/cf/mesh/rectilinear.py | 2 +- pan3d/xarray/cf/mesh/structured.py | 125 +++++++++++------- pan3d/xarray/cf/mesh/uniform.py | 2 +- pan3d/xarray/cf/mesh/unstructured.py | 2 +- pan3d/xarray/cf/reader.py | 51 +++++-- 10 files changed, 228 insertions(+), 106 deletions(-) diff --git a/pan3d/ui/preview.py b/pan3d/ui/preview.py index d6bcb61..702c11d 100644 --- a/pan3d/ui/preview.py +++ b/pan3d/ui/preview.py @@ -227,6 +227,19 @@ def update_information(self, xr, available_arrays=None): "value": f"[{xr[name].values[0]}, {xr[name].values[-1]}]", } ) + elif len(shape) > 1: + attrs.append( + { + "key": "dims", + "value": f'({", ".join(xr[name].dims)})', + } + ) + attrs.append( + { + "key": "shape", + "value": f'({", ".join([str(v) for v in xr[name].shape])})', + } + ) if name in data: icon = "mdi-database" order = 2 @@ -530,10 +543,11 @@ def __init__(self, source, update_rendering): size="sm", classes="mx-2", ) - v3.VDivider() # Slice steps - with v3.VTooltip(text="Level Of Details / Slice stepping"): + with v3.VTooltip( + text="Level Of Details / Slice stepping", v_if="axis_names.length" + ): with html.Template(v_slot_activator="{ props }"): with v3.VRow( v_bind="props", @@ -544,7 +558,7 @@ def __init__(self, source, update_rendering): "mdi-stairs", classes="ml-2 text-medium-emphasis", ) - with v3.VCol(classes="pa-0", v_if="axis_names?.[0]"): + with v3.VCol(classes="pa-0", v_if="axis_names.length > 0"): v3.VTextField( v_model_number=("slice_x_step", 1), hide_details=True, @@ -555,7 +569,7 @@ def __init__(self, source, update_rendering): raw_attrs=['min="1"'], type="number", ) - with v3.VCol(classes="pa-0", v_if="axis_names?.[1]"): + with v3.VCol(classes="pa-0", v_if="axis_names.length > 1"): v3.VTextField( v_model_number=("slice_y_step", 1), hide_details=True, @@ -566,7 +580,7 @@ def __init__(self, source, update_rendering): raw_attrs=['min="1"'], type="number", ) - with v3.VCol(classes="pa-0", v_if="axis_names?.[2]"): + with v3.VCol(classes="pa-0", v_if="axis_names.length > 2"): v3.VTextField( v_model_number=("slice_z_step", 1), hide_details=True, @@ -730,7 +744,7 @@ def update_from_source(self, source=None): self.state.data_arrays_available = source.available_arrays self.state.data_arrays = source.arrays self.state.color_by = None - self.state.axis_names = [source.x, source.y, source.z] + self.state.axis_names = [] self.state.slice_extents = source.slice_extents self.state.projection_mode = ( "spherical" @@ -740,16 +754,19 @@ def update_from_source(self, source=None): self.state.spherical_bias = source.vertical_bias self.state.spherical_scale = source.vertical_scale slices = source.slices - for axis in XYZ: + for idx, name in enumerate(self.state.slice_extents): + axis = XYZ[idx] + self.state.axis_names.append(name) + self.state.dirty("axis_names") # default - axis_extent = self.state.slice_extents.get(getattr(source, axis)) + axis_extent = self.state.slice_extents.get(name) self.state[f"slice_{axis}_range"] = axis_extent self.state[f"slice_{axis}_cut"] = 0 self.state[f"slice_{axis}_step"] = 1 self.state[f"slice_{axis}_type"] = "range" # use slice info if available - axis_slice = slices.get(getattr(source, axis)) + axis_slice = slices.get(name) if axis_slice is not None: if isinstance(axis_slice, int): # cut @@ -862,19 +879,24 @@ def _on_array_selection(self, data_arrays, **_): elif len(data_arrays) == 0: self.state.color_by = None - self.source.arrays = data_arrays + if set(self.source.arrays) != set(data_arrays): + self.source.arrays = data_arrays + self.update_from_source(self.source) @change("spherical_bias", "spherical_scale", "projection_mode") def _on_projection_change( self, spherical_bias, spherical_scale, projection_mode, **_ ): - self.source.projection = ( - Projection.SPHERICAL - if projection_mode == "spherical" - else Projection.EUCLIDEAN - ) - self.source.vertical_bias = spherical_bias - self.source.vertical_scale = spherical_scale + if projection_mode == "spherical": + self.source.projection = Projection.SPHERICAL + self.source.vertical_bias = spherical_bias + self.source.vertical_scale = spherical_scale + else: + self.source.projection = Projection.EUCLIDEAN + self.source.vertical_bias = 0 + self.source.vertical_scale = 1 + + self.ctrl.view_reset_camera() class ControlPanel(v3.VCard): diff --git a/pan3d/xarray/cf/coords/convert.py b/pan3d/xarray/cf/coords/convert.py index 377ae4d..d8ccd55 100644 --- a/pan3d/xarray/cf/coords/convert.py +++ b/pan3d/xarray/cf/coords/convert.py @@ -2,6 +2,33 @@ import numpy as np +def to_isel(slices_info, *array_names): + slices = {} + for name in array_names: + if name is None: + continue + + info = slices_info.get(name) + if info is None: + continue + if isinstance(info, int): + slices[name] = info + else: + start, stop, step = info + stop -= (stop - start) % step + slices[name] = slice(start, stop, step) + + return slices if slices else None + + +def slice_array(array_name, dataset, slice_info): + if array_name is None: + return np.zeros(1, dtype=np.float32) + array = dataset[array_name] + dims = array.dims + return array.isel(to_isel(slice_info, *dims)).values + + def extract_uniform_info(array): origin = float(array[0]) spacing = (float(array[-1]) - origin) / (array.size - 1) @@ -21,6 +48,10 @@ def is_uniform(array): def cell_center_to_point(in_array): + if in_array.size == 1: + print("size 1") + return [float(in_array)] + uniform_data = extract_uniform_info(in_array) if uniform_data is not None: origin, spacing, size = uniform_data diff --git a/pan3d/xarray/cf/coords/index_mapping.py b/pan3d/xarray/cf/coords/index_mapping.py index d75e973..9bc9df6 100644 --- a/pan3d/xarray/cf/coords/index_mapping.py +++ b/pan3d/xarray/cf/coords/index_mapping.py @@ -33,7 +33,7 @@ def __init__(self, xr_dataset, in_dims, out_name): name_to_ijk = {in_dims[-(i + 1)]: "ijk"[i] for i in range(len(in_dims))} out_dims = xr_dataset[out_name].dims map_method_name = "".join([name_to_ijk[name] for name in out_dims]) - print(out_name, "=>", map_method_name) + # print(out_name, "=>", map_method_name) setattr(self, "fn", getattr(self, map_method_name)) diff --git a/pan3d/xarray/cf/coords/meta.py b/pan3d/xarray/cf/coords/meta.py index e5afd95..6ac2496 100644 --- a/pan3d/xarray/cf/coords/meta.py +++ b/pan3d/xarray/cf/coords/meta.py @@ -1,7 +1,7 @@ from enum import Enum import numpy as np from pan3d.xarray.cf import mesh -from pan3d.xarray.cf.coords.convert import is_uniform +from pan3d.xarray.cf.coords.convert import is_uniform, slice_array from pan3d.xarray.cf.constants import Projection from vtkmodules.vtkCommonDataModel import ( vtkImageData, @@ -412,17 +412,34 @@ def timeless_dimensions(self, field): dims = self.xr_dataset[field].dims return dims[1:] if dims[0] == self.time else dims - def field_extent(self, field): + def dims_extent(self, dimensions, slices=None): extent = [0, 0, 0, 0, 0, 0] - dimensions = self.timeless_dimensions(field) + + if slices is None: + slices = {} + for idx in range(len(dimensions)): - array = self.xr_dataset[dimensions[-(1 + idx)]] + name = dimensions[-(1 + idx)] + array = self.xr_dataset[name] # Fill in reverse order (t, z, y, x) => [0, x.size, 0, y.size, 0, z.size] # And extent include both index so (len-1) - extent[idx * 2 + 1] = array.size - 1 + if name in slices: + slice_info = slices[name] + if isinstance(slice_info, int): + # size of 1 + pass + else: + size = int((slice_info[1] - slice_info[0]) / slice_info[2]) + extent[idx * 2 + 1] = size - 1 + else: + extent[idx * 2 + 1] = array.size - 1 return extent + def field_extent(self, field, slices=None): + dimensions = self.timeless_dimensions(field) + return self.dims_extent(dimensions, slices) + def get_vtk_mesh_type(self, projection, fields=None): fields = self.compatible_fields(fields) @@ -450,7 +467,7 @@ def get_vtk_mesh_type(self, projection, fields=None): # imagedata return vtkImageData() - def get_vtk_whole_extent(self, projection, fields=None): + def get_vtk_whole_extent(self, projection, fields=None, slices=None): if self.longitude is None or self.latitude is None or not fields: return [ 0, @@ -463,11 +480,8 @@ def get_vtk_whole_extent(self, projection, fields=None): mesh_type = self.get_vtk_mesh_type(projection, fields) fields = self.compatible_fields(fields) - extent = self.field_extent(fields[0]) dimensions = self.timeless_dimensions(fields[0]) - - print(f"before {extent=}") - print(f"class {mesh_type.GetClassName()}") + extent = self.dims_extent(dimensions, slices) if mesh_type.IsA("vtkStructuredGrid") and not ( self.uniform_lat_lon and self.use_coords(dimensions) @@ -480,11 +494,14 @@ def get_vtk_whole_extent(self, projection, fields=None): if extent[i * 2 + 1] > 0: extent[i * 2 + 1] += 1 - print(f"after {extent=}") + print(f"Whole extent: {extent}") return extent - def get_vtk_mesh(self, time_index=0, projection=None, fields=None): + def get_vtk_mesh(self, time_index=0, projection=None, fields=None, slices=None): + if slices is None: + slices = {} + vtk_mesh, data_location = None, None if self.xr_dataset is None or not fields: return vtk_mesh @@ -505,7 +522,7 @@ def get_vtk_mesh(self, time_index=0, projection=None, fields=None): # Unstructured if len(data_dims_no_time) == 1: vtk_mesh, data_location = mesh.unstructured.generate_mesh( - self, data_dims_no_time, time_index, spherical_proj + self, data_dims_no_time, time_index, spherical_proj, slices ) # Structured @@ -513,7 +530,7 @@ def get_vtk_mesh(self, time_index=0, projection=None, fields=None): self.coords_has_bounds or spherical_proj or not self.coords_1d ): vtk_mesh, data_location = mesh.structured.generate_mesh( - self, data_dims_no_time, time_index, spherical_proj + self, data_dims_no_time, time_index, spherical_proj, slices ) # This should only happen if we don't want spherical_proj @@ -523,24 +540,27 @@ def get_vtk_mesh(self, time_index=0, projection=None, fields=None): # Rectilinear if vtk_mesh is None and not self.uniform_spacing: vtk_mesh, data_location = mesh.rectilinear.generate_mesh( - self, data_dims_no_time, time_index + self, data_dims_no_time, time_index, slices ) # Uniform if vtk_mesh is None: vtk_mesh, data_location = mesh.uniform.generate_mesh( - self, data_dims_no_time, time_index + self, data_dims_no_time, time_index, slices ) # Add fields if vtk_mesh: container = getattr(vtk_mesh, data_location) for field_name in fields: - field = ( - self.xr_dataset[field_name][time_index].values - if self.time - else self.xr_dataset[field_name].values - ) + field = slice_array(field_name, self.xr_dataset, slices) + # # FIXME to select slices + # print(f"fields: {slices=}") + # field = ( + # self.xr_dataset[field_name][time_index].values + # if self.time + # else self.xr_dataset[field_name].values + # ) container[field_name] = field.ravel() else: print(" !!! No mesh for data") diff --git a/pan3d/xarray/cf/coords/parametric_vertical.py b/pan3d/xarray/cf/coords/parametric_vertical.py index a443a84..7736e90 100644 --- a/pan3d/xarray/cf/coords/parametric_vertical.py +++ b/pan3d/xarray/cf/coords/parametric_vertical.py @@ -71,7 +71,6 @@ def __init__(self, formula, bias=0, scale=1): self._fn = formula self._bias = bias self._scale = scale - print(f"{bias=} {scale=}") def __call__(self, n=0, k=0, j=0, i=0): return self._bias + self._scale * self._fn(n=n, k=k, j=j, i=i) diff --git a/pan3d/xarray/cf/mesh/rectilinear.py b/pan3d/xarray/cf/mesh/rectilinear.py index 09fe0b8..b95807e 100644 --- a/pan3d/xarray/cf/mesh/rectilinear.py +++ b/pan3d/xarray/cf/mesh/rectilinear.py @@ -3,7 +3,7 @@ from ..coords.convert import cell_center_to_point -def generate_mesh(metadata, dimensions, time_index): +def generate_mesh(metadata, dimensions, time_index, slices): data_location = "cell_data" extent = [0, 0, 0, 0, 0, 0] empty_coords = np.zeros((1,), dtype=np.double) diff --git a/pan3d/xarray/cf/mesh/structured.py b/pan3d/xarray/cf/mesh/structured.py index d0f9d53..2043531 100644 --- a/pan3d/xarray/cf/mesh/structured.py +++ b/pan3d/xarray/cf/mesh/structured.py @@ -3,26 +3,28 @@ from ..coords.parametric_vertical import get_formula as get_z_formula from ..coords.index_mapping import get_formula as get_coords_formula -from ..coords.convert import point_insert, cell_center_to_point +from ..coords.convert import point_insert, cell_center_to_point, slice_array -def generate_mesh(metadata, dimensions, time_index, spherical): +def generate_mesh(metadata, dimensions, time_index, spherical, slices): # Data location and extend depend if we can extrapolate cell locations # bounds or uniform allow to define cell bounds if metadata.coords_has_bounds: print(" => structured: bounds") - return generate_bound_cells(metadata, dimensions, time_index, spherical) + return generate_bound_cells(metadata, dimensions, time_index, spherical, slices) if metadata.uniform_spacing: print(" => structured: uniform spacing") - return generate_uniform_cells(metadata, dimensions, time_index, spherical) + return generate_uniform_cells( + metadata, dimensions, time_index, spherical, slices + ) # We can only figure out the point location print(" => structured: on points") - return generate_mesh_points(metadata, dimensions, time_index, spherical) + return generate_mesh_points(metadata, dimensions, time_index, spherical, slices) -def generate_uniform_cells(metadata, dimensions, time_index, spherical): +def generate_uniform_cells(metadata, dimensions, time_index, spherical, slices): data_location = "cell_data" assert spherical @@ -33,67 +35,69 @@ def generate_uniform_cells(metadata, dimensions, time_index, spherical): # extract extent, origin, spacing origin = [0, 0, 0] spacing = [1, 1, 1] - extent = [0, 0, 0, 0, 0, 0] - n_points = 1 - dims_origin_spacing = [] + sizes = [1, 1, 1] for idx in range(len(dimensions)): - array = metadata.xr_dataset[dimensions[-(1 + idx)]] - - # Fill in reverse order (t, z, y, x) => [0, x.size, 0, y.size, 0, z.size] - # Use size as end of extent as we are adding 1 point for map data on cell - extent[idx * 2 + 1] = array.size - n_points *= array.size + name = dimensions[-(1 + idx)] + array = slice_array(name, metadata.xr_dataset, slices) # axis origin/spacing - axis_spacing = (array[-1].values - array[0].values) / (array.size - 1) - axis_origin = float(array[0].values) - axis_spacing * 0.5 + if array.size == 1: + continue - # update global origin/spacing - origin[idx] = axis_origin - spacing[idx] = axis_spacing + axis_spacing = (array[-1] - array[0]) / (array.size - 1) + axis_origin = float(array[0]) - axis_spacing * 0.5 - # Add (origin, spacing) for missing coords - while len(dims_origin_spacing) < 3: - dims_origin_spacing.insert(0, (0, 1)) + # update global origin/spacing + origin[idx] = float(axis_origin) + spacing[idx] = float(axis_spacing) + sizes[idx] = array.size + 1 # debug # print(f"{dimensions=}") # print(f"{extent=}") # print(f"{n_points=}") # print(f"{dimensions=}") - # print(f"{extent=}") - # print(f"{dims_origin_spacing=}") + # print(f"{origin=}") + # print(f"{spacing=}") + # print(f"{sizes=}") # Points vtk_points = vtkPoints() vtk_points.SetDataTypeToDouble() - vtk_points.Allocate(n_points) + vtk_points.Allocate(sizes[0] * sizes[1] * sizes[2]) # Check if direct coord mapping - for k in range(extent[5] + 1): + for k in range(sizes[2]): z = origin[2] + k * spacing[2] z = metadata.vertical_bias + metadata.vertical_scale * z - for j in range(extent[3] + 1): + for j in range(sizes[1]): lat = origin[1] + j * spacing[1] - for i in range(extent[1] + 1): + for i in range(sizes[0]): lon = origin[0] + i * spacing[0] point_insert(vtk_points, spherical, lon, lat, z) # Mesh mesh = vtkStructuredGrid() mesh.points = vtk_points - mesh.extent = extent + mesh.extent = [ + 0, + sizes[0] - 1, + 0, + sizes[1] - 1, + 0, + sizes[2] - 1, + ] return mesh, data_location -def generate_bound_cells(metadata, dimensions, time_index, spherical): +def generate_bound_cells(metadata, dimensions, time_index, spherical, slices): data_location = "cell_data" raise NotImplementedError("structured::generate_bound_cells") return False, data_location -def generate_mesh_points(metadata, dimensions, time_index, spherical): +def generate_mesh_points(metadata, dimensions, time_index, spherical, slices): data_location = "point_data" if ( @@ -101,8 +105,9 @@ def generate_mesh_points(metadata, dimensions, time_index, spherical): and metadata.uniform_lat_lon and metadata.use_coords(dimensions) ): + print(" ==> move datat to cell") return generate_mesh_points_data_on_cell( - metadata, dimensions, time_index, spherical + metadata, dimensions, time_index, spherical, slices ) # 2D or 3D @@ -175,20 +180,28 @@ def generate_mesh_points(metadata, dimensions, time_index, spherical): return mesh, data_location -def generate_mesh_points_data_on_cell(metadata, dimensions, time_index, spherical): +def generate_mesh_points_data_on_cell( + metadata, dimensions, time_index, spherical, slices +): + # Slice Ready data_location = "cell_data" + # 2D or 3D dims_size = len(dimensions) assert dims_size == 2 or dims_size == 3 - extent = [0, 0, 0, 0, 0, 0] + # compute extent between dimensions and slices + extent = metadata.dims_extent(dimensions, slices) n_points = 1 - for idx in range(len(dimensions)): - array = metadata.xr_dataset[dimensions[-(1 + idx)]] - # Fill in reverse order (t, z, y, x) => [0, x.size, 0, y.size, 0, z.size] - # And extent include both index and adding 1 point so (len-1+1) => len - extent[idx * 2 + 1] = array.size - n_points *= array.size + 1 + + # Increase extent by 1 since we add points to put data on cells + for i in range(3): + if extent[i * 2 + 1] > 0: + extent[i * 2 + 1] += 1 + + # Extract point count + for i in range(3): + n_points *= extent[i * 2 + 1] - extent[i * 2] + 1 # Points vtk_points = vtkPoints() @@ -196,26 +209,38 @@ def generate_mesh_points_data_on_cell(metadata, dimensions, time_index, spherica vtk_points.Allocate(n_points) # debug - # print(f"{dimensions=}") - # print(f"{extent=}") - # print(f"{n_points=}") + print(f"{dimensions=}") + print(f"{extent=}") + print(f"{n_points=}") + print(f"{slices=}") # Check if direct coord mapping if dims_size == 2: # 2D - x_array = cell_center_to_point(metadata.xr_dataset[metadata.longitude].values) - y_array = cell_center_to_point(metadata.xr_dataset[metadata.latitude].values) - z = 0 + x_array = cell_center_to_point( + slice_array(metadata.longitude, metadata.xr_dataset, slices) + ) + y_array = cell_center_to_point( + slice_array(metadata.latitude, metadata.xr_dataset, slices) + ) + z = metadata.vertical_bias for j in range(extent[3] + 1): lat = y_array[j] for i in range(extent[1] + 1): lon = x_array[i] point_insert(vtk_points, spherical, lon, lat, z) else: # 3D - x_array = cell_center_to_point(metadata.xr_dataset[metadata.longitude].values) - y_array = cell_center_to_point(metadata.xr_dataset[metadata.latitude].values) - z_array = cell_center_to_point(metadata.xr_dataset[metadata.vertical].values) + x_array = cell_center_to_point( + slice_array(metadata.longitude, metadata.xr_dataset, slices) + ) + y_array = cell_center_to_point( + slice_array(metadata.latitude, metadata.xr_dataset, slices) + ) + z_array = cell_center_to_point( + slice_array(metadata.vertical, metadata.xr_dataset, slices) + ) for k in range(extent[5] + 1): z = z_array[k] + print(f"{z=}") for j in range(extent[3] + 1): lat = y_array[j] for i in range(extent[1] + 1): diff --git a/pan3d/xarray/cf/mesh/uniform.py b/pan3d/xarray/cf/mesh/uniform.py index 0d070ae..dcd34a0 100644 --- a/pan3d/xarray/cf/mesh/uniform.py +++ b/pan3d/xarray/cf/mesh/uniform.py @@ -1,7 +1,7 @@ from vtkmodules.vtkCommonDataModel import vtkImageData -def generate_mesh(metadata, dimensions, time_index): +def generate_mesh(metadata, dimensions, time_index, slices): data_location = "cell_data" # data to capture diff --git a/pan3d/xarray/cf/mesh/unstructured.py b/pan3d/xarray/cf/mesh/unstructured.py index 8991645..c11cc70 100644 --- a/pan3d/xarray/cf/mesh/unstructured.py +++ b/pan3d/xarray/cf/mesh/unstructured.py @@ -1,4 +1,4 @@ -def generate_mesh(metadata, dimensions, spherical): +def generate_mesh(metadata, dimensions, spherical, slices): print(" => unstructured: cell_data") raise NotImplementedError("unstructured::generate_mesh") return False, "cell_data" diff --git a/pan3d/xarray/cf/reader.py b/pan3d/xarray/cf/reader.py index 02852cd..56c373d 100644 --- a/pan3d/xarray/cf/reader.py +++ b/pan3d/xarray/cf/reader.py @@ -1,5 +1,6 @@ from typing import List, Optional +import gc import xarray as xr import numpy as np import pandas as pd @@ -292,14 +293,26 @@ def t(self): @property def slice_extents(self): - """return a dictionary for the X, Y, Z dimensions with the corresponding extent [0, size-1]""" - # !!! can be different based on which field is selected !!! - # -> result not obvious based on the mesh type - return { - coord_name: [0, self.input[coord_name].size - 1] - for coord_name in [self.x, self.y, self.z] - if coord_name is not None + """ + return a dictionary for the field dimensions with + the corresponding extent [0, size-1]. + + For some dataset, it is possible that the extent does not have + a direct mapping with the coordinate system. + """ + fields = self.arrays + if not fields: + return {} + dims = self._mapping.timeless_dimensions(fields[0]) + # ensure order to match x, y, z + dims = dims[::-1] + + print(f"slice_extents::{dims=}") + result = { + coord_name: [0, self.input[coord_name].size - 1] for coord_name in dims } + print(f" => {result}") + return result @property def available_coords(self): @@ -334,6 +347,7 @@ def vertical_bias(self): def vertical_bias(self, v): if self._proj_vertical_bias != v: self._proj_vertical_bias = v + self._mapping.vertical_bias = v self.Modified() @property @@ -344,6 +358,7 @@ def vertical_scale(self): def vertical_scale(self, v): if self._proj_vertical_scale != v: self._proj_vertical_scale = v + self._mapping.vertical_scale = v self.Modified() # ------------------------------------------------------------------------- @@ -425,6 +440,7 @@ def slices(self, v): """update the slicing of the data along axes""" if v != self._slices: self._slices = v + print(f"{self._slices=}") if "time" in v: self.t_index = v.get("time", 0) @@ -569,12 +585,14 @@ def state(self): def RequestDataObject(self, request, inInfo, outInfo): output = self._mapping.get_vtk_mesh_type(self._proj_mode, self.arrays) outInfo.GetInformationObject(0).Set(vtkDataObject.DATA_OBJECT(), output) - print(f"RequestDataObject::{output.GetClassName()=}") + # print(f"RequestDataObject::{output.GetClassName()=}") return 1 def RequestInformation(self, request, inInfo, outInfo): - whole_extent = self._mapping.get_vtk_whole_extent(self._proj_mode, self.arrays) - print(f"{whole_extent=}") + whole_extent = self._mapping.get_vtk_whole_extent( + self._proj_mode, self.arrays, self._slices + ) + # print(f"{whole_extent=}") outInfo.GetInformationObject(0).Set( vtkStreamingDemandDrivenPipeline.WHOLE_EXTENT(), *whole_extent, @@ -586,12 +604,17 @@ def RequestData(self, request, inInfo, outInfo): """implementation of the vtk algorithm for generating the VTK mesh""" # Use open data_array handle to fetch data at # desired Level of Detail + print("RequestData") + print("+" * 60) mesh = self._mapping.get_vtk_mesh( - time_index=self.t_index, projection=self._proj_mode, fields=self.arrays + time_index=self.t_index, + projection=self._proj_mode, + fields=self.arrays, + slices=self._slices, ) if mesh is not None: - print(f"{mesh.extent=}") - print(f"{mesh.GetClassName()=}") + # print(f"{mesh.extent=}") + # print(f"{mesh.GetClassName()=}") pdo = self.GetOutputData(outInfo, 0) # Compute derived quantity @@ -601,4 +624,6 @@ def RequestData(self, request, inInfo, outInfo): else: pdo.ShallowCopy(mesh) + gc.collect() + return 1 From a8edf18d734c9de389cdac4c58b7d2ca3b232c6a Mon Sep 17 00:00:00 2001 From: Sebastien Jourdain Date: Mon, 27 Jan 2025 16:52:24 -0700 Subject: [PATCH 10/10] wip: more CF coverage with slicing --- pan3d/catalogs/xarray.py | 8 +- pan3d/ui/preview.py | 21 ++- pan3d/xarray/cf/coords/convert.py | 3 +- pan3d/xarray/cf/coords/meta.py | 8 +- pan3d/xarray/cf/coords/parametric_vertical.py | 18 ++- pan3d/xarray/cf/mesh/rectilinear.py | 25 +++- pan3d/xarray/cf/mesh/structured.py | 132 +++++++++++++++--- pan3d/xarray/cf/mesh/uniform.py | 39 +++++- pan3d/xarray/cf/reader.py | 131 +---------------- 9 files changed, 206 insertions(+), 179 deletions(-) diff --git a/pan3d/catalogs/xarray.py b/pan3d/catalogs/xarray.py index e28a0ff..60f851b 100644 --- a/pan3d/catalogs/xarray.py +++ b/pan3d/catalogs/xarray.py @@ -14,10 +14,10 @@ "description": "Dataset with ocean basins marked using integers", }, # ------------------------------------------------------------------------- - { - "name": "ASE_ice_velocity", - "description": "MEaSUREs InSAR-Based Ice Velocity of the Amundsen Sea Embayment, Antarctica, Version 1", - }, + # { + # "name": "ASE_ice_velocity", + # "description": "MEaSUREs InSAR-Based Ice Velocity of the Amundsen Sea Embayment, Antarctica, Version 1", + # }, { "name": "rasm", "description": "Output of the Regional Arctic System Model (RASM)", diff --git a/pan3d/ui/preview.py b/pan3d/ui/preview.py index 702c11d..eebdec2 100644 --- a/pan3d/ui/preview.py +++ b/pan3d/ui/preview.py @@ -278,7 +278,7 @@ def __init__(self, source, update_rendering): self.state.setdefault("max_time_width", 0) self.state.setdefault("max_time_index_width", 0) self.state.setdefault("dataset_bounds", [0, 1, 0, 1, 0, 1]) - self.state.setdefault("projection_mode", "spherical") + self.state.setdefault("projection_mode", "Spherical") with self.content: v3.VSelect( @@ -595,7 +595,7 @@ def __init__(self, source, update_rendering): # Projection mode + scaling with v3.VTooltip( text=( - "projection_mode === 'spherical' ? `Spherical Projection: scaling=${spherical_scale} bias=${spherical_bias}` : 'Euclidian Projection'", + "projection_mode === 'Spherical' ? `Spherical Projection: scaling=${spherical_scale} bias=${spherical_bias}` : 'Euclidean Projection'", ) ): with html.Template(v_slot_activator="{ props }"): @@ -606,10 +606,10 @@ def __init__(self, source, update_rendering): ): v3.VSelect( prepend_inner_icon=( - "projection_mode === 'spherical' ? 'mdi-earth' : 'mdi-earth-box'", + "projection_mode === 'Spherical' ? 'mdi-earth' : 'mdi-earth-box'", ), - v_model=("projection_mode", "spherical"), - items=("['spherical', 'euclidian']",), + v_model=("projection_mode", "Spherical"), + items=("['Spherical', 'Euclidean']",), hide_details=True, density="compact", flat=True, @@ -619,7 +619,7 @@ def __init__(self, source, update_rendering): with v3.VCol( no_gutter=True, classes="align-center my-0 mx-0 pa-0", - v_if="projection_mode === 'spherical'", + v_if="projection_mode === 'Spherical'", ): v3.VSlider( prepend_icon="mdi-radius-outline", @@ -747,9 +747,9 @@ def update_from_source(self, source=None): self.state.axis_names = [] self.state.slice_extents = source.slice_extents self.state.projection_mode = ( - "spherical" + "Spherical" if source.projection == Projection.SPHERICAL - else "euclidean" + else "Euclidean" ) self.state.spherical_bias = source.vertical_bias self.state.spherical_scale = source.vertical_scale @@ -836,8 +836,7 @@ def on_change(self, slice_t, **_): return slices = {self.source.t: slice_t} - for axis in XYZ: - axis_name = getattr(self.source, axis) + for axis, axis_name in zip(XYZ, self.source.slice_extents.keys()): if axis_name is None: continue @@ -887,7 +886,7 @@ def _on_array_selection(self, data_arrays, **_): def _on_projection_change( self, spherical_bias, spherical_scale, projection_mode, **_ ): - if projection_mode == "spherical": + if projection_mode == "Spherical": self.source.projection = Projection.SPHERICAL self.source.vertical_bias = spherical_bias self.source.vertical_scale = spherical_scale diff --git a/pan3d/xarray/cf/coords/convert.py b/pan3d/xarray/cf/coords/convert.py index d8ccd55..80eb961 100644 --- a/pan3d/xarray/cf/coords/convert.py +++ b/pan3d/xarray/cf/coords/convert.py @@ -49,8 +49,7 @@ def is_uniform(array): def cell_center_to_point(in_array): if in_array.size == 1: - print("size 1") - return [float(in_array)] + return in_array uniform_data = extract_uniform_info(in_array) if uniform_data is not None: diff --git a/pan3d/xarray/cf/coords/meta.py b/pan3d/xarray/cf/coords/meta.py index 6ac2496..4151f3a 100644 --- a/pan3d/xarray/cf/coords/meta.py +++ b/pan3d/xarray/cf/coords/meta.py @@ -491,10 +491,14 @@ def get_vtk_whole_extent(self, projection, fields=None, slices=None): # cell data, need to +1 on the extent for i in range(3): - if extent[i * 2 + 1] > 0: + if i < 2: + # minimum 2D cells + extent[i * 2 + 1] += 1 + elif extent[i * 2 + 1] > 0: + # Only add cell in 3rd dimension if content available extent[i * 2 + 1] += 1 - print(f"Whole extent: {extent}") + # print(f"Whole extent: {extent}") return extent diff --git a/pan3d/xarray/cf/coords/parametric_vertical.py b/pan3d/xarray/cf/coords/parametric_vertical.py index 7736e90..75367aa 100644 --- a/pan3d/xarray/cf/coords/parametric_vertical.py +++ b/pan3d/xarray/cf/coords/parametric_vertical.py @@ -15,6 +15,11 @@ # Factory method # ----------------------------------------------------------------------------- def get_formula(xr_dataset, name, bias=0, scale=1): + # No z array => no formula, just a constant + if name is None: + return ConstantFormulaAdapter(bias) + + # Let's extract formula array_attributes = xr_dataset[name].attrs std_name = array_attributes.get("standard_name") formula_terms = array_attributes.get("formula_terms") @@ -28,7 +33,8 @@ def get_formula(xr_dataset, name, bias=0, scale=1): scale=scale, ) - return None + # Fallback to constant + return ConstantFormulaAdapter(bias) # ----------------------------------------------------------------------------- @@ -76,6 +82,16 @@ def __call__(self, n=0, k=0, j=0, i=0): return self._bias + self._scale * self._fn(n=n, k=k, j=j, i=i) +class ConstantFormulaAdapter: + name = "__internal__" + + def __init__(self, const_value=0): + self._value = const_value + + def __call__(self, **_): + return self._value + + # ----------------------------------------------------------------------------- # Atmosphere natural log pressure coordinate # ----------------------------------------------------------------------------- diff --git a/pan3d/xarray/cf/mesh/rectilinear.py b/pan3d/xarray/cf/mesh/rectilinear.py index b95807e..b6938a1 100644 --- a/pan3d/xarray/cf/mesh/rectilinear.py +++ b/pan3d/xarray/cf/mesh/rectilinear.py @@ -1,9 +1,21 @@ import numpy as np from vtkmodules.vtkCommonDataModel import vtkRectilinearGrid -from ..coords.convert import cell_center_to_point +from ..coords.convert import cell_center_to_point, slice_array def generate_mesh(metadata, dimensions, time_index, slices): + """ + - [X] Initial implementation + - [X] Support range slicing + - [X] Support index slicing + - [ ] Automatic testing + + Testing process: + 1. load xarray tutorial dataset: eraint_uvz + 2. Switch projection to Euclidean + 3. Play with range sliders + 4. Switch one range slider to a cut + """ data_location = "cell_data" extent = [0, 0, 0, 0, 0, 0] empty_coords = np.zeros((1,), dtype=np.double) @@ -12,12 +24,15 @@ def generate_mesh(metadata, dimensions, time_index, slices): assert metadata.coords_1d for idx in range(len(dimensions)): - array = metadata.xr_dataset[dimensions[-(1 + idx)]] - arrays[idx] = cell_center_to_point(array.values) + name = dimensions[-(1 + idx)] + + arrays[idx] = cell_center_to_point( + slice_array(name, metadata.xr_dataset, slices) + ) # Fill in reverse order (t, z, y, x) => [0, x.size, 0, y.size, 0, z.size] - # And extent include both index but add 1 point so (len-1+1) => len - extent[idx * 2 + 1] = array.size + # And extent include both index (len-1) + extent[idx * 2 + 1] = arrays[idx].size - 1 mesh = vtkRectilinearGrid() mesh.x_coordinates = arrays[0] diff --git a/pan3d/xarray/cf/mesh/structured.py b/pan3d/xarray/cf/mesh/structured.py index 2043531..d0d88bb 100644 --- a/pan3d/xarray/cf/mesh/structured.py +++ b/pan3d/xarray/cf/mesh/structured.py @@ -10,21 +10,28 @@ def generate_mesh(metadata, dimensions, time_index, spherical, slices): # Data location and extend depend if we can extrapolate cell locations # bounds or uniform allow to define cell bounds if metadata.coords_has_bounds: - print(" => structured: bounds") return generate_bound_cells(metadata, dimensions, time_index, spherical, slices) if metadata.uniform_spacing: - print(" => structured: uniform spacing") return generate_uniform_cells( metadata, dimensions, time_index, spherical, slices ) # We can only figure out the point location - print(" => structured: on points") return generate_mesh_points(metadata, dimensions, time_index, spherical, slices) def generate_uniform_cells(metadata, dimensions, time_index, spherical, slices): + """ + - [X] Initial implementation + - [X] Support range slicing + - [X] Support index slicing + - [ ] Automatic testing + + xarray tutorial dataset: + - air_temperature + - ersstv5 + """ data_location = "cell_data" assert spherical @@ -40,8 +47,14 @@ def generate_uniform_cells(metadata, dimensions, time_index, spherical, slices): name = dimensions[-(1 + idx)] array = slice_array(name, metadata.xr_dataset, slices) - # axis origin/spacing + # axis origin/spacing (maybe) if array.size == 1: + origin[idx] = float(array) + all_array = metadata.xr_dataset[name] + if all_array.size > 1: + all_array = all_array[:2].values + spacing[idx] = float(all_array[1] - all_array[0]) + origin[idx] -= 0.5 * spacing[idx] continue axis_spacing = (array[-1] - array[0]) / (array.size - 1) @@ -92,12 +105,32 @@ def generate_uniform_cells(metadata, dimensions, time_index, spherical, slices): def generate_bound_cells(metadata, dimensions, time_index, spherical, slices): + """ + - [ ] Initial implementation + - [ ] Support range slicing + - [ ] Support index slicing + - [ ] Automatic testing + """ data_location = "cell_data" raise NotImplementedError("structured::generate_bound_cells") return False, data_location def generate_mesh_points(metadata, dimensions, time_index, spherical, slices): + """ + - [ ] Initial implementation + - [ ] Support range slicing + - [ ] Support index slicing + - [ ] Automatic testing + + xarray tutorial dataset: + - ROMS_example: CUSTOM Z FORMULA + - rasm: CUSTOM Z FORMULA but + - non CF format + - 2d coord remapping + + + """ data_location = "point_data" if ( @@ -105,7 +138,6 @@ def generate_mesh_points(metadata, dimensions, time_index, spherical, slices): and metadata.uniform_lat_lon and metadata.use_coords(dimensions) ): - print(" ==> move datat to cell") return generate_mesh_points_data_on_cell( metadata, dimensions, time_index, spherical, slices ) @@ -115,13 +147,34 @@ def generate_mesh_points(metadata, dimensions, time_index, spherical, slices): assert dims_size == 2 or dims_size == 3 extent = [0, 0, 0, 0, 0, 0] + dim_ranges = [ + (0, 1, 1), # i + (0, 1, 1), # j + (0, 1, 1), # k + ] n_points = 1 for idx in range(len(dimensions)): - array = metadata.xr_dataset[dimensions[-(1 + idx)]] - # Fill in reverse order (t, z, y, x) => [0, x.size, 0, y.size, 0, z.size] - # And extent include both index so (len-1) - extent[idx * 2 + 1] = array.size - 1 - n_points *= array.size + name = dimensions[-(1 + idx)] + array = slice_array(name, metadata.xr_dataset, slices) + all_array = metadata.xr_dataset[name] + if array.size != all_array.size: + slice_info = slices.get(name) + if isinstance(slice_info, int): + start = slice_info + dim_ranges[idx] = (start, start + 1) + else: + start, stop, step = slice_info + stop -= (stop - start) % step + size = int((stop - start) / step) + extent[idx * 2 + 1] = size - 1 + dim_ranges[idx] = (start, stop, step) + n_points *= size + else: + # Fill in reverse order (t, z, y, x) => [0, x.size, 0, y.size, 0, z.size] + # And extent include both index so (len-1) + extent[idx * 2 + 1] = array.size - 1 + n_points *= array.size + dim_ranges[idx] = (0, array.size, 1) # Points vtk_points = vtkPoints() @@ -135,27 +188,31 @@ def generate_mesh_points(metadata, dimensions, time_index, spherical, slices): # Check if direct coord mapping if metadata.coords_1d and metadata.use_coords(dimensions): + print(" => 1D coords - Need data to test") if dims_size == 2: # 2D + print(" => 2D dataset - Need data to test") x_array = metadata.xr_dataset[metadata.longitude].values y_array = metadata.xr_dataset[metadata.latitude].values z = 0 - for j in range(extent[3] + 1): + for j in range(*dim_ranges[1]): lat = y_array[j] - for i in range(extent[1] + 1): + for i in range(*dim_ranges[0]): lon = x_array[i] point_insert(vtk_points, spherical, lon, lat, z) else: # 3D + print(" => 3D dataset - Need data to test") x_array = metadata.xr_dataset[metadata.longitude].values y_array = metadata.xr_dataset[metadata.latitude].values z_array = metadata.xr_dataset[metadata.vertical].values - for k in range(extent[5] + 1): + for k in range(*dim_ranges[2]): z = z_array[k] - for j in range(extent[3] + 1): + for j in range(*dim_ranges[1]): lat = y_array[j] - for i in range(extent[1] + 1): + for i in range(*dim_ranges[0]): lon = x_array[i] point_insert(vtk_points, spherical, lon, lat, z) else: + # CUSTOM Z FORMULA # need some index mapping z_formula = get_z_formula( metadata.xr_dataset, @@ -164,9 +221,9 @@ def generate_mesh_points(metadata, dimensions, time_index, spherical, slices): scale=metadata.vertical_scale, ) coords_formula = get_coords_formula(metadata, dimensions) - for k in range(extent[5] + 1): - for j in range(extent[3] + 1): - for i in range(extent[1] + 1): + for k in range(*dim_ranges[2]): + for j in range(*dim_ranges[1]): + for i in range(*dim_ranges[0]): lon, lat = coords_formula(i=i, j=j, k=k) z = z_formula(i=i, j=j, k=k, n=time_index) # print(f"{time_index=}, {k=}, {j=}, {i=} = {lon=}, {lat=}, {z=}") @@ -183,6 +240,22 @@ def generate_mesh_points(metadata, dimensions, time_index, spherical, slices): def generate_mesh_points_data_on_cell( metadata, dimensions, time_index, spherical, slices ): + """ + - [X] Initial implementation + - [x] Support range slicing + - [x] Support index slicing + - [ ] Automatic testing + + Testing process 2D: <------------------- MISSING + 1. xxxx + 2. Play with range sliders + 3. Switch one range slider to a cut + + Testing process 3D: + 1. load xarray tutorial dataset: eraint_uvz | basin_mask + 2. Play with range sliders + 3. Switch one range slider to a cut + """ # Slice Ready data_location = "cell_data" @@ -209,19 +282,27 @@ def generate_mesh_points_data_on_cell( vtk_points.Allocate(n_points) # debug - print(f"{dimensions=}") - print(f"{extent=}") - print(f"{n_points=}") - print(f"{slices=}") + # print(f"{dimensions=}") + # print(f"{extent=}") + # print(f"{n_points=}") + # print(f"{slices=}") # Check if direct coord mapping if dims_size == 2: # 2D + print("#" * 60) + print("structured::generate_mesh_points_data_on_cell") + print("We are in 2D mode\n" * 5) + print("#" * 60) x_array = cell_center_to_point( slice_array(metadata.longitude, metadata.xr_dataset, slices) ) y_array = cell_center_to_point( slice_array(metadata.latitude, metadata.xr_dataset, slices) ) + if y_array.size == 1: + y_array = [y_array] + if x_array.size == 1: + x_array = [x_array] z = metadata.vertical_bias for j in range(extent[3] + 1): lat = y_array[j] @@ -238,9 +319,14 @@ def generate_mesh_points_data_on_cell( z_array = cell_center_to_point( slice_array(metadata.vertical, metadata.xr_dataset, slices) ) + if z_array.size == 1: + z_array = [z_array] + if y_array.size == 1: + y_array = [y_array] + if x_array.size == 1: + x_array = [x_array] for k in range(extent[5] + 1): z = z_array[k] - print(f"{z=}") for j in range(extent[3] + 1): lat = y_array[j] for i in range(extent[1] + 1): diff --git a/pan3d/xarray/cf/mesh/uniform.py b/pan3d/xarray/cf/mesh/uniform.py index dcd34a0..3826de9 100644 --- a/pan3d/xarray/cf/mesh/uniform.py +++ b/pan3d/xarray/cf/mesh/uniform.py @@ -1,7 +1,20 @@ from vtkmodules.vtkCommonDataModel import vtkImageData +from ..coords.convert import slice_array def generate_mesh(metadata, dimensions, time_index, slices): + """ + - [X] Initial implementation + - [X] Support range slicing + - [X] Support index slicing + - [ ] Automatic testing + + Testing process: + 1. load xarray tutorial dataset: air_temperature + 2. Switch projection to Euclidean + 3. Play with range sliders + 4. Switch one range slider to a cut + """ data_location = "cell_data" # data to capture @@ -11,18 +24,34 @@ def generate_mesh(metadata, dimensions, time_index, slices): # extract information from dimensions for idx in range(len(dimensions)): - array = metadata.xr_dataset[dimensions[-(1 + idx)]] + name = dimensions[-(1 + idx)] + + array = slice_array(name, metadata.xr_dataset, slices) # Fill in reverse order (t, z, y, x) => [0, x.size, 0, y.size, 0, z.size] # Use size as end of extent as we are adding 1 point for map data on cell extent[idx * 2 + 1] = array.size + # No spacing just origin (maybe) + if array.size == 1: + origin[idx] = float(array) + all_array = metadata.xr_dataset[name] + if all_array.size > 1: + all_array = all_array[:2].values + spacing[idx] = float(all_array[1] - all_array[0]) + origin[idx] -= 0.5 * spacing[idx] + continue + # axis origin/spacing - axis_spacing = (array[-1].values - array[0].values) / (array.size - 1) - axis_origin = float(array[0].values) - axis_spacing * 0.5 + axis_spacing = (array[-1] - array[0]) / (array.size - 1) + axis_origin = float(array[0]) - axis_spacing * 0.5 # update global origin/spacing - origin[idx] = axis_origin - spacing[idx] = axis_spacing + origin[idx] = float(axis_origin) + spacing[idx] = float(axis_spacing) + + # print(f"{origin=}") + # print(f"{spacing=}") + # print(f"{extent=}") # Configure mesh mesh = vtkImageData(origin=origin, spacing=spacing, extent=extent) diff --git a/pan3d/xarray/cf/reader.py b/pan3d/xarray/cf/reader.py index 56c373d..cabc911 100644 --- a/pan3d/xarray/cf/reader.py +++ b/pan3d/xarray/cf/reader.py @@ -55,126 +55,6 @@ def get_time_labels(times): # ----------------------------------------------------------------------------- # VTK Algorithms # ----------------------------------------------------------------------------- - - -class CFHelperXXX: - def __init__(self, xr_dataset): - self._xr_dataset = xr_dataset - self._meta = MetaArrayMapping(xr_dataset) - self._t_index = 0 - self._mesh = None - self._array_names = [] - self._spherical = True - - def update(self, xr_dataset): - self._array_names = [] - self._mesh = None - self._xr_dataset = xr_dataset - self._meta.update(xr_dataset) - - def _reset(self): - self._mesh = None - - @property - def x(self): - return self._meta.longitude - - @property - def y(self): - return self._meta.latitude - - @property - def z(self): - return self._meta.vertical - - @property - def t(self): - return self._meta.time - - @property - def spherical(self): - return self._spherical - - @spherical.setter - def spherical(self, v): - if self._spherical != v: - self._spherical = v - self._reset() - - @property - def vertical_bias(self): - return self._meta.vertical_bias - - @vertical_bias.setter - def vertical_bias(self, v): - if self._meta.vertical_bias != v: - self._meta.vertical_bias = v - self._reset() - - @property - def vertical_scale(self): - return self._meta.vertical_scale - - @vertical_scale.setter - def vertical_scale(self, v): - if self._meta.vertical_scale != v: - self._meta.vertical_scale = v - self._reset() - - @property - def t_index(self): - """return the current selected time index""" - return self._t_index - - @t_index.setter - def t_index(self, t_index: int): - """update the current selected time index""" - if t_index != self._t_index: - self._t_index = t_index - self._reset() - - @property - def t_size(self): - """return the size of the coordinate used for the time""" - if self.t is None or self._xr_dataset is None: - return 0 - return int(self._xr_dataset[self.t].size) - - @property - def arrays(self): - """return the list of arrays that are currently selected to be added to the generated VTK mesh""" - return list(self._array_names) - - @arrays.setter - def arrays(self, array_names: List[str]): - """update the list of arrays to load on the generated VTK mesh""" - new_names = set(array_names or []) - if new_names != self._array_names: - self._array_names = new_names - self._reset() - - @property - def available_arrays(self): - all_list = [] - for arrays in self._meta.data_arrays.values(): - if self._array_names: - if self._array_names.intersection(arrays): - # only add compatible arrays - all_list.extend(arrays) - else: - # when no array selected, add all of them - all_list.extend(arrays) - - return all_list - - @property - def mesh(self): - if self._mesh is None: - self._mesh = self._meta.get_mesh(self._t_index, self.spherical, self.arrays) - - return self._mesh - - class vtkXArrayCFSource(VTKPythonAlgorithmBase): """vtk source for converting XArray into a VTK mesh""" @@ -307,11 +187,11 @@ def slice_extents(self): # ensure order to match x, y, z dims = dims[::-1] - print(f"slice_extents::{dims=}") + # print(f"slice_extents::{dims=}") result = { coord_name: [0, self.input[coord_name].size - 1] for coord_name in dims } - print(f" => {result}") + # print(f" => {result}") return result @property @@ -440,7 +320,7 @@ def slices(self, v): """update the slicing of the data along axes""" if v != self._slices: self._slices = v - print(f"{self._slices=}") + # print(f"{self._slices=}") if "time" in v: self.t_index = v.get("time", 0) @@ -592,7 +472,7 @@ def RequestInformation(self, request, inInfo, outInfo): whole_extent = self._mapping.get_vtk_whole_extent( self._proj_mode, self.arrays, self._slices ) - # print(f"{whole_extent=}") + # print(f"RequestInformation::{whole_extent=}") outInfo.GetInformationObject(0).Set( vtkStreamingDemandDrivenPipeline.WHOLE_EXTENT(), *whole_extent, @@ -604,8 +484,7 @@ def RequestData(self, request, inInfo, outInfo): """implementation of the vtk algorithm for generating the VTK mesh""" # Use open data_array handle to fetch data at # desired Level of Detail - print("RequestData") - print("+" * 60) + # print("RequestData") mesh = self._mapping.get_vtk_mesh( time_index=self.t_index, projection=self._proj_mode,