11import os
2- from typing import Any , Callable , List , Optional , Tuple
2+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
33
44from PIL import Image
55
@@ -38,7 +38,7 @@ def __init__(
3838 transform : Optional [Callable ] = None ,
3939 target_transform : Optional [Callable ] = None ,
4040 download : bool = False ,
41- ):
41+ ) -> None :
4242 super ().__init__ (os .path .join (root , self .base_folder ), transform = transform , target_transform = target_transform )
4343
4444 self .image_set = verify_str_arg (image_set .lower (), "image_set" , self .file_dict .keys ())
@@ -62,7 +62,7 @@ def _loader(self, path: str) -> Image.Image:
6262 img = Image .open (f )
6363 return img .convert ("RGB" )
6464
65- def _check_integrity (self ):
65+ def _check_integrity (self ) -> bool :
6666 st1 = check_integrity (os .path .join (self .root , self .filename ), self .md5 )
6767 st2 = check_integrity (os .path .join (self .root , self .labels_file ), self .checksums [self .labels_file ])
6868 if not st1 or not st2 :
@@ -71,7 +71,7 @@ def _check_integrity(self):
7171 return check_integrity (os .path .join (self .root , self .names ), self .checksums [self .names ])
7272 return True
7373
74- def download (self ):
74+ def download (self ) -> None :
7575 if self ._check_integrity ():
7676 print ("Files already downloaded and verified" )
7777 return
@@ -81,13 +81,13 @@ def download(self):
8181 if self .view == "people" :
8282 download_url (f"{ self .download_url_prefix } { self .names } " , self .root )
8383
84- def _get_path (self , identity , no ) :
84+ def _get_path (self , identity : str , no : Union [ int , str ]) -> str :
8585 return os .path .join (self .images_dir , identity , f"{ identity } _{ int (no ):04d} .jpg" )
8686
8787 def extra_repr (self ) -> str :
8888 return f"Alignment: { self .image_set } \n Split: { self .split } "
8989
90- def __len__ (self ):
90+ def __len__ (self ) -> int :
9191 return len (self .data )
9292
9393
@@ -119,13 +119,13 @@ def __init__(
119119 transform : Optional [Callable ] = None ,
120120 target_transform : Optional [Callable ] = None ,
121121 download : bool = False ,
122- ):
122+ ) -> None :
123123 super ().__init__ (root , split , image_set , "people" , transform , target_transform , download )
124124
125125 self .class_to_idx = self ._get_classes ()
126126 self .data , self .targets = self ._get_people ()
127127
128- def _get_people (self ):
128+ def _get_people (self ) -> Tuple [ List [ str ], List [ int ]] :
129129 data , targets = [], []
130130 with open (os .path .join (self .root , self .labels_file )) as f :
131131 lines = f .readlines ()
@@ -143,7 +143,7 @@ def _get_people(self):
143143
144144 return data , targets
145145
146- def _get_classes (self ):
146+ def _get_classes (self ) -> Dict [ str , int ] :
147147 with open (os .path .join (self .root , self .names )) as f :
148148 lines = f .readlines ()
149149 names = [line .strip ().split ()[0 ] for line in lines ]
@@ -201,12 +201,12 @@ def __init__(
201201 transform : Optional [Callable ] = None ,
202202 target_transform : Optional [Callable ] = None ,
203203 download : bool = False ,
204- ):
204+ ) -> None :
205205 super ().__init__ (root , split , image_set , "pairs" , transform , target_transform , download )
206206
207207 self .pair_names , self .data , self .targets = self ._get_pairs (self .images_dir )
208208
209- def _get_pairs (self , images_dir ) :
209+ def _get_pairs (self , images_dir : str ) -> Tuple [ List [ Tuple [ str , str ]], List [ Tuple [ str , str ]], List [ int ]] :
210210 pair_names , data , targets = [], [], []
211211 with open (os .path .join (self .root , self .labels_file )) as f :
212212 lines = f .readlines ()
0 commit comments