diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 10ebcc07664..e5032c3729a 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -7,17 +7,22 @@ Dataset.plot._____ """ import functools +from distutils.version import LooseVersion import numpy as np import pandas as pd +from ..core.alignment import broadcast from .facetgrid import _easy_facetgrid from .utils import ( _add_colorbar, + _adjust_legend_subtitles, _assert_valid_xy, _ensure_plottable, _infer_interval_breaks, _infer_xy_labels, + _is_numeric, + _legend_add_subtitle, _process_cmap_cbar_kwargs, _rescale_imshow_rgb, _resolve_intervals_1dplot, @@ -26,8 +31,132 @@ get_axis, import_matplotlib_pyplot, label_from_attrs, + legend_elements, ) +# copied from seaborn +_MARKERSIZE_RANGE = np.array([18.0, 72.0]) + + +def _infer_scatter_metadata(darray, x, z, hue, hue_style, size): + def _determine_array(darray, name, array_style): + """Find and determine what type of array it is.""" + array = darray[name] + array_is_numeric = _is_numeric(array.values) + + if array_style is None: + array_style = "continuous" if array_is_numeric else "discrete" + elif array_style not in ["discrete", "continuous"]: + raise ValueError( + f"The style '{array_style}' is not valid, " + "valid options are None, 'discrete' or 'continuous'." + ) + + array_label = label_from_attrs(array) + + return array, array_style, array_label + + # Add nice looking labels: + out = dict(ylabel=label_from_attrs(darray)) + out.update( + { + k: label_from_attrs(darray[v]) if v in darray.coords else None + for k, v in [("xlabel", x), ("zlabel", z)] + } + ) + + # Add styles and labels for the dataarrays: + for type_, a, style in [("hue", hue, hue_style), ("size", size, None)]: + tp, stl, lbl = f"{type_}", f"{type_}_style", f"{type_}_label" + if a: + out[tp], out[stl], out[lbl] = _determine_array(darray, a, style) + else: + out[tp], out[stl], out[lbl] = None, None, None + + return out + + +# copied from seaborn +def _parse_size(data, norm, width): + """ + Determine what type of data it is. Then normalize it to width. + + If the data is categorical, normalize it to numbers. + """ + plt = import_matplotlib_pyplot() + + if data is None: + return None + + data = data.values.ravel() + + if not _is_numeric(data): + # Data is categorical. + # Use pd.unique instead of np.unique because that keeps + # the order of the labels: + levels = pd.unique(data) + numbers = np.arange(1, 1 + len(levels)) + else: + levels = numbers = np.sort(np.unique(data)) + + min_width, max_width = width + # width_range = min_width, max_width + + if norm is None: + norm = plt.Normalize() + elif isinstance(norm, tuple): + norm = plt.Normalize(*norm) + elif not isinstance(norm, plt.Normalize): + err = "``size_norm`` must be None, tuple, or Normalize object." + raise ValueError(err) + + norm.clip = True + if not norm.scaled(): + norm(np.asarray(numbers)) + # limits = norm.vmin, norm.vmax + + scl = norm(numbers) + widths = np.asarray(min_width + scl * (max_width - min_width)) + if scl.mask.any(): + widths[scl.mask] = 0 + sizes = dict(zip(levels, widths)) + + return pd.Series(sizes) + + +def _infer_scatter_data( + darray, x, z, hue, size, size_norm, size_mapping=None, size_range=(1, 10) +): + # Broadcast together all the chosen variables: + to_broadcast = dict(y=darray) + to_broadcast.update( + {k: darray[v] for k, v in dict(x=x, z=z).items() if v is not None} + ) + to_broadcast.update( + {k: darray[v] for k, v in dict(hue=hue, size=size).items() if v in darray.dims} + ) + broadcasted = dict(zip(to_broadcast.keys(), broadcast(*(to_broadcast.values())))) + + # Normalize hue and size and create lookup tables: + for type_, mapping, norm, width in [ + ("hue", None, None, [0, 1]), + ("size", size_mapping, size_norm, size_range), + ]: + broadcasted_type = broadcasted.get(type_, None) + if broadcasted_type is not None: + if mapping is None: + mapping = _parse_size(broadcasted_type, norm, width) + + broadcasted[type_] = broadcasted_type.copy( + data=np.reshape( + mapping.loc[broadcasted_type.values.ravel()].values, + broadcasted_type.shape, + ) + ) + broadcasted[f"{type_}_to_label"] = pd.Series(mapping.index, index=mapping) + + return broadcasted + def _infer_line_data(darray, x, y, hue): @@ -435,6 +564,291 @@ def hist( return primitive +def scatter( + darray, + *args, + row=None, + col=None, + figsize=None, + aspect=None, + size=None, + ax=None, + hue=None, + hue_style=None, + x=None, + z=None, + xincrease=None, + yincrease=None, + xscale=None, + yscale=None, + xticks=None, + yticks=None, + xlim=None, + ylim=None, + add_legend=None, + add_colorbar=None, + cbar_kwargs=None, + cbar_ax=None, + vmin=None, + vmax=None, + norm=None, + infer_intervals=None, + center=None, + levels=None, + robust=None, + colors=None, + extend=None, + cmap=None, + _labels=True, + **kwargs, +): + """ + Scatter plot a DataArray along some coordinates. + + Parameters + ---------- + darray : DataArray + Dataarray to plot. + x, y : str + Variable names for x, y axis. + hue: str, optional + Variable by which to color scattered points + hue_style: str, optional + Can be either 'discrete' (legend) or 'continuous' (color bar). + markersize: str, optional + scatter only. Variable by which to vary size of scattered points. + size_norm: optional + Either None or 'Norm' instance to normalize the 'markersize' variable. + add_guide: bool, optional + Add a guide that depends on hue_style + - for "discrete", build a legend. + This is the default for non-numeric `hue` variables. + - for "continuous", build a colorbar + row : str, optional + If passed, make row faceted plots on this dimension name + col : str, optional + If passed, make column faceted plots on this dimension name + col_wrap : int, optional + Use together with ``col`` to wrap faceted plots + ax : matplotlib axes object, optional + If None, uses the current axis. Not applicable when using facets. + subplot_kws : dict, optional + Dictionary of keyword arguments for matplotlib subplots. Only applies + to FacetGrid plotting. + aspect : scalar, optional + Aspect ratio of plot, so that ``aspect * size`` gives the width in + inches. Only used if a ``size`` is provided. + size : scalar, optional + If provided, create a new figure for the plot with the given size. + Height (in inches) of each plot. See also: ``aspect``. + norm : ``matplotlib.colors.Normalize`` instance, optional + If the ``norm`` has vmin or vmax specified, the corresponding kwarg + must be None. + vmin, vmax : float, optional + Values to anchor the colormap, otherwise they are inferred from the + data and other keyword arguments. When a diverging dataset is inferred, + setting one of these values will fix the other by symmetry around + ``center``. Setting both values prevents use of a diverging colormap. + If discrete levels are provided as an explicit list, both of these + values are ignored. + cmap : str or colormap, optional + The mapping from data values to color space. Either a + matplotlib colormap name or object. If not provided, this will + be either ``viridis`` (if the function infers a sequential + dataset) or ``RdBu_r`` (if the function infers a diverging + dataset). When `Seaborn` is installed, ``cmap`` may also be a + `seaborn` color palette. If ``cmap`` is seaborn color palette + and the plot type is not ``contour`` or ``contourf``, ``levels`` + must also be specified. + colors : color-like or list of color-like, optional + A single color or a list of colors. If the plot type is not ``contour`` + or ``contourf``, the ``levels`` argument is required. + center : float, optional + The value at which to center the colormap. Passing this value implies + use of a diverging colormap. Setting it to ``False`` prevents use of a + diverging colormap. + robust : bool, optional + If True and ``vmin`` or ``vmax`` are absent, the colormap range is + computed with 2nd and 98th percentiles instead of the extreme values. + extend : {"neither", "both", "min", "max"}, optional + How to draw arrows extending the colorbar beyond its limits. If not + provided, extend is inferred from vmin, vmax and the data limits. + levels : int or list-like object, optional + Split the colormap (cmap) into discrete color intervals. If an integer + is provided, "nice" levels are chosen based on the data range: this can + imply that the final number of levels is not exactly the expected one. + Setting ``vmin`` and/or ``vmax`` with ``levels=N`` is equivalent to + setting ``levels=np.linspace(vmin, vmax, N)``. + **kwargs : optional + Additional keyword arguments to matplotlib + """ + plt = import_matplotlib_pyplot() + + # Handle facetgrids first + if row or col: + allargs = locals().copy() + allargs.update(allargs.pop("kwargs")) + allargs.pop("darray") + subplot_kws = dict(projection="3d") if z is not None else None + return _easy_facetgrid( + darray, scatter, kind="dataarray", subplot_kws=subplot_kws, **allargs + ) + + # Further + _is_facetgrid = kwargs.pop("_is_facetgrid", False) + if _is_facetgrid: + # Why do I need to pop these here? + kwargs.pop("y", None) + kwargs.pop("args", None) + kwargs.pop("add_labels", None) + + _sizes = kwargs.pop("markersize", kwargs.pop("linewidth", None)) + size_norm = kwargs.pop("size_norm", None) + size_mapping = kwargs.pop("size_mapping", None) # set by facetgrid + cmap_params = kwargs.pop("cmap_params", None) + + figsize = kwargs.pop("figsize", None) + subplot_kws = dict() + if z is not None and ax is None: + # TODO: Importing Axes3D is not necessary in matplotlib >= 3.2. + # Remove when minimum requirement of matplotlib is 3.2: + from mpl_toolkits.mplot3d import Axes3D # type: ignore # noqa + + subplot_kws.update(projection="3d") + ax = get_axis(figsize, size, aspect, ax, **subplot_kws) + # Using 30, 30 minimizes rotation of the plot. Making it easier to + # build on your intuition from 2D plots: + if LooseVersion(plt.matplotlib.__version__) < "3.5.0": + ax.view_init(azim=30, elev=30) + else: + # https://github.com/matplotlib/matplotlib/pull/19873 + ax.view_init(azim=30, elev=30, vertical_axis="y") + else: + ax = get_axis(figsize, size, aspect, ax, **subplot_kws) + + _data = _infer_scatter_metadata(darray, x, z, hue, hue_style, _sizes) + + add_guide = kwargs.pop("add_guide", None) + if add_legend is not None: + pass + elif add_guide is None or add_guide is True: + add_legend = True if _data["hue_style"] == "discrete" else False + elif add_legend is None: + add_legend = False + + if add_colorbar is not None: + pass + elif add_guide is None or add_guide is True: + add_colorbar = True if _data["hue_style"] == "continuous" else False + else: + add_colorbar = False + + # need to infer size_mapping with full dataset + _data.update( + _infer_scatter_data( + darray, + x, + z, + hue, + _sizes, + size_norm, + size_mapping, + _MARKERSIZE_RANGE, + ) + ) + + cmap_params_subset = {} + if _data["hue"] is not None: + kwargs.update(c=_data["hue"].values.ravel()) + cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( + scatter, _data["hue"].values, **locals() + ) + + # subset that can be passed to scatter, hist2d + cmap_params_subset = { + vv: cmap_params[vv] for vv in ["vmin", "vmax", "norm", "cmap"] + } + + if _data["size"] is not None: + kwargs.update(s=_data["size"].values.ravel()) + + if LooseVersion(plt.matplotlib.__version__) < "3.5.0": + # Plot the data. 3d plots has the z value in upward direction + # instead of y. To make jumping between 2d and 3d easy and intuitive + # switch the order so that z is shown in the depthwise direction: + axis_order = ["x", "z", "y"] + else: + # Switching axis order not needed in 3.5.0, can also simplify the code + # that uses axis_order: + # https://github.com/matplotlib/matplotlib/pull/19873 + axis_order = ["x", "y", "z"] + + primitive = ax.scatter( + *[ + _data[v].values.ravel() + for v in axis_order + if _data.get(v, None) is not None + ], + **cmap_params_subset, + **kwargs, + ) + + # Set x, y, z labels: + i = 0 + set_label = [ax.set_xlabel, ax.set_ylabel, getattr(ax, "set_zlabel", None)] + for v in axis_order: + if _data.get(f"{v}label", None) is not None: + set_label[i](_data[f"{v}label"]) + i += 1 + + if add_legend: + + def to_label(data, key, x): + """Map prop values back to its original values.""" + if key in data: + # Use reindex to be less sensitive to float errors. + # Return as numpy array since legend_elements + # seems to require that: + return data[key].reindex(x, method="nearest").to_numpy() + else: + return x + + handles, labels = [], [] + for subtitle, prop, func in [ + ( + _data["hue_label"], + "colors", + functools.partial(to_label, _data, "hue_to_label"), + ), + ( + _data["size_label"], + "sizes", + functools.partial(to_label, _data, "size_to_label"), + ), + ]: + if subtitle: + # Get legend handles and labels that displays the + # values correctly. Order might be different because + # legend_elements uses np.unique instead of pd.unique, + # FacetGrid.add_legend might have troubles with this: + hdl, lbl = legend_elements(primitive, prop, num="auto", func=func) + hdl, lbl = _legend_add_subtitle(hdl, lbl, subtitle, ax.scatter) + handles += hdl + labels += lbl + legend = ax.legend(handles, labels, framealpha=0.5) + _adjust_legend_subtitles(legend) + + if add_colorbar and _data["hue_label"]: + if _data["hue_style"] == "discrete": + raise NotImplementedError("Cannot create a colorbar for non numerics.") + cbar_kwargs = {} if cbar_kwargs is None else cbar_kwargs + if "label" not in cbar_kwargs: + cbar_kwargs["label"] = _data["hue_label"] + _add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, cmap_params) + + return primitive + + # MUST run before any 2d plotting functions are defined since # _plot2d decorator adds them as methods here. class _PlotMethods: @@ -468,6 +882,10 @@ def line(self, *args, **kwargs): def step(self, *args, **kwargs): return step(self._da, *args, **kwargs) + @functools.wraps(scatter) + def _scatter(self, *args, **kwargs): + return scatter(self._da, *args, **kwargs) + def override_signature(f): def wrapper(func): diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index db85a5908c0..1643a295d3f 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -881,3 +881,231 @@ def _get_nice_quiver_magnitude(u, v): mean = np.mean(np.hypot(u.values, v.values)) magnitude = ticker.tick_values(0, mean)[-2] return magnitude + + +# Copied from matplotlib, tweaked so func can return strings. +# https://github.com/matplotlib/matplotlib/issues/19555 +def legend_elements( + self, prop="colors", num="auto", fmt=None, func=lambda x: x, **kwargs +): + """ + Create legend handles and labels for a PathCollection. + + Each legend handle is a `.Line2D` representing the Path that was drawn, + and each label is a string what each Path represents. + + This is useful for obtaining a legend for a `~.Axes.scatter` plot; + e.g.:: + + scatter = plt.scatter([1, 2, 3], [4, 5, 6], c=[7, 2, 3]) + plt.legend(*scatter.legend_elements()) + + creates three legend elements, one for each color with the numerical + values passed to *c* as the labels. + + Also see the :ref:`automatedlegendcreation` example. + + + Parameters + ---------- + prop : {"colors", "sizes"}, default: "colors" + If "colors", the legend handles will show the different colors of + the collection. If "sizes", the legend will show the different + sizes. To set both, use *kwargs* to directly edit the `.Line2D` + properties. + num : int, None, "auto" (default), array-like, or `~.ticker.Locator` + Target number of elements to create. + If None, use all unique elements of the mappable array. If an + integer, target to use *num* elements in the normed range. + If *"auto"*, try to determine which option better suits the nature + of the data. + The number of created elements may slightly deviate from *num* due + to a `~.ticker.Locator` being used to find useful locations. + If a list or array, use exactly those elements for the legend. + Finally, a `~.ticker.Locator` can be provided. + fmt : str, `~matplotlib.ticker.Formatter`, or None (default) + The format or formatter to use for the labels. If a string must be + a valid input for a `~.StrMethodFormatter`. If None (the default), + use a `~.ScalarFormatter`. + func : function, default: ``lambda x: x`` + Function to calculate the labels. Often the size (or color) + argument to `~.Axes.scatter` will have been pre-processed by the + user using a function ``s = f(x)`` to make the markers visible; + e.g. ``size = np.log10(x)``. Providing the inverse of this + function here allows that pre-processing to be inverted, so that + the legend labels have the correct values; e.g. ``func = lambda + x: 10**x``. + **kwargs + Allowed keyword arguments are *color* and *size*. E.g. it may be + useful to set the color of the markers if *prop="sizes"* is used; + similarly to set the size of the markers if *prop="colors"* is + used. Any further parameters are passed onto the `.Line2D` + instance. This may be useful to e.g. specify a different + *markeredgecolor* or *alpha* for the legend handles. + + Returns + ------- + handles : list of `.Line2D` + Visual representation of each element of the legend. + labels : list of str + The string labels for elements of the legend. + """ + import warnings + + import matplotlib as mpl + + mlines = mpl.lines + + handles = [] + labels = [] + + if prop == "colors": + arr = self.get_array() + if arr is None: + warnings.warn( + "Collection without array used. Make sure to " + "specify the values to be colormapped via the " + "`c` argument." + ) + return handles, labels + _size = kwargs.pop("size", mpl.rcParams["lines.markersize"]) + + def _get_color_and_size(value): + return self.cmap(self.norm(value)), _size + + elif prop == "sizes": + arr = self.get_sizes() + _color = kwargs.pop("color", "k") + + def _get_color_and_size(value): + return _color, np.sqrt(value) + + else: + raise ValueError( + "Valid values for `prop` are 'colors' or " + f"'sizes'. You supplied '{prop}' instead." + ) + + # Get the unique values and their labels: + values = np.unique(arr) + label_values = np.asarray(func(values)) + label_values_are_numeric = np.issubdtype(label_values.dtype, np.number) + + # Handle the label format: + if fmt is None and label_values_are_numeric: + fmt = mpl.ticker.ScalarFormatter(useOffset=False, useMathText=True) + elif fmt is None and not label_values_are_numeric: + fmt = mpl.ticker.StrMethodFormatter("{x}") + elif isinstance(fmt, str): + fmt = mpl.ticker.StrMethodFormatter(fmt) + fmt.create_dummy_axis() + + if num == "auto": + num = 9 + if len(values) <= num: + num = None + + if label_values_are_numeric: + label_values_min = label_values.min() + label_values_max = label_values.max() + fmt.set_bounds(label_values_min, label_values_max) + + if num is not None: + # Labels are numerical but larger than the target + # number of elements, reduce to target using matplotlibs + # ticker classes: + if isinstance(num, mpl.ticker.Locator): + loc = num + elif np.iterable(num): + loc = mpl.ticker.FixedLocator(num) + else: + num = int(num) + loc = mpl.ticker.MaxNLocator( + nbins=num, min_n_ticks=num - 1, steps=[1, 2, 2.5, 3, 5, 6, 8, 10] + ) + + # Get nicely spaced label_values: + label_values = loc.tick_values(label_values_min, label_values_max) + + # Remove extrapolated label_values: + cond = (label_values >= label_values_min) & ( + label_values <= label_values_max + ) + label_values = label_values[cond] + + # Get the corresponding values by creating a linear interpolant + # with small step size: + values_interp = np.linspace(values.min(), values.max(), 256) + label_values_interp = func(values_interp) + ix = np.argsort(label_values_interp) + values = np.interp(label_values, label_values_interp[ix], values_interp[ix]) + elif num is not None and not label_values_are_numeric: + # Labels are not numerical so modifying label_values is not + # possible, instead filter the array with nicely distributed + # indexes: + if type(num) == int: + loc = mpl.ticker.LinearLocator(num) + else: + raise ValueError("`num` only supports integers for non-numeric labels.") + + ind = loc.tick_values(0, len(label_values) - 1).astype(int) + label_values = label_values[ind] + values = values[ind] + + # Some formatters requires set_locs: + if hasattr(fmt, "set_locs"): + fmt.set_locs(label_values) + + # Default settings for handles, add or override with kwargs: + kw = dict(markeredgewidth=self.get_linewidths()[0], alpha=self.get_alpha()) + kw.update(kwargs) + + for val, lab in zip(values, label_values): + color, size = _get_color_and_size(val) + h = mlines.Line2D( + [0], [0], ls="", color=color, ms=size, marker=self.get_paths()[0], **kw + ) + handles.append(h) + labels.append(fmt(lab)) + + return handles, labels + + +def _legend_add_subtitle(handles, labels, text, func): + """Add a subtitle to legend handles.""" + if text and len(handles) > 1: + # Create a blank handle that's not visible, the + # invisibillity will be used to discern which are subtitles + # or not: + blank_handle = func([], [], label=text) + blank_handle.set_visible(False) + + # Subtitles are shown first: + handles = [blank_handle] + handles + labels = [text] + labels + + return handles, labels + + +def _adjust_legend_subtitles(legend): + """Make invisible-handle "subtitles" entries look more like titles.""" + plt = import_matplotlib_pyplot() + + # Legend title not in rcParams until 3.0 + font_size = plt.rcParams.get("legend.title_fontsize", None) + hpackers = legend.findobj(plt.matplotlib.offsetbox.VPacker)[0].get_children() + for hpack in hpackers: + draw_area, text_area = hpack.get_children() + handles = draw_area.get_children() + + # Assume that all artists that are not visible are + # subtitles: + if not all(artist.get_visible() for artist in handles): + # Remove the dummy marker which will bring the text + # more to the center: + draw_area.set_width(0) + for text in text_area.get_children(): + if font_size is not None: + # The sutbtitles should have the same font size + # as normal legend titles: + text.set_size(font_size) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index e833654138a..c7f363bbab2 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2808,3 +2808,41 @@ def test_maybe_gca(): assert existing_axes == ax # kwargs are ignored when reusing axes assert ax.get_aspect() == "auto" + + +@requires_matplotlib +@pytest.mark.parametrize( + "x, y, z, hue, markersize, row, col, add_legend, add_colorbar", + [ + ("A", "B", None, None, None, None, None, None, None), + ("B", "A", None, "w", None, None, None, True, None), + ("A", "B", None, "y", "x", None, None, True, True), + ("A", "B", "z", None, None, None, None, None, None), + ("B", "A", "z", "w", None, None, None, True, None), + ("A", "B", "z", "y", "x", None, None, True, True), + ("A", "B", "z", "y", "x", "w", None, True, True), + ], +) +def test_datarray_scatter(x, y, z, hue, markersize, row, col, add_legend, add_colorbar): + """Test datarray scatter. Merge with TestPlot1D eventually.""" + ds = xr.tutorial.scatter_example_dataset() + + extra_coords = [v for v in [x, hue, markersize] if v is not None] + + # Base coords: + coords = dict(ds.coords) + + # Add extra coords to the DataArray: + coords.update({v: ds[v] for v in extra_coords}) + + darray = xr.DataArray(ds[y], coords=coords) + + with figure_context(): + darray.plot._scatter( + x=x, + z=z, + hue=hue, + markersize=markersize, + add_legend=add_legend, + add_colorbar=add_colorbar, + )