Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions references/classification/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,10 @@ def __init__(
backend="pil",
):
trans = []

backend = backend.lower()
if backend == "tensor":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably this should change to

Suggested change
if backend == "tensor":
if backend == "tensor" or backend == "pil":

trans.append(transforms.PILToTensor())
else:
elif backend != "pil":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This probably still be buggy when backend="pil"

As it will not apply PILtoTensor() transform.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we apply PILtoTensor when backend="pil"? I was just guided by the way it is done in the ClassificationPresetTrain class

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably do this in Line 83. I missed this. But it's confusing why would one apply PILToTensor() when backend type is Tensor(). Maybe I'm missing something again.

Copy link
Contributor Author

@AetelFinch AetelFinch Jun 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I see, we apply PILtoTensor at the beginning, when backend=="tensor", to do transformations over tensors, not over PIL images.

trans += [
    transforms.Resize(resize_size, interpolation=interpolation, antialias=True),
    transforms.CenterCrop(crop_size),
]

As written in the documentation for transforms.Resize:
"The output image might be different depending on its type: when downsampling, the interpolation of PIL images and tensors is slightly different, because PIL applies antialiasing. This may lead to significant differences in the performance of a network."

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But it's confusing why would one apply PILToTensor() when backend type is Tensor().

Yeah it's not super clean nor obvious just reading the code. The key piece of information is that those presets make the hard assumption that whatever you pass as input is a PIL image, no matter the backend!

So when we pass backend="tensor" we actually need to first convert the input (a PIL image) to a tensor.

A more complete solution would be to do all those checks "at runtime" in forward() instead of here.

raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")

trans += [
Expand Down