Skip to content

Commit ab4d41e

Browse files
authored
Merge branch 'main' into v2-consistency
2 parents 79b1f49 + 47cd5ea commit ab4d41e

File tree

4 files changed

+73
-90
lines changed

4 files changed

+73
-90
lines changed

docs/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def __init__(self, src_dir):
8484

8585
transforms_subsection_order = [
8686
"plot_transforms_getting_started.py",
87+
"plot_transforms_illustrations.py",
8788
"plot_transforms_e2e.py",
8889
"plot_cutmix_mixup.py",
8990
"plot_custom_transforms.py",

gallery/others/plot_scripted_tensor_transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def show(imgs):
6262
# --------------------------
6363
# Most transforms natively support tensors on top of PIL images (to visualize
6464
# the effect of the transforms, you may refer to see
65-
# :ref:`sphx_glr_auto_examples_others_plot_transforms.py`).
65+
# :ref:`sphx_glr_auto_examples_transforms_plot_transforms_illustrations.py`).
6666
# Using tensor images, we can run the transforms on GPUs if cuda is available!
6767

6868
import torch.nn as nn

gallery/transforms/helpers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torchvision.transforms.v2 import functional as F
66

77

8-
def plot(imgs):
8+
def plot(imgs, row_title=None, **imshow_kwargs):
99
if not isinstance(imgs[0], list):
1010
# Make a 2d grid even if there's just 1 row
1111
imgs = [imgs]
@@ -40,7 +40,11 @@ def plot(imgs):
4040
img = draw_segmentation_masks(img, masks.to(torch.bool), colors=["green"] * masks.shape[0], alpha=.65)
4141

4242
ax = axs[row_idx, col_idx]
43-
ax.imshow(img.permute(1, 2, 0).numpy())
43+
ax.imshow(img.permute(1, 2, 0).numpy(), **imshow_kwargs)
4444
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
4545

46+
if row_title is not None:
47+
for row_idx in range(num_rows):
48+
axs[row_idx, 0].set(ylabel=row_title[row_idx])
49+
4650
plt.tight_layout()

0 commit comments

Comments
 (0)