diff --git a/references/classification/presets.py b/references/classification/presets.py index a710f92ae88..0f2c914be7e 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -69,11 +69,10 @@ def __init__( backend="pil", ): trans = [] - backend = backend.lower() if backend == "tensor": trans.append(transforms.PILToTensor()) - else: + elif backend != "pil": raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}") trans += [