Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions torchvision/prototype/datapoints/_dataset_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __getitem__(self, idx):
# of this class
sample = self._dataset[idx]

sample = self._wrapper(sample)
sample = self._wrapper(idx, sample)

# Regardless of whether the user has supplied the transforms individually (`transform` and `target_transform`)
# or joint (`transforms`), we can access the full functionality through `transforms`
Expand Down Expand Up @@ -125,7 +125,10 @@ def wrap_target_by_type(target, *, target_types, type_wrappers):


def classification_wrapper_factory(dataset):
return identity
def wrapper(idx, sample):
return sample

return wrapper


for dataset_cls in [
Expand All @@ -143,7 +146,7 @@ def classification_wrapper_factory(dataset):


def segmentation_wrapper_factory(dataset):
def wrapper(sample):
def wrapper(idx, sample):
image, mask = sample
return image, pil_image_to_mask(mask)

Expand All @@ -163,7 +166,7 @@ def video_classification_wrapper_factory(dataset):
f"since it is not compatible with the transformations. Please use `output_format='TCHW'` instead."
)

def wrapper(sample):
def wrapper(idx, sample):
video, audio, label = sample

video = datapoints.Video(video)
Expand Down Expand Up @@ -201,9 +204,12 @@ def segmentation_to_mask(segmentation, *, spatial_size):
)
return torch.from_numpy(mask.decode(segmentation))

def wrapper(sample):
def wrapper(idx, sample):
image, target = sample

if not target:
return image, dict(image_id=dataset.ids[idx])

batched_target = list_of_dicts_to_dict_of_lists(target)

image_ids = batched_target.pop("image_id")
Expand Down Expand Up @@ -259,7 +265,7 @@ def wrapper(sample):

@WRAPPER_FACTORIES.register(datasets.VOCDetection)
def voc_detection_wrapper_factory(dataset):
def wrapper(sample):
def wrapper(idx, sample):
image, target = sample

batched_instances = list_of_dicts_to_dict_of_lists(target["annotation"]["object"])
Expand Down Expand Up @@ -294,7 +300,7 @@ def celeba_wrapper_factory(dataset):
if any(target_type in dataset.target_type for target_type in ["attr", "landmarks"]):
raise_not_supported("`CelebA` dataset with `target_type=['attr', 'landmarks', ...]`")

def wrapper(sample):
def wrapper(idx, sample):
image, target = sample

target = wrap_target_by_type(
Expand All @@ -318,7 +324,7 @@ def wrapper(sample):

@WRAPPER_FACTORIES.register(datasets.Kitti)
def kitti_wrapper_factory(dataset):
def wrapper(sample):
def wrapper(idx, sample):
image, target = sample

if target is not None:
Expand All @@ -336,7 +342,7 @@ def wrapper(sample):

@WRAPPER_FACTORIES.register(datasets.OxfordIIITPet)
def oxford_iiit_pet_wrapper_factor(dataset):
def wrapper(sample):
def wrapper(idx, sample):
image, target = sample

if target is not None:
Expand Down Expand Up @@ -371,7 +377,7 @@ def instance_segmentation_wrapper(mask):
labels.append(label)
return dict(masks=datapoints.Mask(torch.stack(masks)), labels=torch.stack(labels))

def wrapper(sample):
def wrapper(idx, sample):
image, target = sample

target = wrap_target_by_type(
Expand All @@ -390,7 +396,7 @@ def wrapper(sample):

@WRAPPER_FACTORIES.register(datasets.WIDERFace)
def widerface_wrapper(dataset):
def wrapper(sample):
def wrapper(idx, sample):
image, target = sample

if target is not None:
Expand Down