diff --git a/src/mdio/builder/dataset_builder.py b/src/mdio/builder/dataset_builder.py index 1cc51598..b2560e1d 100644 --- a/src/mdio/builder/dataset_builder.py +++ b/src/mdio/builder/dataset_builder.py @@ -150,12 +150,32 @@ def add_coordinate( # noqa: PLR0913 msg = "Adding coordinate with the same name twice is not allowed" raise ValueError(msg) - # Validate that all referenced dimensions are already defined + # Resolve referenced dimensions strictly, except allow a single substitution with 'trace' if present. named_dimensions = [] + trace_dim = _get_named_dimension(self._dimensions, "trace") + resolved_dim_names: list[str] = [] + trace_used = False + missing_dims: list[str] = [] for dim_name in dimensions: nd = _get_named_dimension(self._dimensions, dim_name) + if nd is not None: + if dim_name not in resolved_dim_names: + resolved_dim_names.append(dim_name) + continue + if trace_dim is not None and not trace_used and "trace" not in resolved_dim_names: + resolved_dim_names.append("trace") + trace_used = True + else: + missing_dims.append(dim_name) + + if missing_dims: + msg = f"Pre-existing dimension named {missing_dims[0]!r} is not found" + raise ValueError(msg) + + for resolved_name in resolved_dim_names: + nd = _get_named_dimension(self._dimensions, resolved_name) if nd is None: - msg = f"Pre-existing dimension named {dim_name!r} is not found" + msg = f"Pre-existing dimension named {resolved_name!r} is not found" raise ValueError(msg) named_dimensions.append(nd) @@ -174,7 +194,7 @@ def add_coordinate( # noqa: PLR0913 self.add_variable( name=coord.name, long_name=coord.long_name, - dimensions=dimensions, # dimension names (list[str]) + dimensions=tuple(resolved_dim_names), # resolved dimension names data_type=coord.data_type, compressor=compressor, coordinates=[name], # Use the coordinate name as a reference diff --git a/src/mdio/builder/templates/base.py b/src/mdio/builder/templates/base.py index 50f775dc..544e290f 100644 --- a/src/mdio/builder/templates/base.py +++ b/src/mdio/builder/templates/base.py @@ -82,6 +82,9 @@ def build_dataset( Returns: Dataset: The constructed dataset + + Raises: + ValueError: If coordinate already exists from subclass override. """ self._dim_sizes = sizes @@ -90,6 +93,20 @@ def build_dataset( self._builder = MDIODatasetBuilder(name=name, attributes=attributes) self._add_dimensions() self._add_coordinates() + # Ensure any coordinates declared on the template but not added by _add_coordinates + # are materialized with generic defaults. This handles coordinates added by grid overrides. + for coord_name in self.coordinate_names: + try: + self._builder.add_coordinate( + name=coord_name, + dimensions=self.spatial_dimension_names, + data_type=ScalarType.FLOAT64, + compressor=compressors.Blosc(cname=compressors.BloscCname.zstd), + metadata=CoordinateMetadata(units_v1=self.get_unit_by_key(coord_name)), + ) + except ValueError as exc: # coordinate may already exist + if "same name twice" not in str(exc): + raise self._add_variables() self._add_trace_mask() @@ -241,14 +258,21 @@ def _add_coordinates(self) -> None: ) # Add non-dimension coordinates + # Note: coordinate_names may be modified at runtime by grid overrides, + # so we need to handle dynamic additions gracefully for name in self.coordinate_names: - self._builder.add_coordinate( - name=name, - dimensions=self.spatial_dimension_names, - data_type=ScalarType.FLOAT64, - compressor=compressors.Blosc(cname=compressors.BloscCname.zstd), - metadata=CoordinateMetadata(units_v1=self.get_unit_by_key(name)), - ) + try: + self._builder.add_coordinate( + name=name, + dimensions=self.spatial_dimension_names, + data_type=ScalarType.FLOAT64, + compressor=compressors.Blosc(cname=compressors.BloscCname.zstd), + metadata=CoordinateMetadata(units_v1=self.get_unit_by_key(name)), + ) + except ValueError as exc: + # Coordinate may already exist from subclass override + if "same name twice" not in str(exc): + raise def _add_trace_mask(self) -> None: """Add trace mask variables.""" diff --git a/src/mdio/converters/segy.py b/src/mdio/converters/segy.py index e2fd6b35..ec8d01b7 100644 --- a/src/mdio/converters/segy.py +++ b/src/mdio/converters/segy.py @@ -134,6 +134,70 @@ def grid_density_qc(grid: Grid, num_traces: int) -> None: raise GridTraceSparsityError(grid.shape, num_traces, msg) +def _update_template_from_grid_overrides( + template: AbstractDatasetTemplate, + grid_overrides: dict[str, Any] | None, + segy_dimensions: list[Dimension], + full_chunk_shape: tuple[int, ...], + chunk_size: tuple[int, ...], +) -> None: + """Update template attributes to match grid plan results after grid overrides. + + This function modifies the template in-place to reflect changes from grid overrides: + - Updates chunk shape if it changed due to overrides + - Updates dimension names if they changed due to overrides + - Adds non-binned dimensions as coordinates for NonBinned override + + Args: + template: The template to update + grid_overrides: Grid override configuration + segy_dimensions: Dimensions returned from grid planning + full_chunk_shape: Original template chunk shape + chunk_size: Chunk size returned from grid planning + """ + # Update template to match grid_plan results after grid overrides + # Extract actual spatial dimensions from segy_dimensions (excluding vertical dimension) + actual_spatial_dims = tuple(dim.name for dim in segy_dimensions[:-1]) + + # Align chunk_size with actual dimensions - truncate if dimensions were filtered out + num_actual_spatial = len(actual_spatial_dims) + num_chunk_spatial = len(chunk_size) - 1 # Exclude vertical dimension chunk + if num_actual_spatial != num_chunk_spatial: + # Truncate chunk_size to match actual dimensions + chunk_size = chunk_size[:num_actual_spatial] + (chunk_size[-1],) + + if full_chunk_shape != chunk_size: + logger.debug( + "Adjusting template chunk shape from %s to %s to match grid after overrides", + full_chunk_shape, + chunk_size, + ) + template._var_chunk_shape = chunk_size + + # Update dimensions if they don't match grid_plan results + if template.spatial_dimension_names != actual_spatial_dims: + logger.debug( + "Adjusting template dimensions from %s to %s to match grid after overrides", + template.spatial_dimension_names, + actual_spatial_dims, + ) + template._dim_names = actual_spatial_dims + (template.trace_domain,) + + # If using NonBinned override, expose non-binned dims as logical coordinates on the template instance + if grid_overrides and "NonBinned" in grid_overrides and "non_binned_dims" in grid_overrides: + non_binned_dims = tuple(grid_overrides["non_binned_dims"]) + if non_binned_dims: + logger.debug( + "NonBinned grid override: exposing non-binned dims as coordinates: %s", + non_binned_dims, + ) + # Append any missing names; keep existing order and avoid duplicates + existing = set(template.coordinate_names) + to_add = tuple(n for n in non_binned_dims if n not in existing) + if to_add: + template._logical_coord_names = template._logical_coord_names + to_add + + def _scan_for_headers( segy_file_kwargs: SegyFileArguments, segy_file_info: SegyFileInfo, @@ -143,7 +207,11 @@ def _scan_for_headers( """Extract trace dimensions and index headers from the SEG-Y file. This is an expensive operation. - It scans the SEG-Y file in chunks by using ProcessPoolExecutor + It scans the SEG-Y file in chunks by using ProcessPoolExecutor. + + Note: + If grid_overrides are applied to the template before calling this function, + the chunk_size returned from get_grid_plan should match the template's chunk shape. """ full_chunk_shape = template.full_chunk_shape segy_dimensions, chunk_size, segy_headers = get_grid_plan( @@ -154,13 +222,15 @@ def _scan_for_headers( chunksize=full_chunk_shape, grid_overrides=grid_overrides, ) - if full_chunk_shape != chunk_size: - # TODO(Dmitriy): implement grid overrides - # https://github.com/TGSAI/mdio-python/issues/585 - # The returned 'chunksize' is used only for grid_overrides. We will need to use it when full - # support for grid overrides is implemented - err = "Support for changing full_chunk_shape in grid overrides is not yet implemented" - raise NotImplementedError(err) + + _update_template_from_grid_overrides( + template=template, + grid_overrides=grid_overrides, + segy_dimensions=segy_dimensions, + full_chunk_shape=full_chunk_shape, + chunk_size=chunk_size, + ) + return segy_dimensions, segy_headers diff --git a/src/mdio/segy/geometry.py b/src/mdio/segy/geometry.py index bdb0b81b..f0e8a1d3 100644 --- a/src/mdio/segy/geometry.py +++ b/src/mdio/segy/geometry.py @@ -24,6 +24,8 @@ from numpy.typing import NDArray from segy.arrays import HeaderArray + from mdio.builder.templates.base import AbstractDatasetTemplate + logger = logging.getLogger(__name__) @@ -267,7 +269,8 @@ def analyze_non_indexed_headers(index_headers: HeaderArray, dtype: DTypeLike = n header_names = [] for header_key in index_headers.dtype.names: if header_key != "trace": - unique_headers[header_key] = np.sort(np.unique(index_headers[header_key])) + unique_vals = np.sort(np.unique(index_headers[header_key])) + unique_headers[header_key] = unique_vals header_names.append(header_key) total_depth += 1 @@ -302,6 +305,7 @@ def transform( self, index_headers: HeaderArray, grid_overrides: dict[str, bool | int], + template: AbstractDatasetTemplate, # noqa: ARG002 ) -> NDArray: """Perform the grid transform.""" @@ -378,11 +382,35 @@ def transform( self, index_headers: HeaderArray, grid_overrides: dict[str, bool | int], + template: AbstractDatasetTemplate, ) -> NDArray: """Perform the grid transform.""" self.validate(index_headers, grid_overrides) - return analyze_non_indexed_headers(index_headers) + # Filter out coordinate fields, keep only dimensions for trace indexing + coord_fields = set(template.coordinate_names) if template else set() + + # For NonBinned: non_binned_dims should be excluded from trace indexing grouping + # because they become coordinates indexed by the trace dimension, not grouping keys. + # The trace index should count all traces per remaining dimension combination. + non_binned_dims = set(grid_overrides.get("non_binned_dims", [])) if grid_overrides else set() + + dim_fields = [ + name + for name in index_headers.dtype.names + if name != "trace" and name not in coord_fields and name not in non_binned_dims + ] + + # Create trace indices on dimension fields only + dim_headers = index_headers[dim_fields] if dim_fields else index_headers + dim_headers_with_trace = analyze_non_indexed_headers(dim_headers) + + # Add trace field back to full headers + if dim_headers_with_trace is not None and "trace" in dim_headers_with_trace.dtype.names: + trace_values = np.array(dim_headers_with_trace["trace"]) + index_headers = rfn.append_fields(index_headers, "trace", trace_values, usemask=False) + + return index_headers def transform_index_names(self, index_names: Sequence[str]) -> Sequence[str]: """Insert dimension "trace" to the sample-1 dimension.""" @@ -403,19 +431,51 @@ def transform_chunksize( class NonBinned(DuplicateIndex): - """Automatically index traces in a single specified axis - trace.""" + """Handle non-binned dimensions by converting them to a trace dimension with coordinates. + + This override takes dimensions that are not regularly sampled (non-binned) and converts + them into a single 'trace' dimension. The original non-binned dimensions become coordinates + indexed by the trace dimension. + + Example: + Template with dimensions [shot_point, cable, channel, azimuth, offset, sample] + and non_binned_dims=['azimuth', 'offset'] becomes: + - dimensions: [shot_point, cable, channel, trace, sample] + - coordinates: azimuth and offset with dimensions [shot_point, cable, channel, trace] + + Attributes: + required_keys: No required keys for this override. + required_parameters: Set containing 'chunksize' and 'non_binned_dims'. + """ required_keys = None - required_parameters = {"chunksize"} + required_parameters = {"chunksize", "non_binned_dims"} + + def validate(self, index_headers: HeaderArray, grid_overrides: dict[str, bool | int]) -> None: + """Validate if this transform should run on the type of data.""" + self.check_required_params(grid_overrides) + + # Validate that non_binned_dims is a list + non_binned_dims = grid_overrides.get("non_binned_dims", []) + if not isinstance(non_binned_dims, list): + msg = f"non_binned_dims must be a list, got {type(non_binned_dims)}" + raise ValueError(msg) + + # Validate that all non-binned dimensions exist in headers + missing_dims = set(non_binned_dims) - set(index_headers.dtype.names) + if missing_dims: + msg = f"Non-binned dimensions {missing_dims} not found in index headers" + raise ValueError(msg) def transform_chunksize( self, chunksize: Sequence[int], grid_overrides: dict[str, bool | int], ) -> Sequence[int]: - """Perform the transform of chunksize.""" + """Insert chunksize for trace dimension at N-1 position.""" new_chunks = list(chunksize) - new_chunks.insert(-1, grid_overrides["chunksize"]) + trace_chunksize = grid_overrides["chunksize"] + new_chunks.insert(-1, trace_chunksize) return tuple(new_chunks) @@ -434,6 +494,7 @@ def transform( self, index_headers: HeaderArray, grid_overrides: dict[str, bool | int], + template: AbstractDatasetTemplate, # noqa: ARG002 ) -> NDArray: """Perform the grid transform.""" self.validate(index_headers, grid_overrides) @@ -471,6 +532,7 @@ def transform( self, index_headers: HeaderArray, grid_overrides: dict[str, bool | int], + template: AbstractDatasetTemplate, # noqa: ARG002 ) -> NDArray: """Perform the grid transform.""" self.validate(index_headers, grid_overrides) @@ -528,6 +590,9 @@ def get_allowed_parameters(self) -> set: parameters.update(command.required_parameters) + # Add optional parameters that are not strictly required but are valid + parameters.add("non_binned_dims") + return parameters def run( @@ -536,6 +601,7 @@ def run( index_names: Sequence[str], grid_overrides: dict[str, bool], chunksize: Sequence[int] | None = None, + template: AbstractDatasetTemplate | None = None, ) -> tuple[HeaderArray, tuple[str], tuple[int]]: """Run grid overrides and return result.""" for override in grid_overrides: @@ -546,7 +612,7 @@ def run( raise GridOverrideUnknownError(override) function = self.commands[override].transform - index_headers = function(index_headers, grid_overrides=grid_overrides) + index_headers = function(index_headers, grid_overrides=grid_overrides, template=template) function = self.commands[override].transform_index_names index_names = function(index_names) diff --git a/src/mdio/segy/utilities.py b/src/mdio/segy/utilities.py index 195a02c8..d1e42416 100644 --- a/src/mdio/segy/utilities.py +++ b/src/mdio/segy/utilities.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) -def get_grid_plan( # noqa: C901, PLR0913 +def get_grid_plan( # noqa: C901, PLR0912, PLR0913, PLR0915 segy_file_kwargs: SegyFileArguments, segy_file_info: SegyFileInfo, chunksize: tuple[int, ...] | None, @@ -61,15 +61,23 @@ def get_grid_plan( # noqa: C901, PLR0913 # Keep only dimension and non-dimension coordinates excluding the vertical axis horizontal_dimensions = template.spatial_dimension_names horizontal_coordinates = horizontal_dimensions + template.coordinate_names + # Exclude calculated dimensions - they don't exist in SEG-Y headers + calculated_dims = set(template.calculated_dimension_names) - # Remove any to be computed fields + # Remove any to be computed fields - preserve order by using list comprehension instead of set operations computed_fields = set(template.calculated_dimension_names) - horizontal_coordinates = tuple(set(horizontal_coordinates) - computed_fields) + horizontal_coordinates = tuple(c for c in horizontal_coordinates if c not in computed_fields) + + # Ensure non_binned_dims are included in the headers to parse, even if not in template + if grid_overrides and "non_binned_dims" in grid_overrides: + for dim in grid_overrides["non_binned_dims"]: + if dim not in horizontal_coordinates: + horizontal_coordinates = horizontal_coordinates + (dim,) headers_subset = parse_headers( segy_file_kwargs=segy_file_kwargs, num_traces=segy_file_info.num_traces, - subset=horizontal_coordinates, + subset=tuple(c for c in horizontal_coordinates if c not in calculated_dims), ) # Handle grid overrides. @@ -79,8 +87,35 @@ def get_grid_plan( # noqa: C901, PLR0913 horizontal_coordinates, chunksize=chunksize, grid_overrides=grid_overrides, + template=template, ) + # After grid overrides, determine final spatial dimensions and their chunk sizes + non_binned_dims = set() + if "NonBinned" in grid_overrides and "non_binned_dims" in grid_overrides: + non_binned_dims = set(grid_overrides["non_binned_dims"]) + + # Create mapping from dimension name to original chunk size for easy lookup + original_spatial_dims = list(template.spatial_dimension_names) + original_chunks = list(template.full_chunk_shape[:-1]) # Exclude vertical (sample/time) dimension + dim_to_chunk = dict(zip(original_spatial_dims, original_chunks, strict=True)) + + # Final spatial dimensions: keep trace and original dims, exclude non-binned dims + final_spatial_dims = [] + final_spatial_chunks = [] + for name in horizontal_coordinates: + if name in non_binned_dims: + continue # Skip dimensions that became coordinates + if name == "trace": + # Special handling for trace dimension + chunk_val = int(grid_overrides.get("chunksize", 1)) if "NonBinned" in grid_overrides else 1 + final_spatial_dims.append(name) + final_spatial_chunks.append(chunk_val) + elif name in dim_to_chunk: + # Use original chunk size for known dimensions + final_spatial_dims.append(name) + final_spatial_chunks.append(dim_to_chunk[name]) + if len(computed_fields) > 0 and not computed_fields.issubset(headers_subset.dtype.names): err = ( f"Required computed fields {sorted(computed_fields)} for template {template.name} " @@ -88,8 +123,38 @@ def get_grid_plan( # noqa: C901, PLR0913 ) raise ValueError(err) + # Create dimensions from final_spatial_dims plus any computed fields that were added by grid overrides + all_dimension_names = list(final_spatial_dims) + added_computed_fields = [] + for computed_field in computed_fields: + if computed_field in headers_subset.dtype.names and computed_field not in all_dimension_names: + # Insert in template order + if computed_field in template.spatial_dimension_names: + insert_idx = template.spatial_dimension_names.index(computed_field) + # Find position in all_dimension_names that corresponds to this template position + actual_idx = min(insert_idx, len(all_dimension_names)) + all_dimension_names.insert(actual_idx, computed_field) + # Track where we inserted and what chunk size it should have + template_chunk_idx = template.spatial_dimension_names.index(computed_field) + chunk_val = template.full_chunk_shape[template_chunk_idx] + added_computed_fields.append((actual_idx, chunk_val)) + else: + all_dimension_names.append(computed_field) + added_computed_fields.append((len(all_dimension_names) - 1, 1)) + + # Build chunksize including chunks for computed fields + if added_computed_fields: + chunk_list = list(final_spatial_chunks) + for insert_idx, chunk_val in sorted(added_computed_fields, reverse=True): + chunk_list.insert(insert_idx, chunk_val) + chunksize = tuple(chunk_list + [template.full_chunk_shape[-1]]) + else: + chunksize = tuple(final_spatial_chunks + [template.full_chunk_shape[-1]]) + dimensions = [] - for dim_name in horizontal_dimensions: + for dim_name in all_dimension_names: + if dim_name not in headers_subset.dtype.names: + continue dim_unique = np.unique(headers_subset[dim_name]) dimensions.append(Dimension(coords=dim_unique, name=dim_name)) diff --git a/tests/integration/test_import_streamer_grid_overrides.py b/tests/integration/test_import_streamer_grid_overrides.py index c90d8c8c..7077ba9f 100644 --- a/tests/integration/test_import_streamer_grid_overrides.py +++ b/tests/integration/test_import_streamer_grid_overrides.py @@ -7,7 +7,6 @@ import dask import numpy as np -import numpy.testing as npt import pytest import xarray.testing as xrt from tests.integration.conftest import get_segy_mock_4d_spec @@ -28,12 +27,11 @@ os.environ["MDIO__IMPORT__SAVE_SEGY_FILE_HEADER"] = "true" -# TODO(Altay): Finish implementing these grid overrides. -# https://github.com/TGSAI/mdio-python/issues/612 -@pytest.mark.skip(reason="NonBinned and HasDuplicates haven't been properly implemented yet.") -@pytest.mark.parametrize("grid_override", [{"NonBinned": True}, {"HasDuplicates": True}]) +@pytest.mark.parametrize( + "grid_override", [{"NonBinned": True, "chunksize": 4, "non_binned_dims": ["channel"]}, {"HasDuplicates": True}] +) @pytest.mark.parametrize("chan_header_type", [StreamerShotGeometryType.C]) -class TestImport4DNonReg: # pragma: no cover - tests is skipped +class TestImport4DNonReg: """Test for 4D segy import with grid overrides.""" def test_import_4d_segy( # noqa: PLR0913 @@ -67,16 +65,27 @@ def test_import_4d_segy( # noqa: PLR0913 assert ds["segy_file_header"].attrs["binaryHeader"]["samples_per_trace"] == num_samples assert ds.attrs["attributes"]["gridOverrides"] == grid_override - assert npt.assert_array_equal(ds["shot_point"], shots) + xrt.assert_duckarray_equal(ds["shot_point"], shots) xrt.assert_duckarray_equal(ds["cable"], cables) - # assert grid.select_dim("trace") == Dimension(range(1, np.amax(receivers_per_cable) + 1), "trace") + # Both HasDuplicates and NonBinned should create a trace dimension expected = list(range(1, np.amax(receivers_per_cable) + 1)) xrt.assert_duckarray_equal(ds["trace"], expected) times_expected = list(range(0, num_samples, 1)) xrt.assert_duckarray_equal(ds["time"], times_expected) + # Check trace chunk size based on grid override + trace_chunks = ds["amplitude"].chunksizes.get("trace", None) + if trace_chunks is not None: + if "NonBinned" in grid_override: + # NonBinned uses specified chunksize for trace dimension + expected_chunksize = grid_override.get("chunksize", 1) + assert all(chunk == expected_chunksize for chunk in trace_chunks) + else: + # HasDuplicates uses chunksize of 1 for trace dimension + assert all(chunk == 1 for chunk in trace_chunks) + @pytest.mark.parametrize("grid_override", [{"AutoChannelWrap": True}, None]) @pytest.mark.parametrize("chan_header_type", [StreamerShotGeometryType.A, StreamerShotGeometryType.B]) diff --git a/tests/unit/test_segy_grid_overrides.py b/tests/unit/test_segy_grid_overrides.py index bebf6be8..250d6a49 100644 --- a/tests/unit/test_segy_grid_overrides.py +++ b/tests/unit/test_segy_grid_overrides.py @@ -103,10 +103,10 @@ def test_duplicates(self, mock_streamer_headers: dict[str, npt.NDArray]) -> None def test_non_binned(self, mock_streamer_headers: dict[str, npt.NDArray]) -> None: """Test the NonBinned Grid Override command.""" index_names = ("shot_point", "cable") - grid_overrides = {"NonBinned": True, "chunksize": 4} + grid_overrides = {"NonBinned": True, "chunksize": 4, "non_binned_dims": ["channel"]} - # Remove channel header - streamer_headers = mock_streamer_headers[list(index_names)] + # Keep channel header for non-binned processing + streamer_headers = mock_streamer_headers chunksize = (4, 4, 8) new_headers, new_names, new_chunks = run_override( @@ -123,7 +123,9 @@ def test_non_binned(self, mock_streamer_headers: dict[str, npt.NDArray]) -> None assert_array_equal(dims[0].coords, SHOTS) assert_array_equal(dims[1].coords, CABLES) - assert_array_equal(dims[2].coords, RECEIVERS) + # Trace coords are the unique channel values (1-20) + expected_trace_coords = np.arange(1, 21, dtype="int32") + assert_array_equal(dims[2].coords, expected_trace_coords) class TestStreamerGridOverrides: