@@ -662,27 +662,39 @@ class VideoDatasetTestCase(DatasetTestCase):
662662 FEATURE_TYPES = (torch .Tensor , torch .Tensor , int )
663663 REQUIRED_PACKAGES = ("av" ,)
664664
665- DEFAULT_FRAMES_PER_CLIP = 1
665+ FRAMES_PER_CLIP = 1
666666
667667 def __init__ (self , * args , ** kwargs ):
668668 super ().__init__ (* args , ** kwargs )
669669 self .dataset_args = self ._set_default_frames_per_clip (self .dataset_args )
670670
671- def _set_default_frames_per_clip (self , inject_fake_data ):
671+ def _set_default_frames_per_clip (self , dataset_args ):
672672 argspec = inspect .getfullargspec (self .DATASET_CLASS .__init__ )
673673 args_without_default = argspec .args [1 : (- len (argspec .defaults ) if argspec .defaults else None )]
674674 frames_per_clip_last = args_without_default [- 1 ] == "frames_per_clip"
675675
676- @functools .wraps (inject_fake_data )
676+ @functools .wraps (dataset_args )
677677 def wrapper (tmpdir , config ):
678- args = inject_fake_data (tmpdir , config )
678+ args = dataset_args (tmpdir , config )
679679 if frames_per_clip_last and len (args ) == len (args_without_default ) - 1 :
680- args = (* args , self .DEFAULT_FRAMES_PER_CLIP )
680+ args = (* args , self .FRAMES_PER_CLIP )
681681
682682 return args
683683
684684 return wrapper
685685
686+ def test_output_format (self ):
687+ for output_format in ["TCHW" , "THWC" ]:
688+ with self .create_dataset (output_format = output_format ) as (dataset , _ ):
689+ for video , * _ in dataset :
690+ if output_format == "TCHW" :
691+ num_frames , num_channels , * _ = video .shape
692+ else : # output_format == "THWC":
693+ num_frames , * _ , num_channels = video .shape
694+
695+ assert num_frames == self .FRAMES_PER_CLIP
696+ assert num_channels == 3
697+
686698 @test_all_configs
687699 def test_transforms_v2_wrapper (self , config ):
688700 # `output_format == "THWC"` is not supported by the wrapper. Thus, we skip the `config` if it is set explicitly
0 commit comments