diff --git a/.codespellrc b/.codespellrc index 5c5ba3b..84f98a9 100644 --- a/.codespellrc +++ b/.codespellrc @@ -1,2 +1,3 @@ [codespell] skip = CHANGELOG.md,tests/data/* +ignore-words-list = degreee \ No newline at end of file diff --git a/.gitignore b/.gitignore index d3213a4..62ac0f0 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ node_modules .venv* *.nc *.data-origin.json +*.vtk examples/jupyter/*.gif site/* diff --git a/pan3d/catalogs/xarray.py b/pan3d/catalogs/xarray.py index ece97c5..60f851b 100644 --- a/pan3d/catalogs/xarray.py +++ b/pan3d/catalogs/xarray.py @@ -18,21 +18,14 @@ # "name": "ASE_ice_velocity", # "description": "MEaSUREs InSAR-Based Ice Velocity of the Amundsen Sea Embayment, Antarctica, Version 1", # }, - # ------------------------------------------------------------------------- - # { - # "name": "rasm", - # "description": "Output of the Regional Arctic System Model (RASM)", - # }, - # ------------------------------------------------------------------------- - # { - # "name": "ROMS_example", - # "description": "Regional Ocean Model System (ROMS) output", - # }, - # ------------------------------------------------------------------------- - # { - # "name": "tiny", - # "description": "small synthetic dataset with a 1D data variable", - # }, + { + "name": "rasm", + "description": "Output of the Regional Arctic System Model (RASM)", + }, + { + "name": "ROMS_example", + "description": "Regional Ocean Model System (ROMS) output", + }, # ------------------------------------------------------------------------- # needs pandas[xarray] # { diff --git a/pan3d/ui/preview.py b/pan3d/ui/preview.py index b87b7db..eebdec2 100644 --- a/pan3d/ui/preview.py +++ b/pan3d/ui/preview.py @@ -11,6 +11,7 @@ from pan3d.ui.css import base, preview from pan3d.ui.collapsible import CollapsableSection +from pan3d.xarray.cf.constants import Projection class SummaryToolbar(v3.VCard): @@ -226,6 +227,19 @@ def update_information(self, xr, available_arrays=None): "value": f"[{xr[name].values[0]}, {xr[name].values[-1]}]", } ) + elif len(shape) > 1: + attrs.append( + { + "key": "dims", + "value": f'({", ".join(xr[name].dims)})', + } + ) + attrs.append( + { + "key": "shape", + "value": f'({", ".join([str(v) for v in xr[name].shape])})', + } + ) if name in data: icon = "mdi-database" order = 2 @@ -264,6 +278,7 @@ def __init__(self, source, update_rendering): self.state.setdefault("max_time_width", 0) self.state.setdefault("max_time_index_width", 0) self.state.setdefault("dataset_bounds", [0, 1, 0, 1, 0, 1]) + self.state.setdefault("projection_mode", "Spherical") with self.content: v3.VSelect( @@ -528,10 +543,11 @@ def __init__(self, source, update_rendering): size="sm", classes="mx-2", ) - v3.VDivider() # Slice steps - with v3.VTooltip(text="Level Of Details / Slice stepping"): + with v3.VTooltip( + text="Level Of Details / Slice stepping", v_if="axis_names.length" + ): with html.Template(v_slot_activator="{ props }"): with v3.VRow( v_bind="props", @@ -542,7 +558,7 @@ def __init__(self, source, update_rendering): "mdi-stairs", classes="ml-2 text-medium-emphasis", ) - with v3.VCol(classes="pa-0", v_if="axis_names?.[0]"): + with v3.VCol(classes="pa-0", v_if="axis_names.length > 0"): v3.VTextField( v_model_number=("slice_x_step", 1), hide_details=True, @@ -553,7 +569,7 @@ def __init__(self, source, update_rendering): raw_attrs=['min="1"'], type="number", ) - with v3.VCol(classes="pa-0", v_if="axis_names?.[1]"): + with v3.VCol(classes="pa-0", v_if="axis_names.length > 1"): v3.VTextField( v_model_number=("slice_y_step", 1), hide_details=True, @@ -564,7 +580,7 @@ def __init__(self, source, update_rendering): raw_attrs=['min="1"'], type="number", ) - with v3.VCol(classes="pa-0", v_if="axis_names?.[2]"): + with v3.VCol(classes="pa-0", v_if="axis_names.length > 2"): v3.VTextField( v_model_number=("slice_z_step", 1), hide_details=True, @@ -576,63 +592,115 @@ def __init__(self, source, update_rendering): type="number", ) - # Actor scaling - with v3.VTooltip(text="Representation scaling"): + # Projection mode + scaling + with v3.VTooltip( + text=( + "projection_mode === 'Spherical' ? `Spherical Projection: scaling=${spherical_scale} bias=${spherical_bias}` : 'Euclidean Projection'", + ) + ): with html.Template(v_slot_activator="{ props }"): - with v3.VRow( + with v3.VCol( v_bind="props", no_gutter=True, - classes="align-center my-0 mx-0 border-b-thin", + classes="align-center pa-0 my-0 mx-0 border-b-thin", ): - v3.VIcon( - "mdi-ruler-square", - classes="ml-2 text-medium-emphasis", + v3.VSelect( + prepend_inner_icon=( + "projection_mode === 'Spherical' ? 'mdi-earth' : 'mdi-earth-box'", + ), + v_model=("projection_mode", "Spherical"), + items=("['Spherical', 'Euclidean']",), + hide_details=True, + density="compact", + flat=True, + variant="solo", + classes="mx-n1", ) - with v3.VCol(classes="pa-0", v_if="axis_names?.[0]"): - v3.VTextField( - v_model=("scale_x", 1), + with v3.VCol( + no_gutter=True, + classes="align-center my-0 mx-0 pa-0", + v_if="projection_mode === 'Spherical'", + ): + v3.VSlider( + prepend_icon="mdi-radius-outline", + v_model=("spherical_bias", 6378137), + min=1, + max=6378137, + step=100, hide_details=True, density="compact", flat=True, variant="solo", - reverse=True, - raw_attrs=[ - 'pattern="^\d*(\.\d)?$"', # noqa: W605 - 'min="0.001"', - 'step="0.1"', - ], - type="number", + classes="mr-4", + end=self.server.controller.view_update, ) - with v3.VCol(classes="pa-0", v_if="axis_names?.[1]"): - v3.VTextField( - v_model=("scale_y", 1), + v3.VSlider( + prepend_icon="mdi-magnify", + v_model=("spherical_scale", 100), + min=1, + max=1000, + step=20, hide_details=True, density="compact", flat=True, variant="solo", - reverse=True, - raw_attrs=[ - 'pattern="^\d*(\.\d)?$"', # noqa: W605 - 'min="0.001"', - 'step="0.1"', - ], - type="number", + classes="mr-4", + end=self.server.controller.view_update, ) - with v3.VCol(classes="pa-0", v_if="axis_names?.[2]"): - v3.VTextField( - v_model=("scale_z", 1), - hide_details=True, - density="compact", - flat=True, - variant="solo", - reverse=True, - raw_attrs=[ - 'pattern="^\d*(\.\d)?$"', # noqa: W605 - 'min="0.001"', - 'step="0.1"', - ], - type="number", + with v3.VRow( + no_gutter=True, + classes="align-center my-0 mx-0", + v_else=True, + ): + v3.VIcon( + "mdi-ruler-square", + classes="ml-2 text-medium-emphasis", ) + with v3.VCol(classes="pa-0", v_if="axis_names?.[0]"): + v3.VTextField( + v_model=("scale_x", 1), + hide_details=True, + density="compact", + flat=True, + variant="solo", + reverse=True, + raw_attrs=[ + 'pattern="^\d*(\.\d)?$"', # noqa: W605 + 'min="0.001"', + 'step="0.1"', + ], + type="number", + ) + with v3.VCol(classes="pa-0", v_if="axis_names?.[1]"): + v3.VTextField( + v_model=("scale_y", 1), + hide_details=True, + density="compact", + flat=True, + variant="solo", + reverse=True, + raw_attrs=[ + 'pattern="^\d*(\.\d)?$"', # noqa: W605 + 'min="0.001"', + 'step="0.1"', + ], + type="number", + ) + with v3.VCol(classes="pa-0", v_if="axis_names?.[2]"): + v3.VTextField( + v_model=("scale_z", 1), + hide_details=True, + density="compact", + flat=True, + variant="solo", + reverse=True, + raw_attrs=[ + 'pattern="^\d*(\.\d)?$"', # noqa: W605 + 'min="0.001"', + 'step="0.1"', + ], + type="number", + ) # Time slider with v3.VTooltip( @@ -676,19 +744,29 @@ def update_from_source(self, source=None): self.state.data_arrays_available = source.available_arrays self.state.data_arrays = source.arrays self.state.color_by = None - self.state.axis_names = [source.x, source.y, source.z] + self.state.axis_names = [] self.state.slice_extents = source.slice_extents + self.state.projection_mode = ( + "Spherical" + if source.projection == Projection.SPHERICAL + else "Euclidean" + ) + self.state.spherical_bias = source.vertical_bias + self.state.spherical_scale = source.vertical_scale slices = source.slices - for axis in XYZ: + for idx, name in enumerate(self.state.slice_extents): + axis = XYZ[idx] + self.state.axis_names.append(name) + self.state.dirty("axis_names") # default - axis_extent = self.state.slice_extents.get(getattr(source, axis)) + axis_extent = self.state.slice_extents.get(name) self.state[f"slice_{axis}_range"] = axis_extent self.state[f"slice_{axis}_cut"] = 0 self.state[f"slice_{axis}_step"] = 1 self.state[f"slice_{axis}_type"] = "range" # use slice info if available - axis_slice = slices.get(getattr(source, axis)) + axis_slice = slices.get(name) if axis_slice is not None: if isinstance(axis_slice, int): # cut @@ -758,8 +836,7 @@ def on_change(self, slice_t, **_): return slices = {self.source.t: slice_t} - for axis in XYZ: - axis_name = getattr(self.source, axis) + for axis, axis_name in zip(XYZ, self.source.slice_extents.keys()): if axis_name is None: continue @@ -776,6 +853,7 @@ def on_change(self, slice_t, **_): self.source.slices = slices ds = self.source() + self.state.dataset_bounds = ds.bounds self.ctrl.view_reset_clipping_range() @@ -800,7 +878,24 @@ def _on_array_selection(self, data_arrays, **_): elif len(data_arrays) == 0: self.state.color_by = None - self.source.arrays = data_arrays + if set(self.source.arrays) != set(data_arrays): + self.source.arrays = data_arrays + self.update_from_source(self.source) + + @change("spherical_bias", "spherical_scale", "projection_mode") + def _on_projection_change( + self, spherical_bias, spherical_scale, projection_mode, **_ + ): + if projection_mode == "Spherical": + self.source.projection = Projection.SPHERICAL + self.source.vertical_bias = spherical_bias + self.source.vertical_scale = spherical_scale + else: + self.source.projection = Projection.EUCLIDEAN + self.source.vertical_bias = 0 + self.source.vertical_scale = 1 + + self.ctrl.view_reset_camera() class ControlPanel(v3.VCard): diff --git a/pan3d/viewers/preview.py b/pan3d/viewers/preview.py index aa12e6c..83f8dde 100644 --- a/pan3d/viewers/preview.py +++ b/pan3d/viewers/preview.py @@ -25,7 +25,9 @@ from trame.ui.vuetify3 import VAppLayout from trame.widgets import vuetify3 as v3 -from pan3d.xarray.algorithm import vtkXArrayRectilinearSource +from pan3d.xarray.cf.reader import vtkXArrayCFSource + +# from pan3d.xarray.algorithm import vtkXArrayRectilinearSource from pan3d.utils.constants import has_gpu from pan3d.utils.convert import update_camera, to_image, to_float @@ -129,7 +131,17 @@ def _setup_vtk(self): self.interactor.GetInteractorStyle().SetCurrentStyleToTrackballCamera() self.lut = vtkLookupTable() - self.source = vtkXArrayRectilinearSource(input=self.xarray) + + self.source = vtkXArrayCFSource(input=self.xarray) + # if "PAN3D_USE_VTK_XARRAY" in os.environ: + # try: + # from pan3d.xarray.vtk import vtkXArraySource + + # self.source = vtkXArraySource(input=self.xarray) + # except ImportError: + # self.source = vtkXArrayRectilinearSource(input=self.xarray) + # else: + # self.source = vtkXArrayRectilinearSource(input=self.xarray) # Need explicit geometry extraction when used with WASM self.geometry = vtkDataSetSurfaceFilter( @@ -278,6 +290,17 @@ def _on_color_by(self, color_by, **__): self.mapper.SetScalarModeToUsePointFieldData() self.mapper.InterpolateScalarsBeforeMappingOn() self.mapper.SetScalarVisibility(1) + elif color_by in ds.cell_data.keys(): + array = ds.cell_data[color_by] + min_value, max_value = array.GetRange() + + self.state.color_min = min_value + self.state.color_max = max_value + + self.mapper.SelectColorArray(color_by) + self.mapper.SetScalarModeToUseCellFieldData() + self.mapper.InterpolateScalarsBeforeMappingOn() + self.mapper.SetScalarVisibility(1) else: self.mapper.SetScalarVisibility(0) self.state.color_min = 0 @@ -299,13 +322,16 @@ def _on_color_preset( self.ctrl.view_update() - @change("scale_x", "scale_y", "scale_z") - def _on_scale_change(self, scale_x, scale_y, scale_z, **_): - self.actor.SetScale( - to_float(scale_x), - to_float(scale_y), - to_float(scale_z), - ) + @change("scale_x", "scale_y", "scale_z", "projection_mode") + def _on_scale_change(self, scale_x, scale_y, scale_z, projection_mode, **_): + if projection_mode == "spherical": + self.actor.SetScale(1, 1, 1) + else: + self.actor.SetScale( + to_float(scale_x), + to_float(scale_y), + to_float(scale_z), + ) if self.state.import_pending: return diff --git a/pan3d/xarray/cf/__init__.py b/pan3d/xarray/cf/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pan3d/xarray/cf/__main__.py b/pan3d/xarray/cf/__main__.py new file mode 100644 index 0000000..b6008d6 --- /dev/null +++ b/pan3d/xarray/cf/__main__.py @@ -0,0 +1,130 @@ +import sys +from xarray.tutorial import open_dataset +from .utils import XArrayDataSetCFHelper +from pathlib import Path +import xarray as xr +from vtkmodules.vtkIOLegacy import vtkDataSetWriter + +from pan3d.catalogs.xarray import ALL_ENTRIES +from .coords.meta import MetaArrayMapping + + +def main(): + output_name = "test.vtk" + if Path(sys.argv[-1]).exists(): + input_file = Path(sys.argv[-1]).resolve() + output_name = f"{input_file.name}.vtk" + ds = xr.open_dataset(str(input_file)) + else: + output_name = f"{sys.argv[-1]}.vtk" + ds = open_dataset(sys.argv[-1]) + + helper = XArrayDataSetCFHelper(ds) + print("Arrays:", helper.available_arrays) + + helper.array_selection = [helper.available_arrays[0]] + + print("=" * 60) + print("Coordinates:") + print("=" * 60) + for name in ds.coords: + print(helper.get_info(name)) + print("=" * 60) + print("Data:") + print("=" * 60) + for name in ds.data_vars: + print(helper.get_info(name)) + print("=" * 60) + + if helper.x_info: + print(f"{'-'*60}\nX\n{'-'*60}", helper.x_info) + if helper.y_info: + print(f"{'-'*60}\nY\n{'-'*60}", helper.y_info) + if helper.z_info: + print(f"{'-'*60}\nZ\n{'-'*60}", helper.z_info) + if helper.t_info: + print(f"{'-'*60}\nT\n{'-'*60}", helper.t_info) + + mesh = helper.mesh + print(mesh) + + writer = vtkDataSetWriter() + writer.SetInputData(mesh) + writer.SetFileName(output_name) + writer.Write() + + +# ----------------------------------------------------------------------------- +def main2(): + DATASETS = [item.get("name") for item in ALL_ENTRIES] + FILES = ["/Users/sebastien.jourdain/Downloads/sampleGenGrid3.nc"] + for ds_name in DATASETS: + ds = open_dataset(ds_name) + meta = MetaArrayMapping(ds) + + print("-" * 60) + print(ds_name) + print("-" * 60) + print(meta) + + for f_path in FILES: + input_file = Path(f_path).resolve() + ds = xr.open_dataset(str(input_file)) + meta = MetaArrayMapping(ds) + + print("-" * 60) + print(input_file.name) + print("-" * 60) + print(meta) + + +# ----------------------------------------------------------------------------- +def save_dataset(ds_name, field): + ds = open_dataset(ds_name) + meta = MetaArrayMapping(ds) + + print(meta) + + print("-" * 60) + print("spherical=ON") + salt_spherical = meta.get_mesh(fields=[field], spherical=True) + + print("-" * 60) + print("spherical=OFF") + salt_euclidian = meta.get_mesh(fields=[field], spherical=False) + print("-" * 60) + + writer = vtkDataSetWriter() + + if salt_spherical: + writer.SetInputData(salt_spherical) + writer.SetFileName(f"{ds_name}-{field}-spherical.vtk") + writer.Write() + + if salt_euclidian: + writer.SetInputData(salt_euclidian) + writer.SetFileName(f"{ds_name}-{field}-euclidian.vtk") + writer.Write() + + +def main3(): + write = [ + # ("ROMS_example","salt"), # ok + # ("ROMS_example","zeta"), # ok + # ("air_temperature", "air"), # ok + # ("air_temperature_gradient", "Tair"), # ok + # ("basin_mask", "basin"), # ok + # ("rasm", "Tair"), # not following convention + ("eraint_uvz", "v"), + # ("ersstv5", "sst"), # bounds + ] + + for ds, field in write: + print("#" * 80) + print(ds) + print("#" * 80) + save_dataset(ds, field) + + +if __name__ == "__main__": + main3() diff --git a/pan3d/xarray/cf/constants.py b/pan3d/xarray/cf/constants.py new file mode 100644 index 0000000..0570d48 --- /dev/null +++ b/pan3d/xarray/cf/constants.py @@ -0,0 +1,367 @@ +from enum import Enum +import re +import numpy as np +from vtkmodules.vtkCommonDataModel import ( + vtkImageData, + vtkRectilinearGrid, + vtkStructuredGrid, + vtkUnstructuredGrid, +) + +LATITUDE_REGEXP = "degrees?_?n" # degrees_north +LONGITUDE_REGEXP = "degrees?_?e" # degrees_east + + +class DimensionInformation: + def __init__(self, xr_dataset, name): + coord = xr_dataset[name] + + self._xr_dataset = xr_dataset + self.name = name + self.attrs = coord.attrs + self.origin = 0 + self.spacing = 1 + self.has_regular_spacing = True # will be updated later + self.unit = Units.UNDEFINED_UNITS + self.dims = coord.dims + self.has_bounds = False + self.dim_size = -1 + + # unit handling + self.unit = Units.extract(coord) + + # is as direct mapping + if coord.dims == (name,): + # coordinates handling (origin, spacing, has_regular_spacing) + self.dim_size = 1 + self.coords = coord.values + self.origin = float(coord.values[0]) + self.spacing = (float(coord.values[-1]) - self.origin) / (coord.size - 1) + tolerance = 0.01 * self.spacing + + for i in range(coord.size): + expected = self.origin + i * self.spacing + truth = float(coord.values[i]) + if not np.isclose(expected, truth, atol=tolerance): + self.has_regular_spacing = False + break + + # direction + if coord.attrs.get("positive", "") == "down": + self.spacing *= -1 + + # bounds + bounds = coord.attrs.get("bounds") + if bounds: + # TODO ? only min/max + self.bounds = None + self.has_bounds = bounds + elif np.issubdtype(coord.dtype, np.number): + self.coords = coord.values + self.bounds = np.zeros(coord.size + 1, dtype=np.double) + self.bounds[0] = float(coord.values[0]) - 0.5 * self.spacing + for i in range(1, coord.size): + center = 0.5 * float(coord.values[i - 1] + coord.values[i]) + self.bounds[i] = center + self.bounds[-1] = float(coord.values[-1]) + 0.5 * self.spacing + else: # Fake coordinates + xr_dataset + # TODO should do a better job + self.origin = 0 + self.spacing = 1 + self.coords = np.linspace( + 0, coord.size - 1, num=coord.size, dtype=np.double + ) + self.bounds = np.linspace( + -0.5, coord.size - 0.5, num=coord.size + 1, dtype=np.double + ) + + def __repr__(self): + return f""" +{self.name}: + - origin : {self.origin} + - spacing : {self.spacing} + - uniform : {self.has_regular_spacing} + - unit : {self.unit} + - dimensions : {self.dims} + - bounds : {self.has_bounds} + - attributes : {self.attrs} + """ + + +class Projection(Enum): + SPHERICAL = "Spherical" + EUCLIDEAN = "Euclidean" + + +class Scale(Enum): + da = (1e1, {"deca", "deka"}) + h = (1e2, {"hecto"}) + k = (1e3, {"kilo"}) + M = (1e6, {"mega"}) + G = (1e9, {"giga"}) + T = (1e12, {"tera"}) + P = (1e15, {"peta"}) + E = (1e18, {"exa"}) + Z = (1e21, {"zetta"}) + Y = (1e24, {"yotta"}) + d = (1e-1, {"deci"}) + c = (1e-2, {"centi"}) + m = (1e-3, {"milli"}) + u = (1e-6, {"micro"}) + n = (1e-9, {"nano"}) + p = (1e-12, {"pico"}) + f = (1e-15, {"femto"}) + a = (1e-18, {"atto"}) + z = (1e-21, {"zepto"}) + y = (1e-24, {"yocto"}) + IDENTITY = (1, {}) + + def __float__(self): + return self.value[0] + + @classmethod + def from_name(cls, name): + for scale in cls: + if name in scale.value[1]: + return scale + return cls.IDENTITY + + +class Units(Enum): + UNDEFINED_UNITS = "undefined" + TIME_UNITS = "time" + LATITUDE_UNITS = "latitude" + LONGITUDE_UNITS = "longitude" + VERTICAL_UNITS = "vertical" + NUMBER_OF_UNITS = "number" + + @classmethod + def extract(cls, xr_array): + # 1. check "units" + units = xr_array.attrs.get("units", "").lower() + if units: + # time + if units in KNOWN_TIME_UNITS: + return cls.TIME_UNITS + for time_str in KNOWN_TIME_CONTAINS: + if time_str in units: + return cls.TIME_UNITS + + # latitude + if re.search(LATITUDE_REGEXP, units) is not None: + return cls.LATITUDE_UNITS + + # longitude + if re.search(LONGITUDE_REGEXP, units) is not None: + return cls.LONGITUDE_UNITS + + # 2. check "axis" + axis = xr_array.attrs.get("axis", "").lower() + if axis: + if axis == "x": + return cls.LONGITUDE_UNITS + if axis == "y": + return cls.LATITUDE_UNITS + if axis == "z": + return cls.VERTICAL_UNITS + if axis == "t": + return cls.TIME_UNITS + + # 3. use field name + field_name = xr_array.name.lower() + if "time" in field_name: + return cls.TIME_UNITS + if "lat" in field_name: + return cls.LATITUDE_UNITS + if "lon" in field_name: + return cls.LONGITUDE_UNITS + + # 4. data type + if np.issubdtype(xr_array.dtype, np.datetime64): + print("time because of type datetime64") + return cls.TIME_UNITS + if np.issubdtype(xr_array.dtype, np.dtype("O")): + print("time because of type O") + return cls.TIME_UNITS + + # Don't know + return cls.UNDEFINED_UNITS + + +MESH_3D_UNITS = { + Units.LATITUDE_UNITS, + Units.LONGITUDE_UNITS, + Units.VERTICAL_UNITS, +} +MESH_2D_UNITS = { + Units.LATITUDE_UNITS, + Units.LONGITUDE_UNITS, +} + + +class MeshTypes(Enum): + VTK_IMAGE_DATA = vtkImageData + VTK_RECTILINEAR_GRID = vtkRectilinearGrid + VTK_STRUCTURED_GRID = vtkStructuredGrid + VTK_UNSTRUCTURED_GRID = vtkUnstructuredGrid + + @classmethod + def from_coord_type(cls, coord_type): + if coord_type == CoordinateTypes.UNIFORM_RECTILINEAR: + return cls.VTK_IMAGE_DATA + if coord_type == CoordinateTypes.NONUNIFORM_RECTILINEAR: + return cls.VTK_RECTILINEAR_GRID + if coord_type in { + CoordinateTypes.REGULAR_SPHERICAL, + CoordinateTypes.EUCLIDEAN_2D, + CoordinateTypes.SPHERICAL_2D, + CoordinateTypes.EUCLIDEAN_4SIDED_CELLS, + CoordinateTypes.SPHERICAL_4SIDED_CELLS, + }: + return cls.VTK_STRUCTURED_GRID + if coord_type in { + CoordinateTypes.EUCLIDEAN_PSIDED_CELLS, + CoordinateTypes.SPHERICAL_PSIDED_CELLS, + }: + return cls.VTK_STRUCTURED_GRID + + msg = f"Don't have a matching mesh for {coord_type}" + raise ValueError(msg) + + def new(self): + return self.value() + + +class CoordinateTypes(Enum): + UNIFORM_RECTILINEAR = "uniform rectilinear" + NONUNIFORM_RECTILINEAR = "non-uniform rectilinear" + REGULAR_SPHERICAL = "regular spherical" + EUCLIDEAN_2D = "2d euclidean" + SPHERICAL_2D = "2d spherical" + EUCLIDEAN_4SIDED_CELLS = "euclidean 4 sided cells" + SPHERICAL_4SIDED_CELLS = "spherical 4 sided cells" + EUCLIDEAN_PSIDED_CELLS = "euclidean p-sided cells" + SPHERICAL_PSIDED_CELLS = "spherical p-sided cells" + + @classmethod + def get_coordinate_type(cls, xr_dataset, field_name, use_spherical=False): + print(f"{use_spherical=}") + # Remove time axis + dims = [ + array_name + for array_name in xr_dataset[field_name].dims + if Units.extract(xr_dataset[array_name]) != Units.TIME_UNITS + ] + coords = xr_dataset[field_name].coords + cells_unstructured = len(coords) != len(dims) and len(dims) == 1 + has_bounds_count = 0 + + # Check bounds + for coord_name in coords: + if Units.extract(xr_dataset[coord_name]) in MESH_2D_UNITS: + if xr_dataset[coord_name].attrs.get("bounds"): + has_bounds_count += 1 + + if cells_unstructured: + return ( + cls.SPHERICAL_PSIDED_CELLS + if use_spherical + else cls.EUCLIDEAN_PSIDED_CELLS + ) + + if has_bounds_count == 2: + return cls.SPHERICAL_2D if use_spherical else cls.EUCLIDEAN_2D + + name_to_unit = { + dim_name: Units.extract(xr_dataset[dim_name]) for dim_name in dims + } + dim_units = set(name_to_unit.values()) + if use_spherical: + if dim_units >= MESH_2D_UNITS or dim_units >= MESH_3D_UNITS: + return cls.REGULAR_SPHERICAL + + # Check irregular spacing + for name, unit in name_to_unit.items(): + if unit in MESH_3D_UNITS: + info = DimensionInformation(xr_dataset, name) + if not info.has_regular_spacing: + return cls.NONUNIFORM_RECTILINEAR + + return cls.UNIFORM_RECTILINEAR + + @property + def use_point_data(self): + # point_data for only the following types + return self in { + CoordinateTypes.UNIFORM_RECTILINEAR, + CoordinateTypes.NONUNIFORM_RECTILINEAR, + CoordinateTypes.EUCLIDEAN_2D, + CoordinateTypes.SPHERICAL_2D, + } + + +KNOWN_TIME_CONTAINS = { + " since ", + " after ", +} +KNOWN_TIME_UNITS = { + "second", + "seconds", + "day", + "days", + "hour", + "hours", + "minute", + "minutes", + "s", + "sec", + "secs", + "shake", + "shakes", + "sidereal_day", + "sidereal_days", + "sidereal_hour", + "sidereal_hours", + "sidereal_minute", + "sidereal_minutes", + "sidereal_second", + "sidereal_seconds", + "sidereal_year", + "sidereal_years", + "tropical_year", + "tropical_years", + "lunar_month", + "lunar_months", + "common_year", + "common_years", + "leap_year", + "leap_years", + "Julian_year", + "Julian_years", + "Gregorian_year", + "Gregorian_years", + "sidereal_month", + "sidereal_months", + "tropical_month", + "tropical_months", + "d", + "min", + "mins", + "hrs", + "h", + "fortnight", + "fortnights", + "week", + "jiffy", + "jiffies", + "year", + "years", + "yr", + "yrs", + "a", + "eon", + "eons", + "month", + "months", +} diff --git a/pan3d/xarray/cf/coords/__init__.py b/pan3d/xarray/cf/coords/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pan3d/xarray/cf/coords/convert.py b/pan3d/xarray/cf/coords/convert.py new file mode 100644 index 0000000..80eb961 --- /dev/null +++ b/pan3d/xarray/cf/coords/convert.py @@ -0,0 +1,92 @@ +import math +import numpy as np + + +def to_isel(slices_info, *array_names): + slices = {} + for name in array_names: + if name is None: + continue + + info = slices_info.get(name) + if info is None: + continue + if isinstance(info, int): + slices[name] = info + else: + start, stop, step = info + stop -= (stop - start) % step + slices[name] = slice(start, stop, step) + + return slices if slices else None + + +def slice_array(array_name, dataset, slice_info): + if array_name is None: + return np.zeros(1, dtype=np.float32) + array = dataset[array_name] + dims = array.dims + return array.isel(to_isel(slice_info, *dims)).values + + +def extract_uniform_info(array): + origin = float(array[0]) + spacing = (float(array[-1]) - origin) / (array.size - 1) + tolerance = 0.01 * spacing + + for i in range(array.size): + expected = origin + i * spacing + truth = float(array[i]) + if not np.isclose(expected, truth, atol=tolerance): + return None + + return (origin, spacing, array.size) + + +def is_uniform(array): + return extract_uniform_info(array) is not None + + +def cell_center_to_point(in_array): + if in_array.size == 1: + return in_array + + uniform_data = extract_uniform_info(in_array) + if uniform_data is not None: + origin, spacing, size = uniform_data + return np.linspace( + start=origin - spacing * 0.5, + stop=origin - spacing * 0.5 + size * spacing, + num=size + 1, + endpoint=True, + dtype=np.double, + ) + + # generate fake coords + n_cells = in_array.size + n_points = n_cells + 1 + out_array = np.zeros(n_points, dtype=np.double) + for i in range(1, n_cells): + out_array[i] = 0.5 * (in_array[i - 1] + in_array[i]) + + out_array[0] = out_array[1] - 2 * (out_array[1] - in_array[0]) + out_array[-1] = out_array[-2] + 2 * (in_array[-1] - out_array[-2]) + + return out_array + + +def point_insert(vtk_point, spherical, longitude, latitude, vertical): + if spherical: + longitude = math.pi * longitude / 180 + latitude = math.pi * latitude / 180 + vtk_point.InsertNextPoint( + vertical * math.cos(longitude) * math.cos(latitude), + vertical * math.sin(longitude) * math.cos(latitude), + vertical * math.sin(latitude), + ) + else: + vtk_point.InsertNextPoint( + longitude, + latitude, + vertical, + ) diff --git a/pan3d/xarray/cf/coords/index_mapping.py b/pan3d/xarray/cf/coords/index_mapping.py new file mode 100644 index 0000000..9bc9df6 --- /dev/null +++ b/pan3d/xarray/cf/coords/index_mapping.py @@ -0,0 +1,95 @@ +def get_formula(metadata, kji_dims): + lon_mapper = get_value_mapper(metadata.xr_dataset, kji_dims, metadata.longitude) + lat_mapper = get_value_mapper(metadata.xr_dataset, kji_dims, metadata.latitude) + return CoordMapper(lon_mapper, lat_mapper) + + +class CoordMapper: + def __init__(self, lon_mapper, lat_mapper): + self.lon = lon_mapper + self.lat = lat_mapper + + def __call__(self, i=0, j=0, k=0): + return self.lon(i=i, j=j, k=k), self.lat(i=i, j=j, k=k) + + +def get_value_mapper(xr_dataset, in_dims, out_name): + out_slice_size = len(xr_dataset[out_name].dims) + if out_slice_size == 3: + return IndexMapper3(xr_dataset, in_dims, out_name) + if out_slice_size == 2: + return IndexMapper2(xr_dataset, in_dims, out_name) + if out_slice_size == 1: + return IndexMapper1(xr_dataset, in_dims, out_name) + + msg = f"No IndexMapper for dimensions {xr_dataset[out_name].dims}" + raise ValueError(msg) + + +class IndexMapper: + def __init__(self, xr_dataset, in_dims, out_name): + self.out_array = xr_dataset[out_name].values + + name_to_ijk = {in_dims[-(i + 1)]: "ijk"[i] for i in range(len(in_dims))} + out_dims = xr_dataset[out_name].dims + map_method_name = "".join([name_to_ijk[name] for name in out_dims]) + # print(out_name, "=>", map_method_name) + + setattr(self, "fn", getattr(self, map_method_name)) + + def __call__(self, **kwargs): + return self.fn(**kwargs) + + +class IndexMapper3(IndexMapper): + + def ijk(self, i=0, j=0, k=0, **_): + return self.out_array[i, j, k] + + def ikj(self, i=0, j=0, k=0, **_): + return self.out_array[i, k, j] + + def jki(self, i=0, j=0, k=0, **_): + return self.out_array[j, k, i] + + def kij(self, i=0, j=0, k=0, **_): + return self.out_array[k, i, j] + + def jik(self, i=0, j=0, k=0, **_): + return self.out_array[j, i, k] + + def kji(self, i=0, j=0, k=0, **_): + return self.out_array[k, j, i] + + +class IndexMapper2(IndexMapper): + + def ij(self, i=0, j=0, k=0, **_): + return self.out_array[i, j] + + def ji(self, i=0, j=0, k=0, **_): + return self.out_array[j, i] + + def ik(self, i=0, j=0, k=0, **_): + return self.out_array[i, k] + + def ki(self, i=0, j=0, k=0, **_): + return self.out_array[k, i] + + def jk(self, i=0, j=0, k=0, **_): + return self.out_array[j, k] + + def kj(self, i=0, j=0, k=0, **_): + return self.out_array[k, j] + + +class IndexMapper1(IndexMapper): + + def i(self, i=0, j=0, k=0, **_): + return self.out_array[i] + + def j(self, i=0, j=0, k=0, **_): + return self.out_array[j] + + def k(self, i=0, j=0, k=0, **_): + return self.out_array[k] diff --git a/pan3d/xarray/cf/coords/meta.py b/pan3d/xarray/cf/coords/meta.py new file mode 100644 index 0000000..4151f3a --- /dev/null +++ b/pan3d/xarray/cf/coords/meta.py @@ -0,0 +1,572 @@ +from enum import Enum +import numpy as np +from pan3d.xarray.cf import mesh +from pan3d.xarray.cf.coords.convert import is_uniform, slice_array +from pan3d.xarray.cf.constants import Projection +from vtkmodules.vtkCommonDataModel import ( + vtkImageData, + vtkRectilinearGrid, + vtkStructuredGrid, + vtkUnstructuredGrid, +) + +PRESSURE_UNITS = { + "bar", + "bar", + "millibar", + "decibar", + "atmosphere", + "atm", + "pascal", + "Pa", + "hPa", +} +LENGTH_UNITS = { + "meter", + "metre", + "m", + "kilometer", + "km", +} +COORDINATES_DETECTION = { + "longitude": { + "units": { + "degrees_east", + "degree_east", + "degree_e", + "degrees_e", + "degreee", # codespell:ignore + "degreese", + }, + "axis": {"X"}, + "std_name": { + "longitude", + }, + }, + "latitude": { + "units": { + "degrees_north", + "degree_north", + "degree_n", + "degrees_n", + "degreen", + "degreesn", + }, + "axis": {"Y"}, + "std_name": { + "latitude", + }, + }, + "vertical": { + "units": {}, + "axis": {"Z"}, + "std_name": { + "depth", + "level", + }, + "positive": { + "up", + "down", + }, + }, + "time": { + "units": { + "since", + "second", + "seconds", + "day", + "days", + "hour", + "hours", + "minute", + "minutes", + "s", + "sec", + "secs", + "shake", + "shakes", + "sidereal_day", + "sidereal_days", + "sidereal_hour", + "sidereal_hours", + "sidereal_minute", + "sidereal_minutes", + "sidereal_second", + "sidereal_seconds", + "sidereal_year", + "sidereal_years", + "tropical_year", + "tropical_years", + "lunar_month", + "lunar_months", + "common_year", + "common_years", + "leap_year", + "leap_years", + "Julian_year", + "Julian_years", + "Gregorian_year", + "Gregorian_years", + "sidereal_month", + "sidereal_months", + "tropical_month", + "tropical_months", + "d", + "min", + "mins", + "hrs", + "h", + "fortnight", + "fortnights", + "week", + "jiffy", + "jiffies", + "year", + "years", + "yr", + "yrs", + "a", + "eon", + "eons", + "month", + "months", + }, + "axis": {"T"}, + "std_name": { + "time", + }, + "calendar": { + "standard", + "utc", + "proleptic_gregorian", + "julian", + "tai", + "noleap", + "365_day", + "all_leap", + "366_day", + "360_day", + "none", + }, + }, +} + + +class CoordinateType(Enum): + LONGITUDE = "longitude" + LATITUDE = "latitude" + VERTICAL = "vertical" + TIME = "time" + UNKNOWN = "unknown" + + @classmethod + def from_array(cls, xr_array): + axis = xr_array.attrs.get("axis", "") + units = xr_array.attrs.get("units", "").lower() + std_name = xr_array.attrs.get("standard_name", "") + positive = xr_array.attrs.get("positive", "") + calendar = xr_array.attrs.get("calendar", "") + + # axis provide a direct mapping if available + if axis: + if axis == "X": + return cls.LONGITUDE + if axis == "Y": + return cls.LATITUDE + if axis == "Z": + return cls.VERTICAL + if axis == "T": + return cls.TIME + + # units + if units: + if units in COORDINATES_DETECTION["longitude"]["units"]: + return cls.LONGITUDE + if units in COORDINATES_DETECTION["latitude"]["units"]: + return cls.LATITUDE + + # Time unit check + t_unit = COORDINATES_DETECTION["time"]["units"] + unit_token = set(units.split(" ")) + if len(unit_token & t_unit) == 2: + return cls.TIME + + # unique attribute + if positive: + return cls.VERTICAL + if calendar: + return cls.TIME + + # std name + if std_name in COORDINATES_DETECTION["longitude"]["std_name"]: + return cls.LONGITUDE + if std_name in COORDINATES_DETECTION["latitude"]["std_name"]: + return cls.LATITUDE + if std_name in COORDINATES_DETECTION["time"]["std_name"]: + return cls.TIME + if std_name in COORDINATES_DETECTION["vertical"]["std_name"]: + return cls.VERTICAL + + # based on data type + if np.issubdtype(xr_array.dtype, np.datetime64): + return cls.TIME + if np.issubdtype(xr_array.dtype, np.dtype("O")): + return cls.TIME + + return cls.UNKNOWN + + @classmethod + def can_be_vertical(cls, xr_array): + units = xr_array.attrs.get("units", "").lower() + + if xr_array.dims != (xr_array.name,): + return False + + # vertical tends to be harder to detect + for unit in PRESSURE_UNITS: + if unit in units: + return True + + for unit in LENGTH_UNITS: + if unit in units: + return True + + return False + + @classmethod + def can_be_time(cls, xr_array): + if xr_array.dims != (xr_array.name,): + return False + + return xr_array.name in COORDINATES_DETECTION["time"]["units"] + + +class MetaArrayMapping: + def __init__(self, xr_dataset): + self.xr_dataset = None + self.data_arrays = {} + self.conventions = None + self.longitude = None + self.latitude = None + self.vertical = None + self.time = None + self.valid = False + self.vertical_bias = 6378137 + self.vertical_scale = 100 + + self.update(xr_dataset) + + def update(self, xr_dataset): + self.xr_dataset = xr_dataset + self.valid = False + self.data_arrays = {} + self.conventions = None + self.longitude = None + self.latitude = None + self.vertical = None + self.time = None + + if xr_dataset is None: + return + + self.conventions = ( + xr_dataset.Conventions if hasattr(xr_dataset, "Conventions") else None + ) + if self.conventions is not None: + for convention in {"COARDS", "CF-1"}: + if convention in self.conventions: + self.valid = True + break + + # start extracting coordinates from data dimensions + for array_name in xr_dataset: + # skip bounds array as data array + if "bnd" in array_name or "bound" in array_name: + continue + + dims = xr_dataset[array_name].dims + if dims not in self.data_arrays: + self.data_arrays[dims] = [array_name] + for coord_name in dims: + coord_type = CoordinateType.from_array(xr_dataset[coord_name]) + if coord_type != CoordinateType.UNKNOWN: + setattr(self, coord_type.value, coord_name) + + # Extended binding if not found + if self.vertical is None and CoordinateType.can_be_vertical( + xr_dataset[array_name] + ): + self.vertical = array_name + if self.time is None and CoordinateType.can_be_time( + xr_dataset[coord_name] + ): + self.time = coord_name + else: + self.data_arrays[dims].append(array_name) + + # inspect coordinates if not already filled + if any( + coord is None for coord in [self.longitude, self.latitude, self.vertical] + ): + for coord_name in xr_dataset.coords: + coord_type = CoordinateType.from_array(xr_dataset[coord_name]) + if ( + coord_type != CoordinateType.UNKNOWN + and getattr(self, coord_type.value) is None + ): + setattr(self, coord_type.value, coord_name) + + # Extended binding if not found + if self.vertical is None and CoordinateType.can_be_vertical( + xr_dataset[coord_name] + ): + self.vertical = coord_name + + def __repr__(self): + data_lines = [] + for dims, array_names in self.data_arrays.items(): + data_lines.append(f" - {dims}:") + for name in array_names: + data_lines.append(f" - {name}") + data_str = "\n".join(data_lines) + return f"""Conventions: {self.conventions} {'✅' if self.valid else '❌'} +Coordinates: + - longitude : {self.longitude} + - latitude : {self.latitude} + - vertical : {self.vertical} + - time : {self.time} +Computed: + - has_bound : {self.coords_has_bounds} + - uniform (2D) : {self.uniform_lat_lon} + - uniform (all) : {self.uniform_spacing} + - coords 1d : {self.coords_1d} +Data: +{data_str} +""" + + @property + def coords_has_bounds(self): + if self.latitude is None or self.longitude is None: + return None + + lon_bnd = self.xr_dataset[self.longitude].attrs.get("bounds") + lat_bnd = self.xr_dataset[self.latitude].attrs.get("bounds") + return lon_bnd is not None and lat_bnd is not None + + @property + def uniform_lat_lon(self): + return ( + self.coords_1d + and is_uniform(self.xr_dataset[self.longitude].values) + and is_uniform(self.xr_dataset[self.latitude].values) + ) + + @property + def uniform_spacing(self): + uniform = self.uniform_lat_lon + + if self.vertical is not None and uniform: + uniform = len(self.xr_dataset[self.vertical].dims) == 1 and is_uniform( + self.xr_dataset[self.vertical].values + ) + + return uniform + + @property + def coords_1d(self): + vertical_ok = True + if self.vertical is not None: + vertical_ok = len( + self.xr_dataset[self.vertical].dims + ) == 1 and self.xr_dataset[self.vertical].dims == (self.vertical,) + + return ( + vertical_ok + and self.longitude is not None + and len(self.xr_dataset[self.longitude].dims) == 1 + and self.xr_dataset[self.longitude].dims == (self.longitude,) + and self.latitude is not None + and len(self.xr_dataset[self.latitude].dims) == 1 + and self.xr_dataset[self.latitude].dims == (self.latitude,) + ) + + def use_coords(self, dims): + if self.longitude not in dims: + return False + if self.latitude not in dims: + return False + if self.vertical is not None and self.vertical not in dims: + return False + return True + + def compatible_fields(self, fields=None): + if not fields: + return [] + data_dims = self.xr_dataset[fields[0]].dims + return [n for n in fields if self.xr_dataset[n].dims == data_dims] + + def dimensions(self, field): + return self.xr_dataset[field].dims + + def timeless_dimensions(self, field): + dims = self.xr_dataset[field].dims + return dims[1:] if dims[0] == self.time else dims + + def dims_extent(self, dimensions, slices=None): + extent = [0, 0, 0, 0, 0, 0] + + if slices is None: + slices = {} + + for idx in range(len(dimensions)): + name = dimensions[-(1 + idx)] + array = self.xr_dataset[name] + # Fill in reverse order (t, z, y, x) => [0, x.size, 0, y.size, 0, z.size] + # And extent include both index so (len-1) + if name in slices: + slice_info = slices[name] + if isinstance(slice_info, int): + # size of 1 + pass + else: + size = int((slice_info[1] - slice_info[0]) / slice_info[2]) + extent[idx * 2 + 1] = size - 1 + else: + extent[idx * 2 + 1] = array.size - 1 + + return extent + + def field_extent(self, field, slices=None): + dimensions = self.timeless_dimensions(field) + return self.dims_extent(dimensions, slices) + + def get_vtk_mesh_type(self, projection, fields=None): + fields = self.compatible_fields(fields) + + if self.longitude is None or self.latitude is None or not fields: + # default empty mesh + return vtkImageData() + + # unstructured + timeless_dims = self.timeless_dimensions(fields[0]) + if len(timeless_dims) == 1: + return vtkUnstructuredGrid() + + # structured + if ( + self.coords_has_bounds + or projection == Projection.SPHERICAL + or not self.coords_1d + ): + return vtkStructuredGrid() + + # rectilinear + if not self.uniform_spacing: + return vtkRectilinearGrid() + + # imagedata + return vtkImageData() + + def get_vtk_whole_extent(self, projection, fields=None, slices=None): + if self.longitude is None or self.latitude is None or not fields: + return [ + 0, + 0, + 0, + 0, + 0, + 0, + ] + + mesh_type = self.get_vtk_mesh_type(projection, fields) + fields = self.compatible_fields(fields) + dimensions = self.timeless_dimensions(fields[0]) + extent = self.dims_extent(dimensions, slices) + + if mesh_type.IsA("vtkStructuredGrid") and not ( + self.uniform_lat_lon and self.use_coords(dimensions) + ): + # point data + return extent + + # cell data, need to +1 on the extent + for i in range(3): + if i < 2: + # minimum 2D cells + extent[i * 2 + 1] += 1 + elif extent[i * 2 + 1] > 0: + # Only add cell in 3rd dimension if content available + extent[i * 2 + 1] += 1 + + # print(f"Whole extent: {extent}") + + return extent + + def get_vtk_mesh(self, time_index=0, projection=None, fields=None, slices=None): + if slices is None: + slices = {} + + vtk_mesh, data_location = None, None + if self.xr_dataset is None or not fields: + return vtk_mesh + + # resolve projection + if projection is None: + projection = Projection.SPHERICAL + spherical_proj = projection == Projection.SPHERICAL + + # ensure similar dimension across array names + fields = self.compatible_fields(fields) + data_dims_no_time = self.timeless_dimensions(fields[0]) + + # No mesh if no lon/lat + if self.longitude is None or self.latitude is None: + return vtk_mesh + + # Unstructured + if len(data_dims_no_time) == 1: + vtk_mesh, data_location = mesh.unstructured.generate_mesh( + self, data_dims_no_time, time_index, spherical_proj, slices + ) + + # Structured + if vtk_mesh is None and ( + self.coords_has_bounds or spherical_proj or not self.coords_1d + ): + vtk_mesh, data_location = mesh.structured.generate_mesh( + self, data_dims_no_time, time_index, spherical_proj, slices + ) + + # This should only happen if we don't want spherical_proj + if vtk_mesh is None: + assert not spherical_proj + + # Rectilinear + if vtk_mesh is None and not self.uniform_spacing: + vtk_mesh, data_location = mesh.rectilinear.generate_mesh( + self, data_dims_no_time, time_index, slices + ) + + # Uniform + if vtk_mesh is None: + vtk_mesh, data_location = mesh.uniform.generate_mesh( + self, data_dims_no_time, time_index, slices + ) + + # Add fields + if vtk_mesh: + container = getattr(vtk_mesh, data_location) + for field_name in fields: + field = slice_array(field_name, self.xr_dataset, slices) + # # FIXME to select slices + # print(f"fields: {slices=}") + # field = ( + # self.xr_dataset[field_name][time_index].values + # if self.time + # else self.xr_dataset[field_name].values + # ) + container[field_name] = field.ravel() + else: + print(" !!! No mesh for data") + + return vtk_mesh diff --git a/pan3d/xarray/cf/coords/parametric_vertical.py b/pan3d/xarray/cf/coords/parametric_vertical.py new file mode 100644 index 0000000..75367aa --- /dev/null +++ b/pan3d/xarray/cf/coords/parametric_vertical.py @@ -0,0 +1,332 @@ +""" +Based on Parametric Vertical Coordinates appendix from CF-1.12 spec + +https://cfconventions.org/Data/cf-conventions/cf-conventions-1.12/cf-conventions.html#parametric-v-coord +""" + +import math +import sys +import inspect + +CONVENTION_BASE_URL = "https://cfconventions.org/Data/cf-conventions/cf-conventions-1.12/cf-conventions.html" + + +# ----------------------------------------------------------------------------- +# Factory method +# ----------------------------------------------------------------------------- +def get_formula(xr_dataset, name, bias=0, scale=1): + # No z array => no formula, just a constant + if name is None: + return ConstantFormulaAdapter(bias) + + # Let's extract formula + array_attributes = xr_dataset[name].attrs + std_name = array_attributes.get("standard_name") + formula_terms = array_attributes.get("formula_terms") + + formula_classes = inspect.getmembers(sys.modules[__name__], inspect.isclass) + for klass_name, klass in formula_classes: + if std_name == klass.name: + return FormulaAdapter( + klass(xr_dataset, **extract_formula_terms(formula_terms)), + bias=bias, + scale=scale, + ) + + # Fallback to constant + return ConstantFormulaAdapter(bias) + + +# ----------------------------------------------------------------------------- +# Helpers +# ----------------------------------------------------------------------------- +def extract_formula_terms(formula_terms: str): + tokens = formula_terms.split(" ") + if len(tokens) % 2 != 0: + msg = f"Invalid key/value pairing: {tokens}" + raise ValueError(msg) + + key_mapping = {} + nb_keys = int(len(tokens) / 2) + for i in range(nb_keys): + k = tokens[i * 2][:-1] + v = tokens[i * 2 + 1] + key_mapping[k] = v + + return key_mapping + + +# ----------------------------------------------------------------------------- +class AbstractFormula: + name = "__abstract__" + + def __init__(self, xr_dataset, **name_mapping): + for k, v in name_mapping.items(): + setattr(self, k, xr_dataset[v].values) + + self.select_formula(name_mapping) + + def select_formula(self, name_mapping): + pass + + +class FormulaAdapter: + name = "__internal__" + + def __init__(self, formula, bias=0, scale=1): + self._fn = formula + self._bias = bias + self._scale = scale + + def __call__(self, n=0, k=0, j=0, i=0): + return self._bias + self._scale * self._fn(n=n, k=k, j=j, i=i) + + +class ConstantFormulaAdapter: + name = "__internal__" + + def __init__(self, const_value=0): + self._value = const_value + + def __call__(self, **_): + return self._value + + +# ----------------------------------------------------------------------------- +# Atmosphere natural log pressure coordinate +# ----------------------------------------------------------------------------- +class AtmosphereNaturalLogPressureCoordinate(AbstractFormula): + url = f"{CONVENTION_BASE_URL}#atmosphere-natural-log-pressure-coordinate" + name = "atmosphere_ln_pressure_coordinate" + name_computed = "air_pressure" + std_p0 = "reference_air_pressure_for_atmosphere_vertical_coordinate" + + def __len__(self): + return 1 + + def __call__(self, k, **_): + return self.p0 * math.exp(-self.lev[k]) + + +# ----------------------------------------------------------------------------- +# Atmosphere sigma coordinate +# ----------------------------------------------------------------------------- +class AtmosphereSigmaCoordinate(AbstractFormula): + url = f"{CONVENTION_BASE_URL}#_atmosphere_sigma_coordinate" + name = "atmosphere_sigma_coordinate" + name_computed = "air_pressure" + std_ptop = "air_pressure_at_top_of_atmosphere_model" + std_ps = "surface_air_pressure" + + def __len__(self): + return 4 + + def __call__(self, n, k, j, i): + return self.ptop + self.sigma[k] * (self.ps[n, j, i] - self.ptop) + + +# ----------------------------------------------------------------------------- +# Atmosphere hybrid sigma pressure coordinate +# ----------------------------------------------------------------------------- +class AtmosphereHybridSigmaPressureCoordinate(AbstractFormula): + url = f"{CONVENTION_BASE_URL}#_atmosphere_hybrid_sigma_pressure_coordinate" + name = "atmosphere_hybrid_sigma_pressure_coordinate" + name_computed = "air_pressure" + std_p0 = "reference_air_pressure_for_atmosphere_vertical_coordinate" + std_ps = "surface_air_pressure" + + def select_formula(self, key_mapping): + if "p0" in key_mapping: + setattr(self, "__call__", self._a_p0) + else: + setattr(self, "__call__", self._ap) + + def _a_p0(self, n, k, j, i): + return self.a[k] * self.p0 + self.b[k] * self.ps[n, j, i] + + def _ap(self, n, k, j, i): + return self.ap[k] + self.b[k] * self.ps[n, j, i] + + def __len__(self): + return 4 + + +# ----------------------------------------------------------------------------- +# Atmosphere hybrid height coordinate +# ----------------------------------------------------------------------------- +class AtmosphereHybridHeightCoordinate(AbstractFormula): + url = f"{CONVENTION_BASE_URL}#atmosphere-hybrid-height-coordinate" + name = "atmosphere_hybrid_height_coordinate" + name_computed = ("altitude", "height_above_geopotential_datum") + std_orog = ("surface_altitude", "surface_height_above_geopotential_datum") + + def __call__(self, n, k, j, i): + return self.a[k] + self.b[k] * self.orog[n, j, i] + + def __len__(self): + return 4 + + +# ----------------------------------------------------------------------------- +# Atmosphere smooth level vertical (SLEVE) coordinate +# ----------------------------------------------------------------------------- +class AtmosphereSmoothLevelVerticalCoordinate(AbstractFormula): + url = f"{CONVENTION_BASE_URL}#_atmosphere_smooth_level_vertical_sleve_coordinate" + name = "atmosphere_sleve_coordinate" + name_computed = ("altitude", "height_above_geopotential_datum") + std_ztop = ( + "altitude_at_top_of_atmosphere_model", + "height_above_geopotential_datum_at_top_of_atmosphere_model", + ) + + def __call__(self, n, k, j, i): + return ( + self.a[k] * self.ztop + + self.b1[k] * self.zsurf1[n, j, i] + + self.b2[k] * self.zsurf2[n, j, i] + ) + + def __len__(self): + return 4 + + +# ----------------------------------------------------------------------------- +# Ocean sigma coordinate +# ----------------------------------------------------------------------------- +class OceanSigmaCoordinate(AbstractFormula): + url = f"{CONVENTION_BASE_URL}#_ocean_sigma_coordinate" + name = "ocean_sigma_coordinate" + + def __call__(self, n, k, j, i): + return self.eta[n, j, i] + self.sigma[k] * ( + self.depth[j, i] + self.eta[n, j, i] + ) + + def __len__(self): + return 4 + + +# ----------------------------------------------------------------------------- +# Ocean s-coordinate +# ----------------------------------------------------------------------------- +class OceanSCoordinate(AbstractFormula): + url = f"{CONVENTION_BASE_URL}#_ocean_s_coordinate" + name = "ocean_s_coordinate" + + def __call__(self, n, k, j, i): + c_k = (1 - self.b) * math.sinh(self.a * self.s[k]) / math.sinh( + self.a + ) + self.b * [ + math.tanh(self.a * (self.s[k] + 0.5)) / (2 * math.tanh(0.5 * self.a)) - 0.5 + ] + return ( + self.eta[n, j, i] * (1 + self.s[k]) + + self.depth_c * self.s[k] + + (self.depth[j, i] - self.depth_c) * c_k + ) + + def __len__(self): + return 4 + + +# ----------------------------------------------------------------------------- +# Ocean s-coordinate, generic form 1 +# ----------------------------------------------------------------------------- +class OceanSCoordinateGenericForm1(AbstractFormula): + url = f"{CONVENTION_BASE_URL}#_ocean_s_coordinate_generic_form_1" + name = "ocean_s_coordinate_g1" + + def __call__(self, n, k, j, i): + s_kji = self.depth_c * self.s[k] + (self.depth[j, i] - self.depth_c) * self.C[k] + return s_kji + self.eta[n, j, i] * (1 + self.S[k, j, i] / self.depth[j, i]) + + def __len__(self): + return 4 + + +# ----------------------------------------------------------------------------- +# Ocean s-coordinate, generic form 2 +# ----------------------------------------------------------------------------- +class OceanSCoordinateGenericForm2(AbstractFormula): + url = f"{CONVENTION_BASE_URL}#_ocean_s_coordinate_generic_form_2" + name = "ocean_s_coordinate_g2" + std_names = { + "altitude": { + "zlev": "altitude", + "eta": "sea_surface_height_above_geoid", + "depth": "sea_floor_depth_below_geoid", + }, + "height_above_geopotential_ datum": { + "zlev": "height_above_geopotential_datum", + "eta": "sea_surface_height_above_geopotential_datum", + "depth": "sea_floor_depth_below_geopotential_datum", + }, + "height_above_reference_ ellipsoid": { + "zlev": "height_above_reference_ellipsoid", + "eta": "sea_surface_height_above_reference_ellipsoid", + "depth": "sea_floor_depth_below_reference_ellipsoid", + }, + "height_above_mean_sea_ level": { + "zlev": "height_above_mean_sea_level", + "eta": "sea_surface_height_above_mean_ sea_level", + "depth": "sea_floor_depth_below_mean_ sea_level", + }, + } + + def __call__(self, n, k, j, i): + s_kji = (self.depth_c * self.s[k] + self.depth[j, i] * self.C[k]) / ( + self.depth_c + self.depth[j, i] + ) + return self.eta[n, j, i] + (self.eta[n, j, i] + self.depth[j, i]) * s_kji + + def __len__(self): + return 4 + + +# ----------------------------------------------------------------------------- +# Ocean sigma over z coordinate +# ----------------------------------------------------------------------------- +class OceanSigmaOverZCoordinate(AbstractFormula): + url = f"{CONVENTION_BASE_URL}#_ocean_sigma_over_z_coordinate" + name = "ocean_sigma_z_coordinate" + + def select_formula(self, key_mapping): + if "sigma" in key_mapping and "zlev" not in key_mapping: + setattr(self, "__call__", self._sigma) + elif "zlev" in key_mapping: + setattr(self, "__call__", self._z_lev) + else: + msg = f"No formula for 'ocean_sigma_z_coordinate' with given formula: {key_mapping}" + raise ValueError(msg) + + def _z_lev(self, n, k, j, i): + return self.zlev[k] + + def _sigma(self, n, k, j, i): + return self.eta[n, j, i] + self.sigma[k] * ( + min(self.depth_c, self.depth[j, i]) + self.eta[n, j, i] + ) + + def __len__(self): + return 4 + + +# ----------------------------------------------------------------------------- +# Ocean double sigma coordinate +# ----------------------------------------------------------------------------- +class OceanDoubleSigmaCoordinate(AbstractFormula): + url = f"{CONVENTION_BASE_URL}#_ocean_double_sigma_coordinate" + name = "ocean_double_sigma_coordinate" + + def __call__(self, k, j, i, **_): + f_ji = 0.5 * (self.z1 + self.z2) + 0.5 * (self.z1 - self.z2) * math.tanh( + 2 * self.a / (self.z1 - self.z2) * (self.depth[j, i] - self.href) + ) + + if k <= self.k_c: + return self.sigma[k] * f_ji + else: + return f_ji + (self.sigma[k] - 1) * (self.depth[j, i] - f_ji) + + def __len__(self): + return 3 diff --git a/pan3d/xarray/cf/mesh/__init__.py b/pan3d/xarray/cf/mesh/__init__.py new file mode 100644 index 0000000..bccb22b --- /dev/null +++ b/pan3d/xarray/cf/mesh/__init__.py @@ -0,0 +1,4 @@ +from . import unstructured # noqa: F401 +from . import structured # noqa: F401 +from . import rectilinear # noqa: F401 +from . import uniform # noqa: F401 diff --git a/pan3d/xarray/cf/mesh/rectilinear.py b/pan3d/xarray/cf/mesh/rectilinear.py new file mode 100644 index 0000000..b6938a1 --- /dev/null +++ b/pan3d/xarray/cf/mesh/rectilinear.py @@ -0,0 +1,43 @@ +import numpy as np +from vtkmodules.vtkCommonDataModel import vtkRectilinearGrid +from ..coords.convert import cell_center_to_point, slice_array + + +def generate_mesh(metadata, dimensions, time_index, slices): + """ + - [X] Initial implementation + - [X] Support range slicing + - [X] Support index slicing + - [ ] Automatic testing + + Testing process: + 1. load xarray tutorial dataset: eraint_uvz + 2. Switch projection to Euclidean + 3. Play with range sliders + 4. Switch one range slider to a cut + """ + data_location = "cell_data" + extent = [0, 0, 0, 0, 0, 0] + empty_coords = np.zeros((1,), dtype=np.double) + arrays = [empty_coords, empty_coords, empty_coords] + + assert metadata.coords_1d + + for idx in range(len(dimensions)): + name = dimensions[-(1 + idx)] + + arrays[idx] = cell_center_to_point( + slice_array(name, metadata.xr_dataset, slices) + ) + + # Fill in reverse order (t, z, y, x) => [0, x.size, 0, y.size, 0, z.size] + # And extent include both index (len-1) + extent[idx * 2 + 1] = arrays[idx].size - 1 + + mesh = vtkRectilinearGrid() + mesh.x_coordinates = arrays[0] + mesh.y_coordinates = arrays[1] + mesh.z_coordinates = arrays[2] + mesh.extent = extent + + return mesh, data_location diff --git a/pan3d/xarray/cf/mesh/structured.py b/pan3d/xarray/cf/mesh/structured.py new file mode 100644 index 0000000..d0d88bb --- /dev/null +++ b/pan3d/xarray/cf/mesh/structured.py @@ -0,0 +1,341 @@ +from vtkmodules.vtkCommonCore import vtkPoints +from vtkmodules.vtkCommonDataModel import vtkStructuredGrid + +from ..coords.parametric_vertical import get_formula as get_z_formula +from ..coords.index_mapping import get_formula as get_coords_formula +from ..coords.convert import point_insert, cell_center_to_point, slice_array + + +def generate_mesh(metadata, dimensions, time_index, spherical, slices): + # Data location and extend depend if we can extrapolate cell locations + # bounds or uniform allow to define cell bounds + if metadata.coords_has_bounds: + return generate_bound_cells(metadata, dimensions, time_index, spherical, slices) + + if metadata.uniform_spacing: + return generate_uniform_cells( + metadata, dimensions, time_index, spherical, slices + ) + + # We can only figure out the point location + return generate_mesh_points(metadata, dimensions, time_index, spherical, slices) + + +def generate_uniform_cells(metadata, dimensions, time_index, spherical, slices): + """ + - [X] Initial implementation + - [X] Support range slicing + - [X] Support index slicing + - [ ] Automatic testing + + xarray tutorial dataset: + - air_temperature + - ersstv5 + """ + data_location = "cell_data" + assert spherical + + # 2D or 3D + dims_size = len(dimensions) + assert dims_size == 2 or dims_size == 3 + + # extract extent, origin, spacing + origin = [0, 0, 0] + spacing = [1, 1, 1] + sizes = [1, 1, 1] + for idx in range(len(dimensions)): + name = dimensions[-(1 + idx)] + array = slice_array(name, metadata.xr_dataset, slices) + + # axis origin/spacing (maybe) + if array.size == 1: + origin[idx] = float(array) + all_array = metadata.xr_dataset[name] + if all_array.size > 1: + all_array = all_array[:2].values + spacing[idx] = float(all_array[1] - all_array[0]) + origin[idx] -= 0.5 * spacing[idx] + continue + + axis_spacing = (array[-1] - array[0]) / (array.size - 1) + axis_origin = float(array[0]) - axis_spacing * 0.5 + + # update global origin/spacing + origin[idx] = float(axis_origin) + spacing[idx] = float(axis_spacing) + sizes[idx] = array.size + 1 + + # debug + # print(f"{dimensions=}") + # print(f"{extent=}") + # print(f"{n_points=}") + # print(f"{dimensions=}") + # print(f"{origin=}") + # print(f"{spacing=}") + # print(f"{sizes=}") + + # Points + vtk_points = vtkPoints() + vtk_points.SetDataTypeToDouble() + vtk_points.Allocate(sizes[0] * sizes[1] * sizes[2]) + + # Check if direct coord mapping + for k in range(sizes[2]): + z = origin[2] + k * spacing[2] + z = metadata.vertical_bias + metadata.vertical_scale * z + for j in range(sizes[1]): + lat = origin[1] + j * spacing[1] + for i in range(sizes[0]): + lon = origin[0] + i * spacing[0] + point_insert(vtk_points, spherical, lon, lat, z) + + # Mesh + mesh = vtkStructuredGrid() + mesh.points = vtk_points + mesh.extent = [ + 0, + sizes[0] - 1, + 0, + sizes[1] - 1, + 0, + sizes[2] - 1, + ] + + return mesh, data_location + + +def generate_bound_cells(metadata, dimensions, time_index, spherical, slices): + """ + - [ ] Initial implementation + - [ ] Support range slicing + - [ ] Support index slicing + - [ ] Automatic testing + """ + data_location = "cell_data" + raise NotImplementedError("structured::generate_bound_cells") + return False, data_location + + +def generate_mesh_points(metadata, dimensions, time_index, spherical, slices): + """ + - [ ] Initial implementation + - [ ] Support range slicing + - [ ] Support index slicing + - [ ] Automatic testing + + xarray tutorial dataset: + - ROMS_example: CUSTOM Z FORMULA + - rasm: CUSTOM Z FORMULA but + - non CF format + - 2d coord remapping + + + """ + data_location = "point_data" + + if ( + metadata.coords_1d + and metadata.uniform_lat_lon + and metadata.use_coords(dimensions) + ): + return generate_mesh_points_data_on_cell( + metadata, dimensions, time_index, spherical, slices + ) + + # 2D or 3D + dims_size = len(dimensions) + assert dims_size == 2 or dims_size == 3 + + extent = [0, 0, 0, 0, 0, 0] + dim_ranges = [ + (0, 1, 1), # i + (0, 1, 1), # j + (0, 1, 1), # k + ] + n_points = 1 + for idx in range(len(dimensions)): + name = dimensions[-(1 + idx)] + array = slice_array(name, metadata.xr_dataset, slices) + all_array = metadata.xr_dataset[name] + if array.size != all_array.size: + slice_info = slices.get(name) + if isinstance(slice_info, int): + start = slice_info + dim_ranges[idx] = (start, start + 1) + else: + start, stop, step = slice_info + stop -= (stop - start) % step + size = int((stop - start) / step) + extent[idx * 2 + 1] = size - 1 + dim_ranges[idx] = (start, stop, step) + n_points *= size + else: + # Fill in reverse order (t, z, y, x) => [0, x.size, 0, y.size, 0, z.size] + # And extent include both index so (len-1) + extent[idx * 2 + 1] = array.size - 1 + n_points *= array.size + dim_ranges[idx] = (0, array.size, 1) + + # Points + vtk_points = vtkPoints() + vtk_points.SetDataTypeToDouble() + vtk_points.Allocate(n_points) + + # debug + # print(f"{dimensions=}") + # print(f"{extent=}") + # print(f"{n_points=}") + + # Check if direct coord mapping + if metadata.coords_1d and metadata.use_coords(dimensions): + print(" => 1D coords - Need data to test") + if dims_size == 2: # 2D + print(" => 2D dataset - Need data to test") + x_array = metadata.xr_dataset[metadata.longitude].values + y_array = metadata.xr_dataset[metadata.latitude].values + z = 0 + for j in range(*dim_ranges[1]): + lat = y_array[j] + for i in range(*dim_ranges[0]): + lon = x_array[i] + point_insert(vtk_points, spherical, lon, lat, z) + else: # 3D + print(" => 3D dataset - Need data to test") + x_array = metadata.xr_dataset[metadata.longitude].values + y_array = metadata.xr_dataset[metadata.latitude].values + z_array = metadata.xr_dataset[metadata.vertical].values + for k in range(*dim_ranges[2]): + z = z_array[k] + for j in range(*dim_ranges[1]): + lat = y_array[j] + for i in range(*dim_ranges[0]): + lon = x_array[i] + point_insert(vtk_points, spherical, lon, lat, z) + else: + # CUSTOM Z FORMULA + # need some index mapping + z_formula = get_z_formula( + metadata.xr_dataset, + metadata.vertical, + bias=metadata.vertical_bias, + scale=metadata.vertical_scale, + ) + coords_formula = get_coords_formula(metadata, dimensions) + for k in range(*dim_ranges[2]): + for j in range(*dim_ranges[1]): + for i in range(*dim_ranges[0]): + lon, lat = coords_formula(i=i, j=j, k=k) + z = z_formula(i=i, j=j, k=k, n=time_index) + # print(f"{time_index=}, {k=}, {j=}, {i=} = {lon=}, {lat=}, {z=}") + point_insert(vtk_points, spherical, lon, lat, z) + + # Mesh + mesh = vtkStructuredGrid() + mesh.points = vtk_points + mesh.extent = extent + + return mesh, data_location + + +def generate_mesh_points_data_on_cell( + metadata, dimensions, time_index, spherical, slices +): + """ + - [X] Initial implementation + - [x] Support range slicing + - [x] Support index slicing + - [ ] Automatic testing + + Testing process 2D: <------------------- MISSING + 1. xxxx + 2. Play with range sliders + 3. Switch one range slider to a cut + + Testing process 3D: + 1. load xarray tutorial dataset: eraint_uvz | basin_mask + 2. Play with range sliders + 3. Switch one range slider to a cut + """ + # Slice Ready + data_location = "cell_data" + + # 2D or 3D + dims_size = len(dimensions) + assert dims_size == 2 or dims_size == 3 + + # compute extent between dimensions and slices + extent = metadata.dims_extent(dimensions, slices) + n_points = 1 + + # Increase extent by 1 since we add points to put data on cells + for i in range(3): + if extent[i * 2 + 1] > 0: + extent[i * 2 + 1] += 1 + + # Extract point count + for i in range(3): + n_points *= extent[i * 2 + 1] - extent[i * 2] + 1 + + # Points + vtk_points = vtkPoints() + vtk_points.SetDataTypeToDouble() + vtk_points.Allocate(n_points) + + # debug + # print(f"{dimensions=}") + # print(f"{extent=}") + # print(f"{n_points=}") + # print(f"{slices=}") + + # Check if direct coord mapping + if dims_size == 2: # 2D + print("#" * 60) + print("structured::generate_mesh_points_data_on_cell") + print("We are in 2D mode\n" * 5) + print("#" * 60) + x_array = cell_center_to_point( + slice_array(metadata.longitude, metadata.xr_dataset, slices) + ) + y_array = cell_center_to_point( + slice_array(metadata.latitude, metadata.xr_dataset, slices) + ) + if y_array.size == 1: + y_array = [y_array] + if x_array.size == 1: + x_array = [x_array] + z = metadata.vertical_bias + for j in range(extent[3] + 1): + lat = y_array[j] + for i in range(extent[1] + 1): + lon = x_array[i] + point_insert(vtk_points, spherical, lon, lat, z) + else: # 3D + x_array = cell_center_to_point( + slice_array(metadata.longitude, metadata.xr_dataset, slices) + ) + y_array = cell_center_to_point( + slice_array(metadata.latitude, metadata.xr_dataset, slices) + ) + z_array = cell_center_to_point( + slice_array(metadata.vertical, metadata.xr_dataset, slices) + ) + if z_array.size == 1: + z_array = [z_array] + if y_array.size == 1: + y_array = [y_array] + if x_array.size == 1: + x_array = [x_array] + for k in range(extent[5] + 1): + z = z_array[k] + for j in range(extent[3] + 1): + lat = y_array[j] + for i in range(extent[1] + 1): + lon = x_array[i] + point_insert(vtk_points, spherical, lon, lat, z) + + # Mesh + mesh = vtkStructuredGrid() + mesh.points = vtk_points + mesh.extent = extent + + return mesh, data_location diff --git a/pan3d/xarray/cf/mesh/uniform.py b/pan3d/xarray/cf/mesh/uniform.py new file mode 100644 index 0000000..3826de9 --- /dev/null +++ b/pan3d/xarray/cf/mesh/uniform.py @@ -0,0 +1,59 @@ +from vtkmodules.vtkCommonDataModel import vtkImageData +from ..coords.convert import slice_array + + +def generate_mesh(metadata, dimensions, time_index, slices): + """ + - [X] Initial implementation + - [X] Support range slicing + - [X] Support index slicing + - [ ] Automatic testing + + Testing process: + 1. load xarray tutorial dataset: air_temperature + 2. Switch projection to Euclidean + 3. Play with range sliders + 4. Switch one range slider to a cut + """ + data_location = "cell_data" + + # data to capture + origin = [0, 0, 0] + spacing = [1, 1, 1] + extent = [0, 0, 0, 0, 0, 0] + + # extract information from dimensions + for idx in range(len(dimensions)): + name = dimensions[-(1 + idx)] + + array = slice_array(name, metadata.xr_dataset, slices) + # Fill in reverse order (t, z, y, x) => [0, x.size, 0, y.size, 0, z.size] + # Use size as end of extent as we are adding 1 point for map data on cell + extent[idx * 2 + 1] = array.size + + # No spacing just origin (maybe) + if array.size == 1: + origin[idx] = float(array) + all_array = metadata.xr_dataset[name] + if all_array.size > 1: + all_array = all_array[:2].values + spacing[idx] = float(all_array[1] - all_array[0]) + origin[idx] -= 0.5 * spacing[idx] + continue + + # axis origin/spacing + axis_spacing = (array[-1] - array[0]) / (array.size - 1) + axis_origin = float(array[0]) - axis_spacing * 0.5 + + # update global origin/spacing + origin[idx] = float(axis_origin) + spacing[idx] = float(axis_spacing) + + # print(f"{origin=}") + # print(f"{spacing=}") + # print(f"{extent=}") + + # Configure mesh + mesh = vtkImageData(origin=origin, spacing=spacing, extent=extent) + + return mesh, data_location diff --git a/pan3d/xarray/cf/mesh/unstructured.py b/pan3d/xarray/cf/mesh/unstructured.py new file mode 100644 index 0000000..c11cc70 --- /dev/null +++ b/pan3d/xarray/cf/mesh/unstructured.py @@ -0,0 +1,4 @@ +def generate_mesh(metadata, dimensions, spherical, slices): + print(" => unstructured: cell_data") + raise NotImplementedError("unstructured::generate_mesh") + return False, "cell_data" diff --git a/pan3d/xarray/cf/reader.py b/pan3d/xarray/cf/reader.py new file mode 100644 index 0000000..cabc911 --- /dev/null +++ b/pan3d/xarray/cf/reader.py @@ -0,0 +1,508 @@ +from typing import List, Optional + +import gc +import xarray as xr +import numpy as np +import pandas as pd + +from vtkmodules.util.vtkAlgorithm import VTKPythonAlgorithmBase +from vtkmodules.vtkCommonCore import vtkVariant +from vtkmodules.vtkCommonExecutionModel import vtkStreamingDemandDrivenPipeline +from vtkmodules.vtkCommonDataModel import vtkDataObject +from vtkmodules.vtkFiltersCore import vtkArrayCalculator +from vtkmodules.util import numpy_support + +from pan3d.xarray.cf.coords.meta import MetaArrayMapping +from pan3d.xarray.cf.constants import Projection + +# ----------------------------------------------------------------------------- +# Helper functions +# ----------------------------------------------------------------------------- + + +def is_time_array(xarray, name, values): + if values.dtype.type == np.datetime64 or values.dtype.type == np.timedelta64: + un = np.datetime_data(values.dtype) + # unit = ns and 1 base unit in a spep + if un[0] == "ns" and un[1] == 1 and name in xarray.coords.keys(): + return True + return False + + +# ----------------------------------------------------------------------------- + + +def attr_value_to_vtk(value): + if np.issubdtype(type(value), np.integer): + return vtkVariant(int(value)) + + if np.issubdtype(type(value), np.floating): + return vtkVariant(float(value)) + + if isinstance(value, np.ndarray): + return vtkVariant(numpy_support.numpy_to_vtk(value)) + + return vtkVariant(value) + + +# ----------------------------------------------------------------------------- + + +def get_time_labels(times): + return [pd.to_datetime(time).strftime("%Y-%m-%d %H:%M:%S") for time in times] + + +# ----------------------------------------------------------------------------- +# VTK Algorithms +# ----------------------------------------------------------------------------- +class vtkXArrayCFSource(VTKPythonAlgorithmBase): + """vtk source for converting XArray into a VTK mesh""" + + def __init__( + self, + input: Optional[xr.Dataset] = None, + arrays: Optional[List[str]] = None, + ): + """ + Create vtkXArraySource + + Parameters: + input (xr.Dataset): Provide an XArray to use as input. The load() method will replace it. + arrays (list[str]): List of field to load onto the generated VTK mesh. + """ + VTKPythonAlgorithmBase.__init__( + self, + nInputPorts=0, + nOutputPorts=1, + outputType="vtkDataObject", + ) + # Data source + self._input = input + self._pipeline = None + self._computed = {} + self._data_origin = None + + # Data sub-selection + self._array_names = set(arrays or []) + self._t_index = 0 + self._slices = None + + # vtk internal vars + self._arrays = {} + + # projection + self._proj_mode = Projection.SPHERICAL + self._proj_vertical_bias = 6378137 # earth radius in meter + self._proj_vertical_scale = ( + 100 # increase z scaling to see something compare to earth radius + ) + + # Create reader if xarray available + self._mapping = MetaArrayMapping(self._input) + + # ------------------------------------------------------------------------- + # Information + # ------------------------------------------------------------------------- + + def __str__(self): + return f"VTK XArray CF (Python) reader\n{self._mapping}" + + # ------------------------------------------------------------------------- + # Data input + # ------------------------------------------------------------------------- + + @property + def input(self): + """return current input XArray""" + return self._input + + @input.setter + def input(self, xarray_dataset: xr.Dataset): + """update input with a new XArray""" + self._slices = None + self._array_names.clear() + self._input = xarray_dataset + self._mapping.update(self._input) + self.t_index = 0 + self.Modified() + + # ------------------------------------------------------------------------- + # Array selectors + # ------------------------------------------------------------------------- + + @property + def x(self): + """return the name that is currently mapped to the X axis""" + return self._mapping.longitude + + @property + def x_size(self): + """return the size of the coordinate used for the X axis""" + if self.x is None: + return 0 + return int(self._input[self.x].size) + + @property + def y(self): + """return the name that is currently mapped to the Y axis""" + return self._mapping.latitude + + @property + def y_size(self): + """return the size of the coordinate used for the Y axis""" + if self.y is None: + return 0 + return int(self._input[self.y].size) + + @property + def z(self): + """return the name that is currently mapped to the Z axis""" + return self._mapping.vertical + + @property + def z_size(self): + """return the size of the coordinate used for the Z axis""" + if self.z is None: + return 0 + return int(self._input[self.z].size) + + @property + def t(self): + """return the name that is currently mapped to the time axis""" + return self._mapping.time + + @property + def slice_extents(self): + """ + return a dictionary for the field dimensions with + the corresponding extent [0, size-1]. + + For some dataset, it is possible that the extent does not have + a direct mapping with the coordinate system. + """ + fields = self.arrays + if not fields: + return {} + dims = self._mapping.timeless_dimensions(fields[0]) + # ensure order to match x, y, z + dims = dims[::-1] + + # print(f"slice_extents::{dims=}") + result = { + coord_name: [0, self.input[coord_name].size - 1] for coord_name in dims + } + # print(f" => {result}") + return result + + @property + def available_coords(self): + """List available coordinates arrays that are 1D""" + # !!! Do we use that ??? + if self._input is None: + return [] + + return [k for k, v in self._input.coords.items() if len(v.shape) == 1] + + # ------------------------------------------------------------------------- + # Projection management + # ------------------------------------------------------------------------- + + @property + def projection(self): + return self._proj_mode + + @projection.setter + def projection(self, mode=None): + if mode in Projection and self._proj_mode != mode: + if isinstance(mode, str): + mode = Projection(mode) + self._proj_mode = mode + self.Modified() + + @property + def vertical_bias(self): + return self._proj_vertical_bias + + @vertical_bias.setter + def vertical_bias(self, v): + if self._proj_vertical_bias != v: + self._proj_vertical_bias = v + self._mapping.vertical_bias = v + self.Modified() + + @property + def vertical_scale(self): + return self._proj_vertical_scale + + @vertical_scale.setter + def vertical_scale(self, v): + if self._proj_vertical_scale != v: + self._proj_vertical_scale = v + self._mapping.vertical_scale = v + self.Modified() + + # ------------------------------------------------------------------------- + # Data sub-selection + # ------------------------------------------------------------------------- + + @property + def t_index(self): + """return the current selected time index""" + return self._t_index + + @t_index.setter + def t_index(self, t_index: int): + """update the current selected time index""" + if t_index != self._t_index: + self._t_index = t_index + self.Modified() + + @property + def t_size(self): + """return the size of the coordinate used for the time""" + if self.t is None: + return 0 + return int(self._input[self.t].size) + + @property + def t_labels(self): + """return a list of string that match the various time values available""" + if self.t is None: + return [] + + t_array = self._input[self.t] + t_type = t_array.dtype + if np.issubdtype(t_type, np.datetime64): + return get_time_labels(t_array.values) + return [str(t) for t in t_array.values] + + @property + def arrays(self): + """return the list of arrays that are currently selected to be added to the generated VTK mesh""" + return list(self._array_names) + + @arrays.setter + def arrays(self, array_names: List[str]): + """update the list of arrays to load on the generated VTK mesh""" + new_names = set(array_names or []) + if new_names != self._array_names: + self._array_names = new_names + self.Modified() + + @property + def available_arrays(self): + """List all available data fields for the `arrays` option""" + if self._input is None: + return [] + + all_list = [] + for arrays in self._mapping.data_arrays.values(): + if self._array_names: + if self._array_names.intersection(arrays): + # only add compatible arrays + all_list.extend(arrays) + else: + # when no array selected, add all of them + all_list.extend(arrays) + + return all_list + + @property + def slices(self): + """return the current slicing information which include axes crop/cut and time selection""" + result = dict(self._slices or {}) + if self.t is not None: + result[self.t] = self.t_index + return result + + @slices.setter + def slices(self, v): + """update the slicing of the data along axes""" + if v != self._slices: + self._slices = v + # print(f"{self._slices=}") + if "time" in v: + self.t_index = v.get("time", 0) + + self.Modified() + + # ------------------------------------------------------------------------- + # add-on logic + # ------------------------------------------------------------------------- + + @property + def computed(self): + """return the current description of the computed/derived fields on the VTK mesh""" + return self._computed + + @computed.setter + def computed(self, v): + """ + update the computed/derived fields to add on the VTK mesh + + The layout of the dictionary provided should be as follow: + - key: name of the field to be added + - value: formula to apply for the given field name. The syntax is captured in the document (https://docs.paraview.org/en/latest/UsersGuide/filteringData.html#calculator) + + Then additional keys need to be provided to describe your formula dependencies: + `_use_scalars` and `_use_vectors` which should be a list of string matching the name of the fields you are using in your expression. + + + Please find below an example: + + ``` + { + "_use_scalars": ["u", "v"], # (u,v) needed for "vec" and "m2" + "vec": "(u * iHat) + (v * jHat)", # 2D vector + "m2": "u*u + v*v", + } + ``` + """ + if self._computed != v: + self._computed = v or {} + self._pipeline = None + scalar_arrays = self._computed.get("_use_scalars", []) + vector_arrays = self._computed.get("_use_vectors", []) + + for output_name, func in self._computed.items(): + if output_name[0] == "_": + continue + filter = vtkArrayCalculator( + result_array_name=output_name, + function=func, + ) + + # register array dependencies + for scalar_array in scalar_arrays: + filter.AddScalarArrayName(scalar_array) + for vector_array in vector_arrays: + filter.AddVectorArrayName(vector_array) + + if self._pipeline is None: + self._pipeline = filter + else: + self._pipeline = self._pipeline >> filter + + self.Modified() + + def load(self, data_info): + """ + create a new XArray input with the `data_origin` and `dataset_config` information. + + Here is an example of the layout of the parameter + + ``` + { + "data_origin": { + "source": "url", # one of [file, url, xarray, pangeo, esgf] + "id": "https://ncsa.osn.xsede.org/Pangeo/pangeo-forge/noaa-coastwatch-geopolar-sst-feedstock/noaa-coastwatch-geopolar-sst.zarr", + "order": "C" # (optional) order to use in numpy + }, + "dataset_config": { + "x": "lon", # (optional) coord name for X + "y": "lat", # (optional) coord name for Y + "z": null, # (optional) coord name for Z + "t": "time", # (optional) coord name for time + "slices": { # (optional) array slicing + "lon": [ + 1000, + 6000, + 20 + ], + "lat": [ + 500, + 3000, + 20 + ], + "time": 5 + }, + "t_index": 5, # (optional) selected time index + "arrays": [ # (optional) names of arrays to load onto VTK mesh. + "analysed_sst" # If missing no array will be loaded + ] # onto the mesh. + } + } + ``` + """ + if "data_origin" not in data_info: + raise ValueError("Only state with data_origin can be loaded") + + from pan3d import catalogs + + self._data_origin = data_info["data_origin"] + self.input = catalogs.load_dataset( + self._data_origin["source"], self._data_origin["id"] + ) + + dataset_config = data_info.get("dataset_config") + if dataset_config is None: + self.arrays = self.available_arrays + else: + self.slices = dataset_config.get("slices") + self.t_index = dataset_config.get("t_index", 0) + self.arrays = dataset_config.get("arrays", [self.available_arrays[0]]) + + @property + def state(self): + """return current state that can be reused in a load() later on""" + if self._data_origin is None: + raise RuntimeError( + "No state available without data origin. Need to use the load method to set the data origin." + ) + + return { + "data_origin": self._data_origin, + "dataset_config": { + k: getattr(self, k) + for k in ["x", "y", "z", "t", "slices", "t_index", "arrays"] + }, + } + + # ------------------------------------------------------------------------- + # Algorithm + # ------------------------------------------------------------------------- + + def RequestDataObject(self, request, inInfo, outInfo): + output = self._mapping.get_vtk_mesh_type(self._proj_mode, self.arrays) + outInfo.GetInformationObject(0).Set(vtkDataObject.DATA_OBJECT(), output) + # print(f"RequestDataObject::{output.GetClassName()=}") + return 1 + + def RequestInformation(self, request, inInfo, outInfo): + whole_extent = self._mapping.get_vtk_whole_extent( + self._proj_mode, self.arrays, self._slices + ) + # print(f"RequestInformation::{whole_extent=}") + outInfo.GetInformationObject(0).Set( + vtkStreamingDemandDrivenPipeline.WHOLE_EXTENT(), + *whole_extent, + ) + + return 1 + + def RequestData(self, request, inInfo, outInfo): + """implementation of the vtk algorithm for generating the VTK mesh""" + # Use open data_array handle to fetch data at + # desired Level of Detail + # print("RequestData") + mesh = self._mapping.get_vtk_mesh( + time_index=self.t_index, + projection=self._proj_mode, + fields=self.arrays, + slices=self._slices, + ) + if mesh is not None: + # print(f"{mesh.extent=}") + # print(f"{mesh.GetClassName()=}") + pdo = self.GetOutputData(outInfo, 0) + + # Compute derived quantity + if self._pipeline is not None: + mesh = self._pipeline(mesh) + pdo.ShallowCopy(mesh) + else: + pdo.ShallowCopy(mesh) + + gc.collect() + + return 1 diff --git a/pan3d/xarray/cf/usecases/__init__.py b/pan3d/xarray/cf/usecases/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pan3d/xarray/cf/usecases/rom.py b/pan3d/xarray/cf/usecases/rom.py new file mode 100644 index 0000000..4f8d38a --- /dev/null +++ b/pan3d/xarray/cf/usecases/rom.py @@ -0,0 +1,85 @@ +import math +from xarray.tutorial import open_dataset +from vtkmodules.vtkCommonCore import vtkPoints, vtkMath +from vtkmodules.vtkCommonDataModel import vtkStructuredGrid +from vtkmodules.vtkIOLegacy import vtkDataSetWriter + +from ..coords.parametric_vertical import get_formula + + +class Formula_ocean_s_coordinate_g2: + def __init__(self, xr_dataset, s, c, eta, depth, depth_c): + self.xr_dataset = xr_dataset + self.s = self.xr_dataset[s].values + self.c = self.xr_dataset[c].values + self.eta = self.xr_dataset[eta].values + self.depth = self.xr_dataset[depth].values + self.depth_c = self.xr_dataset[depth_c].values + + def __call__(self, n, k, j, i): + return self.eta[n, j, i] + (self.eta[n, j, i] + self.depth[j, i]) * ( + self.depth_c * self.s[k] + self.depth[j, i] * self.c[k] + ) / (self.depth_c + self.depth[j, i]) + + +def main(): + ds = open_dataset("ROMS_example") + + t = 0 + salt = ds.salt # (ocean_time, s_rho, eta_rho, xi_rho) + lon_rho = ds.lon_rho.values # (eta_rho, xi_rho) + lat_rho = ds.lat_rho.values # (eta_rho, xi_rho) + + # 1D coords + eta_rho = ds.eta_rho + xi_rho = ds.xi_rho + s_rho = ( + ds.s_rho + ) # (s_rho) | ocean_s_coordinate_g2 | formula_terms= "s: s_rho C: Cs_r eta: zeta depth: h depth_c: hc" + + formula = get_formula(ds, "s_rho") + + # ocean_s_coordinate_g2 + # z(n,k,j,i) = eta(n,j,i) + (eta(n,j,i) + depth(j,i)) * S(k,j,i) + # S(k,j,i) = (depth_c * s(k) + depth(j,i) * C(k)) / (depth_c + depth(j,i)) + # formula_terms = "s: var1 C: var2 eta: var3 depth: var4 depth_c: var5" + # - s: s_rho + # - C: Cs_r + # - eta: zeta + # - depth: h + # - depth_c: hc + + earth_radius = 6378137 # in meters + bias = earth_radius + scale = 100 + + n_points = s_rho.size * eta_rho.size * xi_rho.size + points = vtkPoints() + points.SetDataTypeToDouble() + points.Allocate(n_points) + + for k in range(s_rho.size): + for j in range(eta_rho.size): + for i in range(xi_rho.size): + lon = vtkMath.RadiansFromDegrees(lon_rho[j, i]) + lat = vtkMath.RadiansFromDegrees(lat_rho[j, i]) + h = bias + scale * formula(t, k, j, i) + points.InsertNextPoint( + h * math.cos(lon) * math.cos(lat), + h * math.sin(lon) * math.cos(lat), + h * math.sin(lat), + ) + + mesh = vtkStructuredGrid() + mesh.SetExtent(0, xi_rho.size - 1, 0, eta_rho.size - 1, 0, s_rho.size - 1) + mesh.points = points + mesh.point_data["salt"] = salt[t].values.ravel() + + writer = vtkDataSetWriter() + writer.SetInputData(mesh) + writer.SetFileName("rom_salt.vtk") + writer.Write() + + +if __name__ == "__main__": + main() diff --git a/pan3d/xarray/cf/utils.py b/pan3d/xarray/cf/utils.py new file mode 100644 index 0000000..44ac7fb --- /dev/null +++ b/pan3d/xarray/cf/utils.py @@ -0,0 +1,472 @@ +import numpy as np +from . import constants +from .coords import coords_mesh_rectilinear, coords_mesh_spherical + + +def to_bounds(xr_dataset, coord_name): + xr_array = xr_dataset[coord_name] + if xr_array.dims == (coord_name,): + bound_array_name = xr_array.attrs.get("bounds") + if bound_array_name: + # TODO + # xr_dataset[bound_array_name] + return bound_array_name + elif np.issubdtype(xr_array.dtype, np.number): + origin = float(xr_array.values[0]) + spacing = (float(xr_array.values[-1]) - origin) / (xr_array.size - 1) + + if xr_array.attrs.get("positive", "") == "down": + spacing *= -1 + + bounds = np.zeros(xr_array.size + 1, dtype=np.double) + bounds[0] = float(xr_array.values[0]) - 0.5 * spacing + for i in range(1, xr_array.size): + center = 0.5 * float(xr_array.values[i - 1] + xr_array.values[i]) + bounds[i] = center + bounds[-1] = float(xr_array.values[-1]) + 0.5 * spacing + + return bounds + + # fake coordinates + return np.linspace( + -0.5, xr_array.size - 0.5, num=xr_array.size + 1, dtype=np.double + ) + + +class Coordinates: + def __init__(self, xr_dataset, vertical_scale=1, vertical_bias=0): + self.xr_dataset = xr_dataset + self.vertical_scale = vertical_scale + self.vertical_bias = vertical_bias + self.longitude = None + self.latitude = None + self.vertical = None + self.time = None + + for coord_name in xr_dataset.coords: + unit = constants.Units.extract(xr_dataset[coord_name]) + if unit == constants.Units.LONGITUDE_UNITS: + self.longitude = coord_name + elif unit == constants.Units.LATITUDE_UNITS: + self.latitude = coord_name + elif unit == constants.Units.VERTICAL_UNITS: + self.vertical = coord_name + elif unit == constants.Units.TIME_UNITS: + self.time = coord_name + else: + print(f"Skip coord {coord_name}") + + @property + def extent(self): + x_size = self.longitude_array.size if self.longitude else 0 + y_size = self.latitude_array.size if self.latitude else 0 + z_size = self.vertical_array.size if self.vertical else 0 + + return [ + 0, + x_size, + 0, + y_size, + 0, + z_size, + ] + + @property + def origin(self): + print(self.longitude, "=>", self.longitude_array.dims) + print(self.latitude, "=>", self.latitude_array.dims) + x = self.longitude_array.values[0] if self.longitude else 0 + y = self.latitude_array.values[0] if self.latitude else 0 + z = self.vertical_array.values[0] if self.vertical else 0 + return [x, y, z] + + @property + def spacing(self): + origin = self.origin + sx, sy, sz = 1, 1, 1 + if self.longitude: + coord = self.longitude_array + sx = (float(coord.values[-1]) - origin[0]) / (coord.size - 1) + if coord.attrs.get("positive", "") == "down": + sx *= -1 + if self.latitude: + coord = self.latitude_array + sy = (float(coord.values[-1]) - origin[1]) / (coord.size - 1) + if coord.attrs.get("positive", "") == "down": + sy *= -1 + if self.vertical: + coord = self.vertical_array + sz = (float(coord.values[-1]) - origin[2]) / (coord.size - 1) + if coord.attrs.get("positive", "") == "down": + sz *= -1 + + return [sx, sy, sz] + + @property + def longitude_array(self): + if self.longitude is None: + return None + + return self.xr_dataset[self.longitude] + + @property + def longitude_bounds(self): + if self.longitude is None: + return None + + return to_bounds(self.xr_dataset, self.longitude) + + @property + def latitude_array(self): + if self.latitude is None: + return None + + return self.xr_dataset[self.latitude] + + @property + def latitude_bounds(self): + if self.latitude is None: + return None + + return to_bounds(self.xr_dataset, self.latitude) + + @property + def vertical_array(self): + if self.vertical is None: + return None + + return self.xr_dataset[self.vertical] + + @property + def vertical_bounds(self): + if self.vertical is None: + return None + + return to_bounds(self.xr_dataset, self.vertical) + + @property + def time_array(self): + if self.time is None: + return None + + return self.xr_dataset[self.time] + + +class XArrayDataSetCFHelper: + def __init__(self, xr_dataset): + self.xr_dataset = xr_dataset + self._active_array_names = {} + self._active_dimensions = None + self._cached_dimensions_info = {} + self._x = None + self._y = None + self._z = None + self._t = None + + @property + def available_arrays(self) -> list[str]: + if self._active_dimensions is None: + return [ + n + for n in self.xr_dataset.data_vars.keys() + if "bnd" not in n and "bound" not in n + ] + + # List only arrays with the same active dimensions + return [ + n + for n in self.xr_dataset.data_vars.keys() + if self.xr_dataset[n].dims == self._active_dimensions + ] + + @property + def array_selection(self) -> set[str]: + """return the list of arrays that are currently selected to be added to the generated VTK mesh""" + return self._active_array_names + + @array_selection.setter + def array_selection(self, array_names=None): + """update the list of arrays to load on the generated VTK mesh""" + if array_names is None: + array_names = [] + + # Filter with only valid arrays + allowed_names = set(self.available_arrays) + new_names = set([n for n in array_names if n in allowed_names]) + + # Do we have a change + if new_names != self._active_array_names: + if len(new_names): + # Check compatibility + self._active_dimensions = None + compatible_set = set() + for name in new_names: + if self._active_dimensions is None: + self._active_dimensions = self.xr_dataset[name].dims + compatible_set.add(name) + elif self.xr_dataset[name].dims == self._active_dimensions: + compatible_set.add(name) + + self._active_array_names = compatible_set + else: + # no selection + self._active_array_names = new_names + self._active_dimensions = None + + # Update dimension mapping + self._compute_dimensions_information() + + return True + + return False + + def _compute_dimensions_information(self): + # reset + self._cached_dimensions_info.clear() + self._x = None + self._y = None + self._z = None + self._t = None + self._t_index = 0 + + # look deeper if possible + if self._active_dimensions is None: + return + + # extract field dimension information + for dim_name in self._active_dimensions: + info = constants.DimensionInformation(self.xr_dataset, dim_name) + self._cached_dimensions_info[dim_name] = info + + # track dimension orientation + if info.unit == constants.Units.LONGITUDE_UNITS: + self._x = dim_name + if info.unit == constants.Units.LATITUDE_UNITS: + self._y = dim_name + if info.unit == constants.Units.VERTICAL_UNITS: + self._z = dim_name + if info.unit == constants.Units.TIME_UNITS: + self._t = dim_name + + # extract coord dimension from dataset + for coord_name in self.xr_dataset.coords: + if coord_name not in self._cached_dimensions_info: + info = constants.DimensionInformation(self.xr_dataset, coord_name) + self._cached_dimensions_info[coord_name] = info + + # track dimension orientation + if info.unit == constants.Units.LONGITUDE_UNITS: + self._x = coord_name + if info.unit == constants.Units.LATITUDE_UNITS: + self._y = coord_name + if info.unit == constants.Units.VERTICAL_UNITS: + self._z = coord_name + + @property + def x(self): + """return the name that is currently mapped to the X axis""" + return self._x + + @property + def x_size(self): + """return the size of the coordinate used for the X axis""" + if self._x is None: + return 0 + return int(self.xr_dataset[self._x].size) + + @property + def x_info(self): + """return the X coordinate information if available""" + if self._x is None: + return None + return self._cached_dimensions_info.get(self._x) + + @property + def y(self): + """return the name that is currently mapped to the Y axis""" + return self._y + + @property + def y_size(self): + """return the size of the coordinate used for the Y axis""" + if self._y is None: + return 0 + return int(self.xr_dataset[self._y].size) + + @property + def y_info(self): + """return the Y coordinate information if available""" + if self._y is None: + return None + return self._cached_dimensions_info.get(self._y) + + @property + def z(self): + """return the name that is currently mapped to the Z axis""" + return self._z + + @property + def z_size(self): + """return the size of the coordinate used for the Z axis""" + if self._z is None: + return 0 + return int(self.xr_dataset[self._z].size) + + @property + def z_info(self): + """return the Z coordinate information if available""" + if self._z is None: + return None + return self._cached_dimensions_info.get(self._z) + + @property + def t(self): + """return the name that is currently mapped to the time axis""" + return self._t + + @property + def t_size(self): + """return the size of the coordinate used for the time axis""" + if self._t is None: + return 0 + return int(self.xr_dataset[self._t].size) + + @property + def t_info(self): + """return the T coordinate information if available""" + if self._t is None: + return None + return self._cached_dimensions_info.get(self._t) + + def get_info(self, name): + info = self._cached_dimensions_info.get(name) + if info is None: + info = self._cached_dimensions_info.setdefault( + name, constants.DimensionInformation(self.xr_dataset, name) + ) + return info + + @property + def mesh(self): + if not self._cached_dimensions_info: + print("no cache") + return None + + if len(self.array_selection) == 0: + print("no field") + return None + + # Most coordinate variables are defined by a variable the same name as the + # dimension they describe. Those are handled elsewhere. This class handles + # dependent variables that define coordinates that are not the same name as + # any dimension. This is only done when the coordinates cannot be expressed + # as a 1D table lookup from dimension index. This occurs in only two places + # in the CF convention. First, it happens for 2D coordinate variables with + # 4-sided cells. This is basically when the grid is a 2D curvilinear grid. + # Each i,j topological point can be placed anywhere in space. Second, it + # happens for multi-dimensional coordinate variables with p-sided cells. + # These are unstructured collections of polygons. + + # we need at least a 2D surface + if self.x is None and self.y is None: + print("no (x,y)") + return None + + if len(self.x_info.dims) != len(self.y_info.dims): + msg = f"Number of dimensions in different coordinate arrays do not match (x:{self.x_info.dims}, y:{self.y_info.dims})" + raise ValueError(msg) + + field_name = next(iter(self.array_selection)) + coord_mode = constants.CoordinateTypes.get_coordinate_type( + self.xr_dataset, field_name, use_spherical=True + ) + print(f"{coord_mode=}") + mesh_type = constants.MeshTypes.from_coord_type(coord_mode) + mesh = mesh_type.new() + + # ------------------------------------------------- + # Generate mesh structure + # ------------------------------------------------- + if mesh_type == constants.MeshTypes.VTK_IMAGE_DATA: + coords_mesh_rectilinear.add_imagedata(mesh, Coordinates(self.xr_dataset)) + elif mesh_type == constants.MeshTypes.VTK_RECTILINEAR_GRID: + if coord_mode in { + constants.CoordinateTypes.EUCLIDEAN_PSIDED_CELLS, + constants.CoordinateTypes.SPHERICAL_PSIDED_CELLS, + }: + # There is no sensible way to store p-sided cells in a structured grid. + # Just fake some coordinates (related to ParaView bug #11543). + coords_mesh_rectilinear.fake_rectilinear(mesh, self.xr_dataset) + else: + coords_mesh_rectilinear.add_rectilinear(mesh, self.xr_dataset) + elif mesh_type in { + constants.MeshTypes.VTK_STRUCTURED_GRID, + constants.MeshTypes.VTK_UNSTRUCTURED_GRID, + }: + if coord_mode in { + constants.CoordinateTypes.UNIFORM_RECTILINEAR, + constants.CoordinateTypes.NONUNIFORM_RECTILINEAR, + }: + if mesh_type == constants.MeshTypes.VTK_STRUCTURED_GRID: + coords_mesh_rectilinear.add_1d_structured(mesh, self.xr_dataset) + else: + coords_mesh_rectilinear.add_1d_unstructured(mesh, self.xr_dataset) + elif coord_mode == constants.CoordinateTypes.REGULAR_SPHERICAL: + if mesh_type == constants.MeshTypes.VTK_STRUCTURED_GRID: + coords_mesh_spherical.add_1d_structured( + mesh, Coordinates(self.xr_dataset) + ) + else: + coords_mesh_spherical.add_1d_unstructured(mesh, self.xr_dataset) + elif coord_mode in { + constants.CoordinateTypes.EUCLIDEAN_2D, + constants.CoordinateTypes.EUCLIDEAN_4SIDED_CELLS, + }: + if mesh_type == constants.MeshTypes.VTK_STRUCTURED_GRID: + coords_mesh_rectilinear.add_2d_structured(mesh, self.xr_dataset) + else: + coords_mesh_rectilinear.add_2d_unstructured(mesh, self.xr_dataset) + elif coord_mode in { + constants.CoordinateTypes.SPHERICAL_2D, + constants.CoordinateTypes.SPHERICAL_4SIDED_CELLS, + }: + + coords_mesh_spherical.add_2d_structured(mesh, self.xr_dataset) + elif coord_mode in { + constants.CoordinateTypes.EUCLIDEAN_PSIDED_CELLS, + constants.CoordinateTypes.SPHERICAL_PSIDED_CELLS, + }: + # There is no sensible way to store p-sided cells in a structured grid. + # Just fake some coordinates (ParaView bug #11543). + coords_mesh_rectilinear.fake_structured(mesh, self.xr_dataset) + else: + msg = f"Unknown coordinate type {coord_mode}" + raise ValueError(msg) + elif mesh_type == constants.MeshTypes.VTK_UNSTRUCTURED_GRID: + if coord_mode in { + constants.CoordinateTypes.UNIFORM_RECTILINEAR, + constants.CoordinateTypes.NONUNIFORM_RECTILINEAR, + }: + coords_mesh_rectilinear.add_1d_unstructured(mesh, self.xr_dataset) + + # ------------------------------------------------- + # Add fields to mesh + # ------------------------------------------------- + for field_name in self.array_selection: + array = self.xr_dataset[field_name] + + # slice array with current time + if self.t: + array = array.isel({self.t: self._t_index}) + + if coord_mode.use_point_data: + mesh.point_data[field_name] = array.values.ravel(order="C") + else: + mesh.cell_data[field_name] = array.values.ravel(order="C") + + # ------------------------------------------------- + # Add time metadata + # ------------------------------------------------- + # TODO + + return mesh diff --git a/pan3d/xarray/vtk.py b/pan3d/xarray/vtk.py new file mode 100644 index 0000000..37f1855 --- /dev/null +++ b/pan3d/xarray/vtk.py @@ -0,0 +1,535 @@ +from typing import List, Optional + +import xarray as xr +import numpy as np +import pandas as pd + +from vtkmodules.util.vtkAlgorithm import VTKPythonAlgorithmBase +from vtkmodules.vtkCommonCore import vtkVariant +from vtkmodules.vtkCommonExecutionModel import vtkStreamingDemandDrivenPipeline +from vtkmodules.vtkCommonDataModel import vtkImageData, vtkDataObject +from vtkmodules.vtkFiltersCore import vtkArrayCalculator +from vtkmodules.util import numpy_support, xarray_support + +VTK_DATASETS = {"vtkImageData": vtkImageData} + +# ----------------------------------------------------------------------------- +# Helper functions +# ----------------------------------------------------------------------------- + + +def is_time_array(xarray, name, values): + if values.dtype.type == np.datetime64 or values.dtype.type == np.timedelta64: + un = np.datetime_data(values.dtype) + # unit = ns and 1 base unit in a spep + if un[0] == "ns" and un[1] == 1 and name in xarray.coords.keys(): + return True + return False + + +# ----------------------------------------------------------------------------- + + +def attr_value_to_vtk(value): + if np.issubdtype(type(value), np.integer): + return vtkVariant(int(value)) + + if np.issubdtype(type(value), np.floating): + return vtkVariant(float(value)) + + if isinstance(value, np.ndarray): + return vtkVariant(numpy_support.numpy_to_vtk(value)) + + return vtkVariant(value) + + +# ----------------------------------------------------------------------------- + + +def get_time_labels(times): + return [pd.to_datetime(time).strftime("%Y-%m-%d %H:%M:%S") for time in times] + + +# ----------------------------------------------------------------------------- +# VTK Algorithms +# ----------------------------------------------------------------------------- + + +class vtkXArraySource(VTKPythonAlgorithmBase): + """vtk source for converting XArray into a VTK mesh""" + + def __init__( + self, + input: Optional[xr.Dataset] = None, + arrays: Optional[List[str]] = None, + ): + """ + Create vtkXArraySource + + Parameters: + input (xr.Dataset): Provide an XArray to use as input. The load() method will replace it. + arrays (list[str]): List of field to load onto the generated VTK mesh. + """ + VTKPythonAlgorithmBase.__init__( + self, + nInputPorts=0, + nOutputPorts=1, + outputType="vtkDataObject", + ) + # Data source + self._input = input + self._pipeline = None + self._computed = {} + self._data_origin = None + + # Array name selectors + self._x = None + self._y = None + self._z = None + self._t = None + + # Data sub-selection + self._array_names = set(arrays or []) + self._t_index = 0 + self._slices = None + + # vtk internal vars + self._arrays = {} + self._accessor = None + self._reader = None + + # Create reader if xarray available + if self._input: + self._build_reader() + + # ------------------------------------------------------------------------- + # Reader Setup + # ------------------------------------------------------------------------- + + def _build_reader(self): + # reset + self._array_names.clear() + self._arrays = {} # to prevent garbage collection + self._x = None + self._y = None + self._z = None + self._t = None + + # vtk reader + self._accessor = xarray_support.vtkXArrayAccessor() + self._reader = xarray_support.vtkNetCDFCFReader(accessor=self._accessor) + + # XArray binding + xarray = self._input + accessor = self._accessor + reader = self._reader + time_names = [] + + # map dimensions + dim_keys = list(xarray.sizes.keys()) + dim_values = [xarray.sizes[k] for k in dim_keys] + dim_name_to_index = {k: i for i, k in enumerate(dim_keys)} + + accessor.SetDim(dim_keys) + accessor.SetDimLen(dim_values) + + # map variables + var_keys = [*xarray.data_vars.keys(), *xarray.coords.keys()] + var_is_coord = [0] * len(xarray.data_vars) + [1] * len(xarray.coords) + var_name_to_index = {k: i for i, k in enumerate(var_keys)} + + accessor.SetVar(var_keys, var_is_coord) + for i, k in enumerate(var_keys): + # need contiguous array (noop if true, otherwise copy) + v_data = np.ascontiguousarray(xarray[k].values) + + # Capture time array names + if is_time_array(xarray, k, v_data): + time_names.append(k) + + # Convert time data + if v_data.dtype.char == "O": + # object array, assume cftime (copy as double array) + v_data = xarray_support.ndarray_cftime_toordinal(v_data).astype( + np.float64 + ) + time_names.append(k) + + # save array + self._arrays[k] = v_data + + # register data to accessor + accessor.SetVarValue(i, v_data) + accessor.SetVarType(i, xarray_support.get_nc_type(v_data.dtype)) + accessor.SetVarDims(i, [dim_name_to_index[name] for name in xarray[k].dims]) + accessor.SetVarCoords( + i, [var_name_to_index[name] for name in xarray[k].coords] + ) + + # handle attributes + for attr_k, attr_v in xarray[k].attrs.items(): + accessor.SetAtt(i, attr_k, attr_value_to_vtk(attr_v)) + + # map time array name to reader + if len(time_names) >= 1: + for name in time_names: + if accessor.IsCOARDSCoordinate(name): + reader.SetTimeDimensionName(name) + break + + # Extract coordinate mapping + reader.UpdateInformation() + self._x = reader.GetLongitudeDimensionName() + self._y = reader.GetLatitudeDimensionName() + self._z = reader.GetVerticalDimensionName() + self._t = reader.GetTimeDimensionName() + + # ------------------------------------------------------------------------- + # Information + # ------------------------------------------------------------------------- + + def __str__(self): + return """VTK NetCDF/XArray reader""" + + # ------------------------------------------------------------------------- + # Data input + # ------------------------------------------------------------------------- + + @property + def input(self): + """return current input XArray""" + return self._input + + @input.setter + def input(self, xarray_dataset: xr.Dataset): + """update input with a new XArray""" + self._input = xarray_dataset + self._build_reader() + self.Modified() + + # ------------------------------------------------------------------------- + # Array selectors + # ------------------------------------------------------------------------- + + @property + def x(self): + """return the name that is currently mapped to the X axis""" + return self._x + + @property + def x_size(self): + """return the size of the coordinate used for the X axis""" + if self._x is None: + return 0 + return int(self._input[self._x].size) + + @property + def y(self): + """return the name that is currently mapped to the Y axis""" + return self._y + + @property + def y_size(self): + """return the size of the coordinate used for the Y axis""" + if self._y is None: + return 0 + return int(self._input[self._y].size) + + @property + def z(self): + """return the name that is currently mapped to the Z axis""" + return self._z + + @property + def z_size(self): + """return the size of the coordinate used for the Z axis""" + if self._z is None: + return 0 + return int(self._input[self._z].size) + + @property + def t(self): + """return the name that is currently mapped to the time axis""" + return self._t + + @property + def slice_extents(self): + """return a dictionary for the X, Y, Z dimensions with the corresponding extent [0, size-1]""" + return { + coord_name: [0, self.input[coord_name].size - 1] + for coord_name in [self.x, self.y, self.z] + if coord_name is not None + } + + @property + def available_coords(self): + """List available coordinates arrays that have are 1D""" + if self._input is None: + return [] + + return [k for k, v in self._input.coords.items() if len(v.shape) == 1] + + # ------------------------------------------------------------------------- + # Data sub-selection + # ------------------------------------------------------------------------- + + @property + def t_index(self): + """return the current selected time index""" + return self._t_index + + @t_index.setter + def t_index(self, t_index: int): + """update the current selected time index""" + if t_index != self._t_index: + self._t_index = t_index + self.Modified() + + @property + def t_size(self): + """return the size of the coordinate used for the time""" + if self._t is None: + return 0 + return int(self._input[self._t].size) + + @property + def t_labels(self): + """return a list of string that match the various time values available""" + if self._t is None: + return [] + + t_array = self._input[self._t] + t_type = t_array.dtype + if np.issubdtype(t_type, np.datetime64): + return get_time_labels(t_array.values) + return [str(t) for t in t_array.values] + + @property + def arrays(self): + """return the list of arrays that are currently selected to be added to the generated VTK mesh""" + return list(self._array_names) + + @arrays.setter + def arrays(self, array_names: List[str]): + """update the list of arrays to load on the generated VTK mesh""" + new_names = set(array_names or []) + if new_names != self._array_names: + self._array_names = new_names + + for name in self.available_arrays: + active = 1 if name in self._array_names else 0 + self._reader.SetVariableArrayStatus(name, active) + + self.Modified() + + @property + def available_arrays(self): + """List all available data fields for the `arrays` option""" + if self._input is None or self._reader is None: + return [] + + vtk_str_array = self._reader.GetAllVariableArrayNames() + return [ + vtk_str_array.GetValue(i) for i in range(vtk_str_array.GetNumberOfValues()) + ] + + @property + def slices(self): + """return the current slicing information which include axes crop/cut and time selection""" + result = dict(self._slices or {}) + if self.t is not None: + result[self.t] = self.t_index + return result + + @slices.setter + def slices(self, v): + """update the slicing of the data along axes""" + if v != self._slices: + self._slices = v + # FIXME !!! update accessor + # Ask Dan + # self.Modified() + # raise NotImplementedError() + print("set slices not implemented", v) + if "time" in v: + self.t_index = v.get("time", 0) + + # ------------------------------------------------------------------------- + # add-on logic + # ------------------------------------------------------------------------- + + @property + def computed(self): + """return the current description of the computed/derived fields on the VTK mesh""" + return self._computed + + @computed.setter + def computed(self, v): + """ + update the computed/derived fields to add on the VTK mesh + + The layout of the dictionary provided should be as follow: + - key: name of the field to be added + - value: formula to apply for the given field name. The syntax is captured in the document (https://docs.paraview.org/en/latest/UsersGuide/filteringData.html#calculator) + + Then additional keys need to be provided to describe your formula dependencies: + `_use_scalars` and `_use_vectors` which should be a list of string matching the name of the fields you are using in your expression. + + + Please find below an example: + + ``` + { + "_use_scalars": ["u", "v"], # (u,v) needed for "vec" and "m2" + "vec": "(u * iHat) + (v * jHat)", # 2D vector + "m2": "u*u + v*v", + } + ``` + """ + if self._computed != v: + self._computed = v or {} + self._pipeline = None + scalar_arrays = self._computed.get("_use_scalars", []) + vector_arrays = self._computed.get("_use_vectors", []) + + for output_name, func in self._computed.items(): + if output_name[0] == "_": + continue + filter = vtkArrayCalculator( + result_array_name=output_name, + function=func, + ) + + # register array dependencies + for scalar_array in scalar_arrays: + filter.AddScalarArrayName(scalar_array) + for vector_array in vector_arrays: + filter.AddVectorArrayName(vector_array) + + if self._pipeline is None: + self._pipeline = filter + else: + self._pipeline = self._pipeline >> filter + + self.Modified() + + def load(self, data_info): + """ + create a new XArray input with the `data_origin` and `dataset_config` information. + + Here is an example of the layout of the parameter + + ``` + { + "data_origin": { + "source": "url", # one of [file, url, xarray, pangeo, esgf] + "id": "https://ncsa.osn.xsede.org/Pangeo/pangeo-forge/noaa-coastwatch-geopolar-sst-feedstock/noaa-coastwatch-geopolar-sst.zarr", + "order": "C" # (optional) order to use in numpy + }, + "dataset_config": { + "x": "lon", # (optional) coord name for X + "y": "lat", # (optional) coord name for Y + "z": null, # (optional) coord name for Z + "t": "time", # (optional) coord name for time + "slices": { # (optional) array slicing + "lon": [ + 1000, + 6000, + 20 + ], + "lat": [ + 500, + 3000, + 20 + ], + "time": 5 + }, + "t_index": 5, # (optional) selected time index + "arrays": [ # (optional) names of arrays to load onto VTK mesh. + "analysed_sst" # If missing no array will be loaded + ] # onto the mesh. + } + } + ``` + """ + if "data_origin" not in data_info: + raise ValueError("Only state with data_origin can be loaded") + + from pan3d import catalogs + + self._data_origin = data_info["data_origin"] + self.input = catalogs.load_dataset( + self._data_origin["source"], self._data_origin["id"] + ) + self._build_reader() + + dataset_config = data_info.get("dataset_config") + if dataset_config is None: + self.arrays = self.available_arrays + else: + # self.slices = dataset_config.get("slices") # FIXME: not implemented yet + self.t_index = dataset_config.get("t_index", 0) + self.arrays = dataset_config.get("arrays", self.available_arrays) + + @property + def state(self): + """return current state that can be reused in a load() later on""" + if self._data_origin is None: + raise RuntimeError( + "No state available without data origin. Need to use the load method to set the data origin." + ) + + return { + "data_origin": self._data_origin, + "dataset_config": { + k: getattr(self, k) + for k in ["x", "y", "z", "t", "slices", "t_index", "arrays"] + }, + } + + # ------------------------------------------------------------------------- + # Algorithm + # ------------------------------------------------------------------------- + + def RequestDataObject(self, request, inInfo, outInfo): + output = vtkImageData() + if self._reader: + self._reader.UpdateDataObject() + output = self._reader.GetOutputDataObject(0).NewInstance() + + outInfo.GetInformationObject(0).Set(vtkDataObject.DATA_OBJECT(), output) + return 1 + + def RequestInformation(self, request, inInfo, outInfo): + if self._reader: + self._reader.UpdateInformation() + info = self._reader.GetOutputInformation(0) + whole_extent = info.Get(vtkStreamingDemandDrivenPipeline.WHOLE_EXTENT()) + outInfo.GetInformationObject(0).Set( + vtkStreamingDemandDrivenPipeline.WHOLE_EXTENT(), *whole_extent + ) + + return 1 + + def RequestData(self, request, inInfo, outInfo): + """implementation of the vtk algorithm for generating the VTK mesh""" + # Use open data_array handle to fetch data at + # desired Level of Detail + if self._reader is not None: + pdo = self.GetOutputData(outInfo, 0) + + if self.t_size: + t = self._arrays[self.t][self.t_index] + self._reader.UpdateTimeStep(t) + + mesh = self._reader() + + # Compute derived quantity + if self._pipeline is not None: + mesh = self._pipeline(mesh) + pdo.ShallowCopy(mesh) + else: + pdo.ShallowCopy(mesh) + + return 1 + return 0