Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ Bug fixes
- In the API for backends, support dimensions that express their preferred chunk sizes
as a tuple of integers. (:issue:`6333`, :pull:`6334`)
By `Stan West <https://github.com/stanwest>`_.
- Fix bug in :py:func:`where` when passing non-xarray objects with ``keep_attrs=True``. (:issue:`6444`, :pull:`6461`)
By `Sam Levang <https://github.com/slevang>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
3 changes: 1 addition & 2 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1825,11 +1825,10 @@ def where(cond, x, y, keep_attrs=None):
"""
if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=False)

if keep_attrs is True:
# keep the attributes of x, the second parameter, by default to
# be consistent with the `where` method of `DataArray` and `Dataset`
keep_attrs = lambda attrs, context: attrs[1]
keep_attrs = lambda attrs, context: getattr(x, "attrs", {})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is attrs here for a scalar? The issue is that you're using x.attrs which could be a dataset while presumably attrs has Variable attributes when apply_ufunc is iterating through variables in a dataset.. We want attrs[1] or {} I think with some handling for errors.

PS: Sorry for the not very throrough review earlier

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think something like this should work (with getitem defined in xarray.core.utils):

_default = object()
def getitem(sequence, item, default=_default):
    try:
        return sequence[item]
    except IndexError:
        if default is _default:
            raise
        return default

keep_attrs = lambda attrs, context: getitem(attrs, 1, {})


# alignment for three arguments is complicated, so don't support it yet
return apply_ufunc(
Expand Down
4 changes: 4 additions & 0 deletions xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1928,6 +1928,10 @@ def test_where_attrs() -> None:
expected = xr.DataArray([1, 0], dims="x", attrs={"attr": "x"})
assert_identical(expected, actual)

# ensure keep_attrs can handle scalar values
actual = xr.where(cond, 1, 0, keep_attrs=True)
assert actual.attrs == {}


@pytest.mark.parametrize("use_dask", [True, False])
@pytest.mark.parametrize("use_datetime", [True, False])
Expand Down