Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions core/src/config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ export algorithm,
is_adaptive,
node_types,
node_type,
node_id_relation,
node_kinds,
table_name,
sql_table_name,
Expand All @@ -53,6 +54,46 @@ function sql_table_name(table_type::Type{<:Table})::String
return string(node_type(table_type), " / ", table_name(table_type))
end

"""
node_id_relation(::Type{<:Table}) -> Symbol

Return the expected relationship between the node_ids in a table and the node_ids
in the Node table for its node type.

- `:equal` – table node_ids must exactly match the Node table (default)
- `:partition` – tables in the same partition group must be pairwise disjoint
and their union must equal the Node table
- `:subset` – table node_ids must be a subset of the Node table
"""
node_id_relation(::Type{<:Table}) = :equal

# Static/Time pairs form a partition of all node_ids
node_id_relation(::Type{Schema.Pump.Static}) = :partition
node_id_relation(::Type{Schema.Pump.Time}) = :partition
node_id_relation(::Type{Schema.Outlet.Static}) = :partition
node_id_relation(::Type{Schema.Outlet.Time}) = :partition
node_id_relation(::Type{Schema.LevelBoundary.Static}) = :partition
node_id_relation(::Type{Schema.LevelBoundary.Time}) = :partition
node_id_relation(::Type{Schema.FlowBoundary.Static}) = :partition
node_id_relation(::Type{Schema.FlowBoundary.Time}) = :partition
node_id_relation(::Type{Schema.TabulatedRatingCurve.Static}) = :partition
node_id_relation(::Type{Schema.TabulatedRatingCurve.Time}) = :partition
node_id_relation(::Type{Schema.PidControl.Static}) = :partition
node_id_relation(::Type{Schema.PidControl.Time}) = :partition
node_id_relation(::Type{Schema.UserDemand.Static}) = :partition
node_id_relation(::Type{Schema.UserDemand.Time}) = :partition
node_id_relation(::Type{Schema.LevelDemand.Static}) = :partition
node_id_relation(::Type{Schema.LevelDemand.Time}) = :partition
node_id_relation(::Type{Schema.FlowDemand.Static}) = :partition
node_id_relation(::Type{Schema.FlowDemand.Time}) = :partition

# Subset tables: node_ids are allowed to be a subset
node_id_relation(::Type{Schema.Basin.Static}) = :subset
node_id_relation(::Type{Schema.Basin.Time}) = :subset
node_id_relation(::Type{Schema.Basin.Subgrid}) = :subset
node_id_relation(::Type{Schema.Basin.SubgridTime}) = :subset
node_id_relation(::Type{Schema.Observation.Time}) = :subset

"[:Basin, Terminal, ...]"
const node_types::Vector{Symbol} = filter(
name -> getfield(Schema, name) isa Module && name !== :Schema,
Expand Down
2 changes: 2 additions & 0 deletions python/ribasim/ribasim/geometry/area.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


class BasinAreaSchema(_GeoBaseSchema):
_node_id_relation: str = "subset"
fid: Index[Int32] = pa.Field(default=0, check_name=True)
node_id: Series[Int32] = pa.Field(nullable=False, default=0)
geometry: GeoSeries[MultiPolygon] = pa.Field(default=None, nullable=True)
Expand All @@ -20,6 +21,7 @@ def convert_to_multi(cls, series):


class FlowBoundaryAreaSchema(_GeoBaseSchema):
_node_id_relation: str = "subset"
fid: Index[Int32] = pa.Field(default=0, check_name=True)
node_id: Series[Int32] = pa.Field(nullable=False, default=0)
geometry: GeoSeries[MultiPolygon] = pa.Field(default=None, nullable=True)
Expand Down
82 changes: 82 additions & 0 deletions python/ribasim/ribasim/geometry/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,87 @@
node_ids.update(table._node_ids())
return node_ids

def _validate_node_ids(self) -> None:
"""Validate that node_ids in data tables are consistent with the Node table.

Each table's schema defines a ``_node_id_relation`` that describes the
expected relationship between its node_ids and the full set of node_ids
for this node type:

- ``"equal"``: table node_ids must exactly match the node table.
- ``"partition"``: all partition tables must be pairwise disjoint and
their union must equal the node table.
- ``"subset"``: table node_ids must be a subset of the node table.
"""
node_table = self.node
if node_table is None or node_table.df is None or node_table.df.empty:
return

expected_ids: set[int] = set(node_table.df.index)
node_type = self.__class__.__name__

partition_tables: list[tuple[str, set[int]]] = []
errors: list[str] = []

for key in self._fields():
attr = getattr(self, key)
if not isinstance(attr, TableModel) or attr.df is None or key == "node":
continue

table_ids = attr._node_ids()
if not table_ids:
continue

relation = getattr(attr.tableschema(), "_node_id_relation", "equal")

if relation == "equal":
if table_ids != expected_ids:
missing = expected_ids - table_ids
extra = table_ids - expected_ids
parts = []
if missing:
parts.append(f"missing node_ids {missing}")
if extra:
parts.append(f"unexpected node_ids {extra}")
errors.append(f"{node_type}/{key}: {'; '.join(parts)}")
elif relation == "subset":
extra = table_ids - expected_ids
if extra:
errors.append(f"{node_type}/{key}: unexpected node_ids {extra}")
elif relation == "partition":
partition_tables.append((key, table_ids))

if partition_tables:
# Check pairwise disjointness
for i, (name_a, ids_a) in enumerate(partition_tables):
for name_b, ids_b in partition_tables[i + 1 :]:
overlap = ids_a & ids_b
if overlap:
errors.append(
f"{node_type}: node_ids {overlap} found in both "
f"{name_a} and {name_b}"
)

# Check union equals expected
union_ids: set[int] = set()
for _, ids in partition_tables:
union_ids |= ids
if union_ids != expected_ids:
missing = expected_ids - union_ids
extra = union_ids - expected_ids
table_names = ", ".join(name for name, _ in partition_tables)
parts = []
if missing:
parts.append(f"missing node_ids {missing}")
if extra:
parts.append(f"unexpected node_ids {extra}")
errors.append(
f"{node_type} partition ({table_names}): {'; '.join(parts)}"
)

if errors:
raise ValueError("Node ID validation failed:\n" + "\n".join(errors))

def read(
self,
internal: bool = True,
Expand Down Expand Up @@ -323,6 +404,7 @@
external : bool, optional
Write the NetCDF input files. Default is True.
"""
# here
for table in self._tables():
if (internal and table.is_internal) or (external and table.is_external):
table.write()
Expand Down
8 changes: 7 additions & 1 deletion python/ribasim/ribasim/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,10 @@ def ensure_listen_links(self) -> None:
self.link.df = GeoDataFrame[LinkSchema](_concat([df_link, table_to_append]))

def _validate_model(self) -> None:
"""Validate that all nodes satisfy their neighbor-count bounds for every link type."""
"""Validate that all nodes satisfy their neighbor-count bounds for every link type.

Also validates that node_ids in data tables are consistent with the Node table.
"""
df_link = self.link.df
df_node = self.node.df
assert df_link is not None
Expand All @@ -597,6 +600,9 @@ def _validate_model(self) -> None:
f"Minimum {link_type} inneighbor or outneighbor unsatisfied"
)

for node_model in self._nodes():
node_model._validate_node_ids()

def _has_valid_neighbor_amount(
self,
df_graph: pd.DataFrame,
Expand Down
Loading