diff --git a/src/mdio/converters/segy.py b/src/mdio/converters/segy.py index 7914f410..1726a761 100644 --- a/src/mdio/converters/segy.py +++ b/src/mdio/converters/segy.py @@ -156,7 +156,11 @@ def _scan_for_headers( """Extract trace dimensions and index headers from the SEG-Y file. This is an expensive operation. - It scans the SEG-Y file in chunks by using ProcessPoolExecutor + It scans the SEG-Y file in chunks by using ProcessPoolExecutor. + + Note: + If grid_overrides are applied to the template before calling this function, + the chunk_size returned from get_grid_plan should match the template's chunk shape. """ full_chunk_shape = template.full_chunk_shape segy_dimensions, chunk_size, segy_headers = get_grid_plan( @@ -167,13 +171,26 @@ def _scan_for_headers( chunksize=full_chunk_shape, grid_overrides=grid_overrides, ) + + # Update template to match grid_plan results after grid overrides if full_chunk_shape != chunk_size: - # TODO(Dmitriy): implement grid overrides - # https://github.com/TGSAI/mdio-python/issues/585 - # The returned 'chunksize' is used only for grid_overrides. We will need to use it when full - # support for grid overrides is implemented - err = "Support for changing full_chunk_shape in grid overrides is not yet implemented" - raise NotImplementedError(err) + logger.debug( + "Adjusting template chunk shape from %s to %s to match grid after overrides", + full_chunk_shape, + chunk_size, + ) + template._var_chunk_shape = chunk_size + + # Update dimensions if they don't match grid_plan results + actual_spatial_dims = tuple(dim.name for dim in segy_dimensions[:-1]) + if template.spatial_dimension_names != actual_spatial_dims: + logger.debug( + "Adjusting template dimensions from %s to %s to match grid after overrides", + template.spatial_dimension_names, + actual_spatial_dims, + ) + template._dim_names = actual_spatial_dims + (template.trace_domain,) + return segy_dimensions, segy_headers diff --git a/src/mdio/segy/geometry.py b/src/mdio/segy/geometry.py index ed41e42e..39ed6dca 100644 --- a/src/mdio/segy/geometry.py +++ b/src/mdio/segy/geometry.py @@ -24,6 +24,8 @@ from numpy.typing import NDArray from segy.arrays import HeaderArray + from mdio.builder.templates.base import AbstractDatasetTemplate + logger = logging.getLogger(__name__) @@ -267,7 +269,8 @@ def analyze_non_indexed_headers(index_headers: HeaderArray, dtype: DTypeLike = n header_names = [] for header_key in index_headers.dtype.names: if header_key != "trace": - unique_headers[header_key] = np.sort(np.unique(index_headers[header_key])) + unique_vals = np.sort(np.unique(index_headers[header_key])) + unique_headers[header_key] = unique_vals header_names.append(header_key) total_depth += 1 @@ -302,6 +305,7 @@ def transform( self, index_headers: HeaderArray, grid_overrides: dict[str, bool | int], + template: AbstractDatasetTemplate, # noqa: ARG002 ) -> NDArray: """Perform the grid transform.""" @@ -378,11 +382,25 @@ def transform( self, index_headers: HeaderArray, grid_overrides: dict[str, bool | int], + template: AbstractDatasetTemplate, ) -> NDArray: """Perform the grid transform.""" self.validate(index_headers, grid_overrides) - return analyze_non_indexed_headers(index_headers) + # Filter out coordinate fields, keep only dimensions for trace indexing + coord_fields = set(template.coordinate_names) if template else set() + dim_fields = [name for name in index_headers.dtype.names if name != "trace" and name not in coord_fields] + + # Create trace indices on dimension fields only + dim_headers = index_headers[dim_fields] if dim_fields else index_headers + dim_headers_with_trace = analyze_non_indexed_headers(dim_headers) + + # Add trace field back to full headers + if dim_headers_with_trace is not None and "trace" in dim_headers_with_trace.dtype.names: + trace_values = np.array(dim_headers_with_trace["trace"]) + index_headers = rfn.append_fields(index_headers, "trace", trace_values, usemask=False) + + return index_headers def transform_index_names(self, index_names: Sequence[str]) -> Sequence[str]: """Insert dimension "trace" to the sample-1 dimension.""" @@ -434,6 +452,7 @@ def transform( self, index_headers: HeaderArray, grid_overrides: dict[str, bool | int], + template: AbstractDatasetTemplate, # noqa: ARG002 ) -> NDArray: """Perform the grid transform.""" self.validate(index_headers, grid_overrides) @@ -471,6 +490,7 @@ def transform( self, index_headers: HeaderArray, grid_overrides: dict[str, bool | int], + template: AbstractDatasetTemplate, # noqa: ARG002 ) -> NDArray: """Perform the grid transform.""" self.validate(index_headers, grid_overrides) @@ -532,6 +552,7 @@ def run( index_names: Sequence[str], grid_overrides: dict[str, bool], chunksize: Sequence[int] | None = None, + template: AbstractDatasetTemplate | None = None, ) -> tuple[HeaderArray, tuple[str], tuple[int]]: """Run grid overrides and return result.""" for override in grid_overrides: @@ -542,7 +563,7 @@ def run( raise GridOverrideUnknownError(override) function = self.commands[override].transform - index_headers = function(index_headers, grid_overrides=grid_overrides) + index_headers = function(index_headers, grid_overrides=grid_overrides, template=template) function = self.commands[override].transform_index_names index_names = function(index_names) diff --git a/src/mdio/segy/utilities.py b/src/mdio/segy/utilities.py index f6d76bc6..289e7f94 100644 --- a/src/mdio/segy/utilities.py +++ b/src/mdio/segy/utilities.py @@ -71,10 +71,19 @@ def get_grid_plan( # noqa: C901, PLR0913 horizontal_coordinates, chunksize=chunksize, grid_overrides=grid_overrides, + template=template, ) + # Use the spatial dimension names from horizontal_coordinates (which may have been modified by grid overrides) + # Extract only the dimension names (not including non-dimension coordinates) + # After grid overrides, trace might have been added to horizontal_coordinates + transformed_spatial_dims = [ + name for name in horizontal_coordinates if name in horizontal_dimensions or name == "trace" + ] dimensions = [] - for dim_name in horizontal_dimensions: + for dim_name in transformed_spatial_dims: + if dim_name not in headers_subset.dtype.names: + continue dim_unique = np.unique(headers_subset[dim_name]) dimensions.append(Dimension(coords=dim_unique, name=dim_name)) diff --git a/tests/integration/test_import_streamer_grid_overrides.py b/tests/integration/test_import_streamer_grid_overrides.py index d05070f5..56f592e6 100644 --- a/tests/integration/test_import_streamer_grid_overrides.py +++ b/tests/integration/test_import_streamer_grid_overrides.py @@ -7,7 +7,6 @@ import dask import numpy as np -import numpy.testing as npt import pytest import xarray.testing as xrt from tests.integration.conftest import get_segy_mock_4d_spec @@ -28,12 +27,12 @@ os.environ["MDIO__IMPORT__SAVE_SEGY_FILE_HEADER"] = "true" -# TODO(Altay): Finish implementing these grid overrides. +# TODO(BrianMichell): Add non-binned back # https://github.com/TGSAI/mdio-python/issues/612 -@pytest.mark.skip(reason="NonBinned and HasDuplicates haven't been properly implemented yet.") -@pytest.mark.parametrize("grid_override", [{"NonBinned": True}, {"HasDuplicates": True}]) +# @pytest.mark.parametrize("grid_override", [{"NonBinned": True, "chunksize": 4}, {"HasDuplicates": True}]) +@pytest.mark.parametrize("grid_override", [{"HasDuplicates": True}]) @pytest.mark.parametrize("chan_header_type", [StreamerShotGeometryType.C]) -class TestImport4DNonReg: # pragma: no cover - tests is skipped +class TestImport4DNonReg: """Test for 4D segy import with grid overrides.""" def test_import_4d_segy( # noqa: PLR0913 @@ -67,7 +66,7 @@ def test_import_4d_segy( # noqa: PLR0913 assert ds["segy_file_header"].attrs["binaryHeader"]["samples_per_trace"] == num_samples assert ds.attrs["attributes"]["gridOverrides"] == grid_override - assert npt.assert_array_equal(ds["shot_point"], shots) + xrt.assert_duckarray_equal(ds["shot_point"], shots) xrt.assert_duckarray_equal(ds["cable"], cables) # assert grid.select_dim("trace") == Dimension(range(1, np.amax(receivers_per_cable) + 1), "trace")