diff --git a/src/mdio/__init__.py b/src/mdio/__init__.py index 5fed389c8..59b834c4a 100644 --- a/src/mdio/__init__.py +++ b/src/mdio/__init__.py @@ -2,16 +2,17 @@ from importlib import metadata -from mdio.api.io import open_mdio -from mdio.api.io import to_mdio -from mdio.converters import mdio_to_segy -from mdio.converters import segy_to_mdio - try: __version__ = metadata.version("multidimio") except metadata.PackageNotFoundError: __version__ = "unknown" +from mdio.api.create import create_empty +from mdio.api.create import create_empty_like +from mdio.api.io import open_mdio +from mdio.api.io import to_mdio +from mdio.converters.mdio import mdio_to_segy +from mdio.converters.segy import segy_to_mdio __all__ = [ "__version__", @@ -19,4 +20,6 @@ "to_mdio", "mdio_to_segy", "segy_to_mdio", + "create_empty", + "create_empty_like", ] diff --git a/src/mdio/api/create.py b/src/mdio/api/create.py new file mode 100644 index 000000000..a3f7342dc --- /dev/null +++ b/src/mdio/api/create.py @@ -0,0 +1,168 @@ +"""Creating MDIO v1 datasets.""" + +from __future__ import annotations + +from datetime import UTC +from datetime import datetime +from typing import TYPE_CHECKING + +from mdio.api.io import _normalize_path +from mdio.api.io import open_mdio +from mdio.api.io import to_mdio +from mdio.builder.template_registry import TemplateRegistry +from mdio.builder.xarray_builder import to_xarray_dataset +from mdio.converters.segy import populate_dim_coordinates +from mdio.converters.type_converter import to_structured_type +from mdio.core.grid import Grid + +if TYPE_CHECKING: + from pathlib import Path + + from segy.schema import HeaderSpec + from upath import UPath + from xarray import Dataset as xr_Dataset + + from mdio.builder.schemas import Dataset + from mdio.builder.templates.base import AbstractDatasetTemplate + from mdio.core.dimension import Dimension + + +def create_empty( + mdio_template: AbstractDatasetTemplate | str, + dimensions: list[Dimension], + output_path: UPath | Path | str | None, + headers: HeaderSpec | None = None, + overwrite: bool = False, +) -> xr_Dataset: + """A function that creates an empty MDIO v1 file with known dimensions. + + Args: + mdio_template: The MDIO template or template name to use to define the dataset structure. + dimensions: The dimensions of the MDIO file. + output_path: The universal path for the output MDIO v1 file. + headers: The SEG-Y trace headers that are important to the Dataset. Defaults to None. + overwrite: Whether to overwrite the output file if it already exists. Defaults to False. + + Returns: + The output MDIO dataset. + + Raises: + FileExistsError: If the output location already exists and overwrite is False. + """ + output_path = _normalize_path(output_path) + + if not overwrite and output_path.exists(): + err = f"Output location '{output_path.as_posix()}' exists. Set `overwrite=True` if intended." + raise FileExistsError(err) + + header_dtype = to_structured_type(headers.dtype) if headers else None + grid = Grid(dims=dimensions) + if isinstance(mdio_template, str): + # A template name is passed in. Get a unit-unaware template from registry + mdio_template = TemplateRegistry().get(mdio_template) + # Build the dataset using the template + mdio_ds: Dataset = mdio_template.build_dataset(name=mdio_template.name, sizes=grid.shape, header_dtype=header_dtype) + + # Convert to xarray dataset + xr_dataset: xr_Dataset = to_xarray_dataset(mdio_ds=mdio_ds) + + # Populate coordinates using the grid + # For empty datasets, we only populate dimension coordinates + drop_vars_delayed = [] + xr_dataset, drop_vars_delayed = populate_dim_coordinates(xr_dataset, grid, drop_vars_delayed=drop_vars_delayed) + + if headers: + # Since the headers were provided, the user wants to export to SEG-Y + # Add a dummy segy_file_header variable used to export to SEG-Y + xr_dataset["segy_file_header"] = ((), "") + + # Create the Zarr store with the correct structure but with empty arrays + if output_path is not None: + to_mdio(xr_dataset, output_path=output_path, mode="w", compute=False) + + # Write the dimension coordinates and trace mask + xr_dataset = xr_dataset[drop_vars_delayed + ["trace_mask"]] + + if output_path is not None: + to_mdio(xr_dataset, output_path=output_path, mode="r+", compute=True) + + return xr_dataset + + +def create_empty_like( + input_path: UPath | Path | str, + output_path: UPath | Path | str, + keep_coordinates: bool = False, + overwrite: bool = False, +) -> xr_Dataset: + """A function that creates an empty MDIO v1 file with the same structure as an existing one. + + Args: + input_path: The path of the input MDIO file. + output_path: The path of the output MDIO file. + If None, the output will not be written to disk. + keep_coordinates: Whether to keep the coordinates in the output file. + overwrite: Whether to overwrite the output file if it exists. + + Returns: + The output MDIO dataset. + + Raises: + FileExistsError: If the output location already exists and overwrite is False. + """ + input_path = _normalize_path(input_path) + output_path = _normalize_path(output_path) if output_path is not None else None + + if not overwrite and output_path is not None and output_path.exists(): + err = f"Output location '{output_path.as_posix()}' exists. Set `overwrite=True` if intended." + raise FileExistsError(err) + + ds = open_mdio(input_path) + + # Create a copy with the same structure but no data or, + # optionally, coordinates + ds_output = ds.copy(data=None).reset_coords(drop=not keep_coordinates) + + # Dataset + # Keep the name (which is the same as the used template name) and the original API version + # ds_output.attrs["name"] + # ds_output.attrs["apiVersion"] + ds_output.attrs["createdOn"] = str(datetime.now(UTC)) + + # Coordinates + if not keep_coordinates: + for coord_name in ds_output.coords: + ds_output[coord_name].attrs.pop("unitsV1", None) + + # MDIO attributes + attr = ds_output.attrs["attributes"] + if attr is not None: + attr.pop("gridOverrides", None) # Empty dataset should not have gridOverrides + # Keep the original values for the following attributes + # attr["defaultVariableName"] + # attr["surveyType"] + # attr["gatherType"] + + # "All traces should be marked as dead in empty dataset" + if "trace_mask" in ds_output.variables: + ds_output["trace_mask"][:] = False + + # Data variable + var_name = attr["defaultVariableName"] + var = ds_output[var_name] + var.attrs.pop("statsV1", None) + if not keep_coordinates: + var.attrs.pop("unitsV1", None) + + # SEG-Y file header + if "segy_file_header" in ds_output.variables: + segy_file_header = ds_output["segy_file_header"] + if segy_file_header is not None: + segy_file_header.attrs.pop("textHeader", None) + segy_file_header.attrs.pop("binaryHeader", None) + segy_file_header.attrs.pop("rawBinaryHeader", None) + + if output_path is not None: + to_mdio(ds_output, output_path=output_path, mode="w", compute=True) + + return ds_output diff --git a/src/mdio/converters/__init__.py b/src/mdio/converters/__init__.py index fd88595ff..a860b9067 100644 --- a/src/mdio/converters/__init__.py +++ b/src/mdio/converters/__init__.py @@ -1,6 +1 @@ """MDIO Data conversion API.""" - -from mdio.converters.mdio import mdio_to_segy -from mdio.converters.segy import segy_to_mdio - -__all__ = ["mdio_to_segy", "segy_to_mdio"] diff --git a/tests/conftest.py b/tests/conftest.py index 2a32b8cd4..1d03f584d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -58,3 +58,9 @@ def segy_export_tmp(tmp_path_factory: pytest.TempPathFactory) -> Path: """Make a temp file for the round-trip IBM SEG-Y.""" tmp_dir = tmp_path_factory.mktemp("segy") return tmp_dir / "teapot_roundtrip.segy" + + +@pytest.fixture(scope="class") +def empty_mdio_dir(tmp_path_factory: pytest.TempPathFactory) -> Path: + """Make a temp file for empty MDIO testing.""" + return tmp_path_factory.mktemp(r"empty_mdio_dir") diff --git a/tests/integration/test_z_create_empty.py b/tests/integration/test_z_create_empty.py new file mode 100644 index 000000000..c788de446 --- /dev/null +++ b/tests/integration/test_z_create_empty.py @@ -0,0 +1,417 @@ +"""Test for create_empty_mdio function. + +This set of tests has to run after the segy_roundtrip_teapot tests have run because +the teapot dataset is used as the input for the create_empty_like test. + +NOTE: The only reliable way to ensure the test order (including the case when the +test are run in parallel) is to use the alphabetical order of the test names. +""" + +from __future__ import annotations + +import math +from typing import TYPE_CHECKING + +import numpy as np +import pytest +from segy.standards import get_segy_standard + +from mdio.builder.schemas.v1.units import LengthUnitEnum +from mdio.builder.schemas.v1.units import LengthUnitModel +from mdio.builder.schemas.v1.units import SpeedUnitEnum +from mdio.builder.schemas.v1.units import SpeedUnitModel +from mdio.builder.schemas.v1.units import TimeUnitEnum +from mdio.builder.schemas.v1.units import TimeUnitModel + +if TYPE_CHECKING: + from pathlib import Path + + from xarray import Dataset as xr_Dataset + + +from tests.integration.testing_helpers import UNITS_METER +from tests.integration.testing_helpers import UNITS_METERS_PER_SECOND +from tests.integration.testing_helpers import UNITS_MILLISECOND +from tests.integration.testing_helpers import UNITS_NONE +from tests.integration.testing_helpers import get_teapot_segy_spec +from tests.integration.testing_helpers import get_values +from tests.integration.testing_helpers import validate_xr_variable + +from mdio import __version__ +from mdio.api.create import create_empty +from mdio.api.create import create_empty_like +from mdio.api.io import open_mdio +from mdio.api.io import to_mdio +from mdio.builder.schemas.v1.stats import CenteredBinHistogram +from mdio.builder.schemas.v1.stats import SummaryStatistics +from mdio.builder.template_registry import TemplateRegistry +from mdio.builder.templates.seismic_3d_poststack import Seismic3DPostStackTemplate +from mdio.converters.mdio import mdio_to_segy +from mdio.converters.segy import segy_to_mdio +from mdio.core import Dimension + + +class PostStack3DVelocityMetricTemplate(Seismic3DPostStackTemplate): + """Custom template that uses 'velocity' as the default variable name instead of 'amplitude'.""" + + @property + def _default_variable_name(self) -> str: + """Override the default variable name.""" + return "velocity" + + def __init__(self, data_domain: str) -> None: + super().__init__(data_domain) + self._units.update( + { + "time": UNITS_MILLISECOND, + "cdp_x": UNITS_METER, + "cdp_y": UNITS_METER, + "velocity": UNITS_METERS_PER_SECOND, + } + ) + + @property + def _name(self) -> str: + """Override the name of the template.""" + domain_suffix = self._data_domain.capitalize() + return f"PostStack3DVelocity{domain_suffix}" + + +class TestCreateEmptyMdio: + """Tests for create_empty_mdio function.""" + + @classmethod + def _create_empty_mdio(cls, create_headers: bool, output_path: Path, overwrite: bool = True) -> xr_Dataset: + """Create a temporary empty MDIO file for testing.""" + # Create the grid with the specified dimensions + dims = [ + Dimension(name="inline", coords=range(1, 346, 1)), + Dimension(name="crossline", coords=range(1, 189, 1)), + Dimension(name="time", coords=range(0, 3002, 2)), + ] + + # If later on, we want to export to SEG-Y, we need to provide the trace header spec. + # The HeaderSpec can be either standard or customized. + headers = get_teapot_segy_spec().trace.header if create_headers else None + # Create an empty MDIO v1 metric post-stack 3D time velocity dataset + return create_empty( + mdio_template=PostStack3DVelocityMetricTemplate(data_domain="time"), + dimensions=dims, + output_path=output_path, + headers=headers, + overwrite=overwrite, + ) + + @classmethod + def validate_teapot_dataset_metadata(cls, ds: xr_Dataset, is_velocity: bool) -> None: + """Validate the dataset metadata.""" + if is_velocity: + assert ds.name == "PostStack3DVelocityTime" + else: + assert ds.name == "PostStack3DTime" + + # Check basic metadata attributes + expected_attrs = { + "apiVersion": __version__, + "name": ds.name, + } + actual_attrs_json = ds.attrs + + # Compare one by one due to ever changing createdOn + for key, value in expected_attrs.items(): + assert key in actual_attrs_json + if key == "createdOn": + assert actual_attrs_json[key] is not None + else: + assert actual_attrs_json[key] == value + + # Check that createdOn exists + assert "createdOn" in actual_attrs_json + + # Validate template attributes + attributes = ds.attrs["attributes"] + assert attributes is not None + assert len(attributes) == 3 + # Validate all attributes provided by the abstract template + if is_velocity: + assert attributes["defaultVariableName"] == "velocity" + else: + assert attributes["defaultVariableName"] == "amplitude" + assert attributes["surveyType"] == "3D" + assert attributes["gatherType"] == "stacked" + + @classmethod + def validate_teapot_dataset_variables( + cls, ds: xr_Dataset, header_dtype: np.dtype | None, is_velocity: bool + ) -> None: + """Validate an empty MDIO dataset structure and content.""" + # Check that the dataset has the expected shape + assert ds.sizes == {"inline": 345, "crossline": 188, "time": 1501} + + # Validate the dimension coordinate variables + validate_xr_variable(ds, "inline", {"inline": 345}, UNITS_NONE, np.int32, False, range(1, 346), get_values) + validate_xr_variable( + ds, "crossline", {"crossline": 188}, UNITS_NONE, np.int32, False, range(1, 189), get_values + ) + validate_xr_variable( + ds, "time", {"time": 1501}, UNITS_MILLISECOND, np.int32, False, range(0, 3002, 2), get_values + ) + + # Validate the non-dimensional coordinate variables (should be empty for empty dataset) + validate_xr_variable(ds, "cdp_x", {"inline": 345, "crossline": 188}, UNITS_METER, np.float64) + validate_xr_variable(ds, "cdp_y", {"inline": 345, "crossline": 188}, UNITS_METER, np.float64) + + if header_dtype is not None: + # Validate the headers (should be empty for empty dataset) + # Infer the dtype from segy_spec and ignore endianness + header_dtype = header_dtype.newbyteorder("native") + validate_xr_variable(ds, "headers", {"inline": 345, "crossline": 188}, UNITS_NONE, header_dtype) + # The "segy_file_header" is optional + if "segy_file_header" in ds.variables: + validate_xr_variable(ds, "segy_file_header", dims={}, units=UNITS_NONE, data_type=np.dtype("U1")) + else: + assert "headers" not in ds.variables + assert "segy_file_header" not in ds.variables + + # Validate the trace mask (should be all True for empty dataset) + validate_xr_variable(ds, "trace_mask", {"inline": 345, "crossline": 188}, UNITS_NONE, np.bool_) + trace_mask = ds["trace_mask"].values + assert not np.any(trace_mask), "Expected all `False` values in `trace_mask` but found `True`." + + # Validate the velocity or amplitude data (should be empty) + if is_velocity: + validate_xr_variable( + ds, "velocity", {"inline": 345, "crossline": 188, "time": 1501}, UNITS_METERS_PER_SECOND, np.float32 + ) + else: + validate_xr_variable( + ds, "amplitude", {"inline": 345, "crossline": 188, "time": 1501}, UNITS_NONE, np.float32 + ) + + @pytest.fixture(scope="class") + def mdio_with_headers(self, empty_mdio_dir: Path) -> Path: + """Create a temporary empty MDIO file for testing. + + This fixture is scoped to the class level, so it will be executed only once + and shared across all test methods in the class. + """ + empty_mdio: Path = empty_mdio_dir / "mdio_with_headers.mdio" + xr_dataset = self._create_empty_mdio(create_headers=True, output_path=empty_mdio) + assert xr_dataset is not None + return empty_mdio + + @pytest.fixture(scope="class") + def mdio_no_headers(self, empty_mdio_dir: Path) -> Path: + """Create a temporary empty MDIO file for testing. + + This fixture is scoped to the class level, so it will be executed only once + and shared across all test methods in the class. + """ + empty_mdio: Path = empty_mdio_dir / "mdio_no_headers.mdio" + xr_dataset = self._create_empty_mdio(create_headers=False, output_path=empty_mdio) + assert xr_dataset is not None + return empty_mdio + + def test_dataset_metadata(self, mdio_with_headers: Path) -> None: + """Test dataset metadata for empty MDIO file.""" + ds = open_mdio(mdio_with_headers) + self.validate_teapot_dataset_metadata(ds, is_velocity=True) + + def test_variables(self, mdio_with_headers: Path, mdio_no_headers: Path) -> None: + """Test grid validation for empty MDIO file.""" + ds = open_mdio(mdio_with_headers) + header_dtype = get_teapot_segy_spec().trace.header.dtype + self.validate_teapot_dataset_variables(ds, header_dtype=header_dtype, is_velocity=True) + + ds = open_mdio(mdio_no_headers) + self.validate_teapot_dataset_variables(ds, header_dtype=None, is_velocity=True) + + def test_overwrite_behavior(self, empty_mdio_dir: Path) -> None: + """Test overwrite parameter behavior in create_empty_mdio.""" + empty_mdio = empty_mdio_dir / "empty.mdio" + empty_mdio.mkdir(parents=True, exist_ok=True) + garbage_file = empty_mdio / "garbage.txt" + garbage_file.write_text("This is garbage data that should be overwritten") + garbage_dir = empty_mdio / "garbage_dir" + garbage_dir.mkdir(exist_ok=True) + (garbage_dir / "nested_garbage.txt").write_text("More garbage") + + # Verify the directory exists with garbage data + assert empty_mdio.exists() + assert garbage_file.exists() + assert garbage_dir.exists() + + # Second call: Try to create MDIO with overwrite=False - should raise FileExistsError + with pytest.raises(FileExistsError, match="Output location.*exists"): + self._create_empty_mdio(create_headers=True, output_path=empty_mdio, overwrite=False) + + # Third call: Create MDIO with overwrite=True - should succeed and overwrite garbage + xr_dataset = self._create_empty_mdio(create_headers=True, output_path=empty_mdio, overwrite=True) + assert xr_dataset is not None + + # Validate that the MDIO file can be loaded correctly using the helper function + ds = open_mdio(empty_mdio) + self.validate_teapot_dataset_metadata(ds, is_velocity=True) + header_dtype = get_teapot_segy_spec().trace.header.dtype + self.validate_teapot_dataset_variables(ds, header_dtype=header_dtype, is_velocity=True) + + # Verify the garbage data was overwritten (should not exist) + assert not garbage_file.exists(), "Garbage file should have been overwritten" + assert not garbage_dir.exists(), "Garbage directory should have been overwritten" + + def test_populate_empty_dataset(self, mdio_with_headers: Path) -> None: + """Test showing how to populate empty dataset.""" + # Open an empty PostStack3DVelocityTime dataset with SEG-Y 1.0 headers + # + # When this empty dataset was created from the 'PostStack3DVelocityTime' template and dimensions, + # * 'inline', 'crossline', and 'time' dimension coordinate variables were created and pre-populated + # NOTE: the 'time' units are specified in the template, so they are not None in this case. + # * 'cdp_x', 'cdp_y' non-dimensional coordinate variables were created + # NOTE: the 'cdp_x' and 'cdp_y' units are specified in the template, so they are not None in this case. + # * 'velocity' variable was created (the name of this default variable is specified in the template) + # NOTE: the 'velocity' units are specified in the template, so they are not None in this case. + # * 'trace_mask' variable was created and pre-populated with 'False' fill values + # (all traces are marked as dead) + # * 'headers' and 'segy_file_header' variables were created (if the dataset was created with + # headers not None). The 'headers' variable structured datatype is defined by the HeaderSpec + # that was used to create the empty MDIO + # * dataset attribute called 'attributes' was created + ds = open_mdio(mdio_with_headers) + + # 1) Populate dataset's velocity + var_name = ds.attrs["attributes"]["defaultVariableName"] + velocity = ds[var_name] + velocity[:5, :, :] = 1 + velocity[5:10, :, :] = 2 + velocity[50:100, :, :] = 3 + velocity[150:175, :, :] = -1 + + # 2) Populate dataset's velocity statistics (optional) + nonzero_samples = np.ma.masked_invalid(velocity, copy=False) + stats = SummaryStatistics( + count=nonzero_samples.count(), + min=nonzero_samples.min(), + max=nonzero_samples.max(), + sum=nonzero_samples.sum(dtype="float64"), + sum_squares=(np.ma.power(nonzero_samples, 2).sum(dtype="float64")), + histogram=CenteredBinHistogram(bin_centers=[], counts=[]), + ) + velocity.attrs["statsV1"] = stats.model_dump(mode="json") + + # 3) Populate the non-dimensional coordinate variables 'cdp_x' and 'cdp_y' (optional) + origin = [270000, 3290000] # survey x, y origin + inline_azimuth_rad = 0.523599 # survey orientation, in radians, from the north to the east (30 degrees) + spacing = [50, 50] # survey inline, crossline spacing + inline_grid, xline_grid = np.meshgrid(ds.inline.values, ds.crossline.values, indexing="ij") + sin_azimuth = math.sin(inline_azimuth_rad) + cos_azimuth = math.cos(inline_azimuth_rad) + ds.cdp_x[:] = origin[0] + inline_grid * spacing[0] * sin_azimuth + xline_grid * spacing[1] * cos_azimuth + ds.cdp_y[:] = origin[1] + inline_grid * spacing[0] * cos_azimuth - xline_grid * spacing[1] * sin_azimuth + + # 4) Populate dataset's trace mask (optional) + ds.trace_mask[:] = ~np.isnan(velocity[:, :, 0]) + + # 5) If the units were not set in the template or you want to change the coordinate and data variable units + # you can set the unitsV1 attribute for the coordinate and data variables (optional). + # If you are happy with the units specified in the template, you should skip this step. + ds.time.attrs["unitsV1"] = TimeUnitModel(time=TimeUnitEnum.MILLISECOND).model_dump(mode="json") + + ds.cdp_x.attrs["unitsV1"] = LengthUnitModel(length=LengthUnitEnum.FOOT).model_dump(mode="json") + ds.cdp_x.attrs["unitsV1"] = LengthUnitModel(length=LengthUnitEnum.FOOT).model_dump(mode="json") + + velocity.attrs["unitsV1"] = SpeedUnitModel(speed=SpeedUnitEnum.FEET_PER_SECOND).model_dump(mode="json") + + # 6) Populate dataset's segy trace headers, if those were created (required only if we want to export to SEG-Y) + if "headers" in ds.variables: + # Both the structured "headers" and the dummy "segy_file_header" variables are + # required to enable SEG-Y to MDIO conversion + + # Populate the structured trace "headers" variable + ds["headers"].values["inline"] = inline_grid + ds["headers"].values["crossline"] = xline_grid + # coordinate_scalar: + # Scalar to be applied to all coordinates specified in Standard Trace Header bytes + # 73–88 and to bytes Trace Header 181–188 to give the real value. Scalar = 1, + # ±10, ±100, ±1000, or ±10,000. If positive, scalar is used as a multiplier; if + # negative, scalar is used as divisor. A value of zero is assumed to be a scalar + # value of 1. + ds["headers"].values["coordinate_scalar"][:] = np.int16(-100) + ds["headers"].values["cdp_x"][:] = np.int32(ds.cdp_x * 100) + ds["headers"].values["cdp_y"][:] = np.int32(ds.cdp_y * 100) + + # Fill its metadata (.attrs) with 'textHeader' and 'binaryHeader'. + ds["segy_file_header"].attrs.update( + { + "textHeader": "\n".join( + [ + "C01 BYTES 13-16: CROSSLINE " + " " * 47, + "C02 BYTES 17-20: INLINE " + " " * 47, + "C03 BYTES 71-74: COORDINATE SCALAR " + " " * 47, + "C04 BYTES 181-184: CDP X " + " " * 47, + "C05 BYTES 185-188: CDP Y " + " " * 47, + *(f"C{i:02d}" + " " * 77 for i in range(6, 41)), + ] + ), + "binaryHeader": { + "data_sample_format": 1, + "sample_interval": int(ds.time[1] - ds.time[0]), + "samples_per_trace": ds.time.size, + "segy_revision_major": 0, + "segy_revision_minor": 0, + }, + } + ) + + # 7) Create dataset's custom attributes (optional) + ds.attrs["attributes"]["createdBy"] = "John Doe" + + # 8) Export to MDIO + output_path_mdio = mdio_with_headers.parent / "populated_empty.mdio" + to_mdio(ds, output_path=output_path_mdio, mode="w", compute=True) + + # 9) Convert the populated empty MDIO to SEG-Y + if "headers" in ds.variables: + # Select the SEG-Y standard to use for the conversion + custom_segy_spec = get_segy_standard(1.0) + # Customize to use the same HeaderSpec that was used to create the empty MDIO + custom_segy_spec.trace.header = get_teapot_segy_spec().trace.header + # Convert the MDIO file to SEG-Y + mdio_to_segy( + segy_spec=custom_segy_spec, + input_path=output_path_mdio, + output_path=mdio_with_headers.parent / "populated_empty.sgy", + ) + + def test_create_empty_like(self, segy_input: Path, empty_mdio_dir: Path) -> None: + """Create an empty MDIO file like the input MDIO file. + + This test has to run after the segy_roundtrip_teapot tests have run because + its uses 'teapot_mdio_tmp' created by the segy_roundtrip_teapot tests as the input. + """ + mdio_input_tmp = empty_mdio_dir / "create_empty_like_input.mdio" + mdio_output_tmp = empty_mdio_dir / "create_empty_like_output.mdio" + + mdio_unit_aware_template = TemplateRegistry().get("PostStack3DTime") + mdio_unit_aware_template.add_units({"time": UNITS_MILLISECOND}) + mdio_unit_aware_template.add_units({"cdp_x": UNITS_METER}) + mdio_unit_aware_template.add_units({"cdp_y": UNITS_METER}) + + segy_to_mdio( + segy_spec=get_teapot_segy_spec(), + mdio_template=mdio_unit_aware_template, + input_path=segy_input, + output_path=mdio_input_tmp, + overwrite=True, + ) + + ds = create_empty_like( + input_path=mdio_input_tmp, + output_path=mdio_output_tmp, + keep_coordinates=True, + overwrite=True, + ) + assert ds is not None + + self.validate_teapot_dataset_metadata(ds, is_velocity=False) + header_dtype = get_teapot_segy_spec().trace.header.dtype + self.validate_teapot_dataset_variables(ds, header_dtype=header_dtype, is_velocity=False) diff --git a/tests/integration/testing_helpers.py b/tests/integration/testing_helpers.py index c871ba6db..df7233207 100644 --- a/tests/integration/testing_helpers.py +++ b/tests/integration/testing_helpers.py @@ -5,6 +5,37 @@ import numpy as np import xarray as xr from numpy.typing import DTypeLike +from segy.schema import HeaderField +from segy.schema import ScalarType +from segy.schema.segy import SegySpec +from segy.standards import get_segy_standard + +from mdio.builder.schemas.v1.units import AllUnitModel +from mdio.builder.schemas.v1.units import LengthUnitEnum +from mdio.builder.schemas.v1.units import LengthUnitModel +from mdio.builder.schemas.v1.units import SpeedUnitEnum +from mdio.builder.schemas.v1.units import SpeedUnitModel +from mdio.builder.schemas.v1.units import TimeUnitEnum +from mdio.builder.schemas.v1.units import TimeUnitModel + +UNITS_NONE = None +UNITS_METER = LengthUnitModel(length=LengthUnitEnum.METER) +UNITS_SECOND = TimeUnitModel(time=TimeUnitEnum.SECOND) +UNITS_MILLISECOND = TimeUnitModel(time=TimeUnitEnum.MILLISECOND) +UNITS_METERS_PER_SECOND = SpeedUnitModel(speed=SpeedUnitEnum.METERS_PER_SECOND) +UNITS_FOOT = LengthUnitModel(length=LengthUnitEnum.FOOT) +UNITS_FEET_PER_SECOND = SpeedUnitModel(speed=SpeedUnitEnum.FEET_PER_SECOND) + + +def get_teapot_segy_spec() -> SegySpec: + """Return the customized SEG-Y specification for the teapot dome dataset.""" + teapot_fields = [ + HeaderField(name="inline", byte=17, format=ScalarType.INT32), + HeaderField(name="crossline", byte=13, format=ScalarType.INT32), + HeaderField(name="cdp_x", byte=81, format=ScalarType.INT32), + HeaderField(name="cdp_y", byte=85, format=ScalarType.INT32), + ] + return get_segy_standard(1.0).customize(trace_header_fields=teapot_fields) def get_values(arr: xr.DataArray) -> np.ndarray: @@ -49,3 +80,51 @@ def validate_variable( # noqa PLR0913 if expected_values is not None and actual_value_generator is not None: actual_values = actual_value_generator(arr) assert np.array_equal(expected_values, actual_values) + + +def validate_xr_variable( # noqa PLR0913 + dataset: xr.Dataset, + name: str, + dims: dict[int], + units: AllUnitModel, + data_type: np.dtype, + has_stats: bool = False, + expected_values: range | None = None, + actual_value_generator: Callable[[xr.DataArray], np.ndarray] | None = None, +) -> None: + """Validate the properties of a variable in an Xarray dataset.""" + v = dataset[name] + assert v is not None + assert v.sizes == dims + if hasattr(data_type, "fields") and data_type.fields is not None: + # The following assertion will fail because of differences in offsets + # assert data_type == arr.dtype + + # Compare field names + expected_names = list(data_type.names) + actual_names = list(v.dtype.names) + assert expected_names == actual_names + + # Compare field types + expected_types = [data_type[name] for name in data_type.names] + actual_types = [v.dtype[name] for name in v.dtype.names] + assert expected_types == actual_types + else: + assert data_type == v.dtype + + stats = v.attrs.get("statsV1", None) + if has_stats: + assert stats is not None, "StatsV1 should not be empty for dataset variables with stats" + else: + assert stats is None, "StatsV1 should be empty for dataset variables without stats" + + if units is not None: + units_v1 = v.attrs.get("unitsV1", None) + assert units_v1 is not None, "UnitsV1 should not be empty for dataset variables with units" + assert units_v1 == units.model_dump(mode="json") + else: + assert "unitsV1" not in v.attrs, "UnitsV1 should not exist for unit-unaware variables" + + if expected_values is not None and actual_value_generator is not None: + actual_values = actual_value_generator(v) + assert np.array_equal(expected_values, actual_values)