diff --git a/geoutils/raster/raster.py b/geoutils/raster/raster.py index e255a8ac..ef0ce767 100644 --- a/geoutils/raster/raster.py +++ b/geoutils/raster/raster.py @@ -2161,8 +2161,15 @@ def plot( # Create colorbar # Use rcParam default - if cmap is None: + if cmap is None and isinstance(bands, int): + # ONLY set a cmap arg for single band images cmap = plt.get_cmap(plt.rcParams["image.cmap"]) + elif cmap is None and isinstance(bands, tuple): + # Leave cmap as None for multi-band image, because if a cmap + # is passed then imshow treats this as an instruction to apply scalar + # mapping, which is not a desirable behaviour (it can result in color-casted + # RGB images for example). + pass elif isinstance(cmap, str): cmap = plt.get_cmap(cmap) elif isinstance(cmap, matplotlib.colors.Colormap):