From d3a7544a5f9ce3d1ab93f6abe78d3704cc01beb6 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 16 Oct 2024 17:36:03 -0400 Subject: [PATCH 1/2] Support alternative names for the root node in DataTree.from_dict `DataTree.from_dict` now supports indicating the root node with "", ".", "/" or "./". This makes the handling of the root node a bit more consistent with handling of other nested paths. For example, we can use relative paths, which allows for removing special cases for the root from some internal uses of from_dict. --- xarray/core/datatree.py | 32 +++++++++++++++++++++----------- xarray/tests/test_datatree.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 11 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 19bfd7130d5..4c3d2149c65 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1104,10 +1104,12 @@ def from_dict( d : dict-like A mapping from path names to xarray.Dataset or DataTree objects. - Path names are to be given as unix-like path. If path names containing more than one - part are given, new tree nodes will be constructed as necessary. + Path names are to be given as unix-like path. If path names + containing more than one part are given, new tree nodes will be + constructed as necessary. - To assign data to the root node of the tree use "/" as the path. + To assign data to the root node of the tree use "", ".", "/" or "./" + as the path. name : Hashable | None, optional Name for the root node of the tree. Default is None. @@ -1119,17 +1121,26 @@ def from_dict( ----- If your dictionary is nested you will need to flatten it before using this method. """ - - # First create the root node + # Find any values corresponding to the root d_cast = dict(d) - root_data = d_cast.pop("/", None) + root_data = None + for key in ("", ".", "/", "./"): + if key in d_cast: + if root_data is not None: + raise ValueError( + "multiple entries found corresponding to the root node" + ) + root_data = d_cast.pop(key) + + # Create the root node if isinstance(root_data, DataTree): obj = root_data.copy() elif root_data is None or isinstance(root_data, Dataset): obj = cls(name=name, dataset=root_data, children=None) else: raise TypeError( - f'root node data (at "/") must be a Dataset or DataTree, got {type(root_data)}' + f'root node data (at "", ".", "/" or "./") must be a Dataset ' + f"or DataTree, got {type(root_data)}" ) def depth(item) -> int: @@ -1141,11 +1152,10 @@ def depth(item) -> int: # Sort keys by depth so as to insert nodes from root first (see GH issue #9276) for path, data in sorted(d_cast.items(), key=depth): # Create and set new node - node_name = NodePath(path).name if isinstance(data, DataTree): new_node = data.copy() elif isinstance(data, Dataset) or data is None: - new_node = cls(name=node_name, dataset=data) + new_node = cls(dataset=data) else: raise TypeError(f"invalid values: {data}") obj._set_item( @@ -1683,7 +1693,7 @@ def reduce( numeric_only=numeric_only, **kwargs, ) - path = "/" if node is self else node.relative_to(self) + path = node.relative_to(self) result[path] = node_result return type(self).from_dict(result, name=self.name) @@ -1718,7 +1728,7 @@ def _selective_indexing( # with a scalar) can also create scalar coordinates, which # need to be explicitly removed. del node_result.coords[k] - path = "/" if node is self else node.relative_to(self) + path = node.relative_to(self) result[path] = node_result return type(self).from_dict(result, name=self.name) diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 9c11cde3bbb..b46597ea1aa 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -883,6 +883,37 @@ def test_array_values(self) -> None: with pytest.raises(TypeError): DataTree.from_dict(data) # type: ignore[arg-type] + def test_relative_paths(self) -> None: + tree = DataTree.from_dict({".": None, "foo": None, "./bar": None, "x/y": None}) + paths = [node.path for node in tree.subtree] + assert paths == [ + "/", + "/foo", + "/bar", + "/x/y", + ] + + def test_root_keys(self): + ds = Dataset({"x": 1}) + expected = DataTree(dataset=ds) + + actual = DataTree.from_dict({"": ds}) + assert_identical(actual, expected) + + actual = DataTree.from_dict({".": ds}) + assert_identical(actual, expected) + + actual = DataTree.from_dict({"/": ds}) + assert_identical(actual, expected) + + actual = DataTree.from_dict({"./": ds}) + assert_identical(actual, expected) + + with pytest.raises( + ValueError, match="multiple entries found corresponding to the root node" + ): + DataTree.from_dict({"": ds, "/": ds}) + class TestDatasetView: def test_view_contents(self) -> None: From ac4be24a45ffd1b6ccd8c817db1b4ba39e08528d Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 16 Oct 2024 17:44:27 -0400 Subject: [PATCH 2/2] fixes --- xarray/core/datatree.py | 1 + xarray/tests/test_datatree.py | 11 +++++++++++ 2 files changed, 12 insertions(+) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 4c3d2149c65..5657af076e1 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1135,6 +1135,7 @@ def from_dict( # Create the root node if isinstance(root_data, DataTree): obj = root_data.copy() + obj.name = name elif root_data is None or isinstance(root_data, Dataset): obj = cls(name=name, dataset=root_data, children=None) else: diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index b46597ea1aa..1b6c9c6de0c 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -890,6 +890,7 @@ def test_relative_paths(self) -> None: "/", "/foo", "/bar", + "/x", "/x/y", ] @@ -914,6 +915,16 @@ def test_root_keys(self): ): DataTree.from_dict({"": ds, "/": ds}) + def test_name(self): + tree = DataTree.from_dict({"/": None}, name="foo") + assert tree.name == "foo" + + tree = DataTree.from_dict({"/": DataTree()}, name="foo") + assert tree.name == "foo" + + tree = DataTree.from_dict({"/": DataTree(name="bar")}, name="foo") + assert tree.name == "foo" + class TestDatasetView: def test_view_contents(self) -> None: