Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions src/mdio/builder/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,32 @@ def add_coordinate( # noqa: PLR0913
msg = "Adding coordinate with the same name twice is not allowed"
raise ValueError(msg)

# Validate that all referenced dimensions are already defined
# Resolve referenced dimensions strictly, except allow a single substitution with 'trace' if present.
named_dimensions = []
trace_dim = _get_named_dimension(self._dimensions, "trace")
resolved_dim_names: list[str] = []
trace_used = False
missing_dims: list[str] = []
for dim_name in dimensions:
nd = _get_named_dimension(self._dimensions, dim_name)
if nd is not None:
if dim_name not in resolved_dim_names:
resolved_dim_names.append(dim_name)
continue
if trace_dim is not None and not trace_used and "trace" not in resolved_dim_names:
resolved_dim_names.append("trace")
trace_used = True
else:
missing_dims.append(dim_name)

if missing_dims:
msg = f"Pre-existing dimension named {missing_dims[0]!r} is not found"
raise ValueError(msg)

for resolved_name in resolved_dim_names:
nd = _get_named_dimension(self._dimensions, resolved_name)
if nd is None:
msg = f"Pre-existing dimension named {dim_name!r} is not found"
msg = f"Pre-existing dimension named {resolved_name!r} is not found"
raise ValueError(msg)
named_dimensions.append(nd)

Expand All @@ -174,7 +194,7 @@ def add_coordinate( # noqa: PLR0913
self.add_variable(
name=coord.name,
long_name=coord.long_name,
dimensions=dimensions, # dimension names (list[str])
dimensions=tuple(resolved_dim_names), # resolved dimension names
data_type=coord.data_type,
compressor=compressor,
coordinates=[name], # Use the coordinate name as a reference
Expand Down
38 changes: 31 additions & 7 deletions src/mdio/builder/templates/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def build_dataset(

Returns:
Dataset: The constructed dataset

Raises:
ValueError: If coordinate already exists from subclass override.
"""
self._dim_sizes = sizes

Expand All @@ -90,6 +93,20 @@ def build_dataset(
self._builder = MDIODatasetBuilder(name=name, attributes=attributes)
self._add_dimensions()
self._add_coordinates()
# Ensure any coordinates declared on the template but not added by _add_coordinates
# are materialized with generic defaults. This handles coordinates added by grid overrides.
for coord_name in self.coordinate_names:
try:
self._builder.add_coordinate(
name=coord_name,
dimensions=self.spatial_dimension_names,
data_type=ScalarType.FLOAT64,
compressor=compressors.Blosc(cname=compressors.BloscCname.zstd),
metadata=CoordinateMetadata(units_v1=self.get_unit_by_key(coord_name)),
)
except ValueError as exc: # coordinate may already exist
if "same name twice" not in str(exc):
raise
self._add_variables()
self._add_trace_mask()

Expand Down Expand Up @@ -241,14 +258,21 @@ def _add_coordinates(self) -> None:
)

# Add non-dimension coordinates
# Note: coordinate_names may be modified at runtime by grid overrides,
# so we need to handle dynamic additions gracefully
for name in self.coordinate_names:
self._builder.add_coordinate(
name=name,
dimensions=self.spatial_dimension_names,
data_type=ScalarType.FLOAT64,
compressor=compressors.Blosc(cname=compressors.BloscCname.zstd),
metadata=CoordinateMetadata(units_v1=self.get_unit_by_key(name)),
)
try:
self._builder.add_coordinate(
name=name,
dimensions=self.spatial_dimension_names,
data_type=ScalarType.FLOAT64,
compressor=compressors.Blosc(cname=compressors.BloscCname.zstd),
metadata=CoordinateMetadata(units_v1=self.get_unit_by_key(name)),
)
except ValueError as exc:
# Coordinate may already exist from subclass override
if "same name twice" not in str(exc):
raise

def _add_trace_mask(self) -> None:
"""Add trace mask variables."""
Expand Down
86 changes: 78 additions & 8 deletions src/mdio/converters/segy.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,70 @@ def grid_density_qc(grid: Grid, num_traces: int) -> None:
raise GridTraceSparsityError(grid.shape, num_traces, msg)


def _update_template_from_grid_overrides(
template: AbstractDatasetTemplate,
grid_overrides: dict[str, Any] | None,
segy_dimensions: list[Dimension],
full_chunk_shape: tuple[int, ...],
chunk_size: tuple[int, ...],
) -> None:
"""Update template attributes to match grid plan results after grid overrides.

This function modifies the template in-place to reflect changes from grid overrides:
- Updates chunk shape if it changed due to overrides
- Updates dimension names if they changed due to overrides
- Adds non-binned dimensions as coordinates for NonBinned override

Args:
template: The template to update
grid_overrides: Grid override configuration
segy_dimensions: Dimensions returned from grid planning
full_chunk_shape: Original template chunk shape
chunk_size: Chunk size returned from grid planning
"""
# Update template to match grid_plan results after grid overrides
# Extract actual spatial dimensions from segy_dimensions (excluding vertical dimension)
actual_spatial_dims = tuple(dim.name for dim in segy_dimensions[:-1])

# Align chunk_size with actual dimensions - truncate if dimensions were filtered out
num_actual_spatial = len(actual_spatial_dims)
num_chunk_spatial = len(chunk_size) - 1 # Exclude vertical dimension chunk
if num_actual_spatial != num_chunk_spatial:
# Truncate chunk_size to match actual dimensions
chunk_size = chunk_size[:num_actual_spatial] + (chunk_size[-1],)

if full_chunk_shape != chunk_size:
logger.debug(
"Adjusting template chunk shape from %s to %s to match grid after overrides",
full_chunk_shape,
chunk_size,
)
template._var_chunk_shape = chunk_size

# Update dimensions if they don't match grid_plan results
if template.spatial_dimension_names != actual_spatial_dims:
logger.debug(
"Adjusting template dimensions from %s to %s to match grid after overrides",
template.spatial_dimension_names,
actual_spatial_dims,
)
template._dim_names = actual_spatial_dims + (template.trace_domain,)

# If using NonBinned override, expose non-binned dims as logical coordinates on the template instance
if grid_overrides and "NonBinned" in grid_overrides and "non_binned_dims" in grid_overrides:
non_binned_dims = tuple(grid_overrides["non_binned_dims"])
if non_binned_dims:
logger.debug(
"NonBinned grid override: exposing non-binned dims as coordinates: %s",
non_binned_dims,
)
# Append any missing names; keep existing order and avoid duplicates
existing = set(template.coordinate_names)
to_add = tuple(n for n in non_binned_dims if n not in existing)
if to_add:
template._logical_coord_names = template._logical_coord_names + to_add


def _scan_for_headers(
segy_file_kwargs: SegyFileArguments,
segy_file_info: SegyFileInfo,
Expand All @@ -143,7 +207,11 @@ def _scan_for_headers(
"""Extract trace dimensions and index headers from the SEG-Y file.

This is an expensive operation.
It scans the SEG-Y file in chunks by using ProcessPoolExecutor
It scans the SEG-Y file in chunks by using ProcessPoolExecutor.

Note:
If grid_overrides are applied to the template before calling this function,
the chunk_size returned from get_grid_plan should match the template's chunk shape.
"""
full_chunk_shape = template.full_chunk_shape
segy_dimensions, chunk_size, segy_headers = get_grid_plan(
Expand All @@ -154,13 +222,15 @@ def _scan_for_headers(
chunksize=full_chunk_shape,
grid_overrides=grid_overrides,
)
if full_chunk_shape != chunk_size:
# TODO(Dmitriy): implement grid overrides
# https://github.com/TGSAI/mdio-python/issues/585
# The returned 'chunksize' is used only for grid_overrides. We will need to use it when full
# support for grid overrides is implemented
err = "Support for changing full_chunk_shape in grid overrides is not yet implemented"
raise NotImplementedError(err)

_update_template_from_grid_overrides(
template=template,
grid_overrides=grid_overrides,
segy_dimensions=segy_dimensions,
full_chunk_shape=full_chunk_shape,
chunk_size=chunk_size,
)

return segy_dimensions, segy_headers


Expand Down
80 changes: 73 additions & 7 deletions src/mdio/segy/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from numpy.typing import NDArray
from segy.arrays import HeaderArray

from mdio.builder.templates.base import AbstractDatasetTemplate


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -267,7 +269,8 @@ def analyze_non_indexed_headers(index_headers: HeaderArray, dtype: DTypeLike = n
header_names = []
for header_key in index_headers.dtype.names:
if header_key != "trace":
unique_headers[header_key] = np.sort(np.unique(index_headers[header_key]))
unique_vals = np.sort(np.unique(index_headers[header_key]))
unique_headers[header_key] = unique_vals
header_names.append(header_key)
total_depth += 1

Expand Down Expand Up @@ -302,6 +305,7 @@ def transform(
self,
index_headers: HeaderArray,
grid_overrides: dict[str, bool | int],
template: AbstractDatasetTemplate, # noqa: ARG002
) -> NDArray:
"""Perform the grid transform."""

Expand Down Expand Up @@ -378,11 +382,35 @@ def transform(
self,
index_headers: HeaderArray,
grid_overrides: dict[str, bool | int],
template: AbstractDatasetTemplate,
) -> NDArray:
"""Perform the grid transform."""
self.validate(index_headers, grid_overrides)

return analyze_non_indexed_headers(index_headers)
# Filter out coordinate fields, keep only dimensions for trace indexing
coord_fields = set(template.coordinate_names) if template else set()

# For NonBinned: non_binned_dims should be excluded from trace indexing grouping
# because they become coordinates indexed by the trace dimension, not grouping keys.
# The trace index should count all traces per remaining dimension combination.
non_binned_dims = set(grid_overrides.get("non_binned_dims", [])) if grid_overrides else set()

dim_fields = [
name
for name in index_headers.dtype.names
if name != "trace" and name not in coord_fields and name not in non_binned_dims
]

# Create trace indices on dimension fields only
dim_headers = index_headers[dim_fields] if dim_fields else index_headers
dim_headers_with_trace = analyze_non_indexed_headers(dim_headers)

# Add trace field back to full headers
if dim_headers_with_trace is not None and "trace" in dim_headers_with_trace.dtype.names:
trace_values = np.array(dim_headers_with_trace["trace"])
index_headers = rfn.append_fields(index_headers, "trace", trace_values, usemask=False)

return index_headers

def transform_index_names(self, index_names: Sequence[str]) -> Sequence[str]:
"""Insert dimension "trace" to the sample-1 dimension."""
Expand All @@ -403,19 +431,51 @@ def transform_chunksize(


class NonBinned(DuplicateIndex):
"""Automatically index traces in a single specified axis - trace."""
"""Handle non-binned dimensions by converting them to a trace dimension with coordinates.

This override takes dimensions that are not regularly sampled (non-binned) and converts
them into a single 'trace' dimension. The original non-binned dimensions become coordinates
indexed by the trace dimension.

Example:
Template with dimensions [shot_point, cable, channel, azimuth, offset, sample]
and non_binned_dims=['azimuth', 'offset'] becomes:
- dimensions: [shot_point, cable, channel, trace, sample]
- coordinates: azimuth and offset with dimensions [shot_point, cable, channel, trace]

Attributes:
required_keys: No required keys for this override.
required_parameters: Set containing 'chunksize' and 'non_binned_dims'.
"""

required_keys = None
required_parameters = {"chunksize"}
required_parameters = {"chunksize", "non_binned_dims"}

def validate(self, index_headers: HeaderArray, grid_overrides: dict[str, bool | int]) -> None:
"""Validate if this transform should run on the type of data."""
self.check_required_params(grid_overrides)

# Validate that non_binned_dims is a list
non_binned_dims = grid_overrides.get("non_binned_dims", [])
if not isinstance(non_binned_dims, list):
msg = f"non_binned_dims must be a list, got {type(non_binned_dims)}"
raise ValueError(msg)

# Validate that all non-binned dimensions exist in headers
missing_dims = set(non_binned_dims) - set(index_headers.dtype.names)
if missing_dims:
msg = f"Non-binned dimensions {missing_dims} not found in index headers"
raise ValueError(msg)

def transform_chunksize(
self,
chunksize: Sequence[int],
grid_overrides: dict[str, bool | int],
) -> Sequence[int]:
"""Perform the transform of chunksize."""
"""Insert chunksize for trace dimension at N-1 position."""
new_chunks = list(chunksize)
new_chunks.insert(-1, grid_overrides["chunksize"])
trace_chunksize = grid_overrides["chunksize"]
new_chunks.insert(-1, trace_chunksize)
return tuple(new_chunks)


Expand All @@ -434,6 +494,7 @@ def transform(
self,
index_headers: HeaderArray,
grid_overrides: dict[str, bool | int],
template: AbstractDatasetTemplate, # noqa: ARG002
) -> NDArray:
"""Perform the grid transform."""
self.validate(index_headers, grid_overrides)
Expand Down Expand Up @@ -471,6 +532,7 @@ def transform(
self,
index_headers: HeaderArray,
grid_overrides: dict[str, bool | int],
template: AbstractDatasetTemplate, # noqa: ARG002
) -> NDArray:
"""Perform the grid transform."""
self.validate(index_headers, grid_overrides)
Expand Down Expand Up @@ -528,6 +590,9 @@ def get_allowed_parameters(self) -> set:

parameters.update(command.required_parameters)

# Add optional parameters that are not strictly required but are valid
parameters.add("non_binned_dims")

return parameters

def run(
Expand All @@ -536,6 +601,7 @@ def run(
index_names: Sequence[str],
grid_overrides: dict[str, bool],
chunksize: Sequence[int] | None = None,
template: AbstractDatasetTemplate | None = None,
) -> tuple[HeaderArray, tuple[str], tuple[int]]:
"""Run grid overrides and return result."""
for override in grid_overrides:
Expand All @@ -546,7 +612,7 @@ def run(
raise GridOverrideUnknownError(override)

function = self.commands[override].transform
index_headers = function(index_headers, grid_overrides=grid_overrides)
index_headers = function(index_headers, grid_overrides=grid_overrides, template=template)

function = self.commands[override].transform_index_names
index_names = function(index_names)
Expand Down
Loading