3838from .info import DatasetInfo
3939from .search import IndexableMixin
4040from .splits import NamedSplit
41- from .utils import map_all_sequences_to_lists , map_nested
41+ from .utils import map_nested
4242
4343
4444logger = logging .getLogger (__name__ )
@@ -49,7 +49,7 @@ class DatasetInfoMixin(object):
4949 at the base level of the Dataset for easy access.
5050 """
5151
52- def __init__ (self , info : Optional [ DatasetInfo ] , split : Optional [NamedSplit ]):
52+ def __init__ (self , info : DatasetInfo , split : Optional [NamedSplit ]):
5353 self ._info = info
5454 self ._split = split
5555
@@ -92,7 +92,7 @@ def download_size(self) -> Optional[int]:
9292 return self ._info .download_size
9393
9494 @property
95- def features (self ):
95+ def features (self ) -> Features :
9696 return self ._info .features
9797
9898 @property
@@ -131,6 +131,7 @@ def __init__(
131131 info : Optional [DatasetInfo ] = None ,
132132 split : Optional [NamedSplit ] = None ,
133133 ):
134+ info = info .copy () if info is not None else DatasetInfo ()
134135 DatasetInfoMixin .__init__ (self , info = info , split = split )
135136 IndexableMixin .__init__ (self )
136137 self ._data : pa .Table = arrow_table
@@ -139,6 +140,15 @@ def __init__(
139140 self ._format_kwargs : dict = {}
140141 self ._format_columns : Optional [list ] = None
141142 self ._output_all_columns : bool = False
143+ inferred_features = Features .from_arrow_schema (arrow_table .schema )
144+ if self .info .features is not None :
145+ if self .info .features .type != inferred_features .type :
146+ self .info .features = inferred_features
147+ else :
148+ pass # keep the original features
149+ else :
150+ self .info .features = inferred_features
151+ assert self .features is not None , "Features can't be None in a Dataset object"
142152
143153 @classmethod
144154 def from_file (
@@ -177,7 +187,7 @@ def from_pandas(
177187
178188 Be aware that Series of the `object` dtype don't carry enough information to always lead to a meaningful Arrow type. In the case that
179189 we cannot infer a type, e.g. because the DataFrame is of length 0 or the Series only contains None/nan objects, the type is set to
180- null. This behavior can be avoided by constructing an explicit schema and passing it to this function.
190+ null. This behavior can be avoided by constructing explicit features and passing it to this function.
181191
182192 Args:
183193 df (:obj:``pandas.DataFrame``): the dataframe that contains the dataset.
@@ -186,21 +196,16 @@ def from_pandas(
186196 description, citation, etc.
187197 split (:obj:``nlp.NamedSplit``, `optional`, defaults to :obj:``None``): If specified, the name of the dataset split.
188198 """
189- if info is None :
190- info = DatasetInfo ()
191- if info .features is None :
192- info .features = features
193- elif info .features != features and features is not None :
199+ if info is not None and features is not None and info .features != features :
194200 raise ValueError (
195201 "Features specified in `features` and `info.features` can't be different:\n {}\n {}" .format (
196202 features , info .features
197203 )
198204 )
205+ features = features if features is not None else info .feature if info is not None else None
199206 pa_table : pa .Table = pa .Table .from_pandas (
200- df = df , schema = pa .schema (info . features .type ) if info . features is not None else None
207+ df = df , schema = pa .schema (features .type ) if features is not None else None
201208 )
202- if info .features is None :
203- info .features = Features .from_arrow_schema (pa_table .schema )
204209 return cls (pa_table , info = info , split = split )
205210
206211 @classmethod
@@ -221,21 +226,16 @@ def from_dict(
221226 description, citation, etc.
222227 split (:obj:``nlp.NamedSplit``, `optional`, defaults to :obj:``None``): If specified, the name of the dataset split.
223228 """
224- if info is None :
225- info = DatasetInfo ()
226- if info .features is None :
227- info .features = features
228- elif info .features != features and features is not None :
229+ if info is not None and features is not None and info .features != features :
229230 raise ValueError (
230231 "Features specified in `features` and `info.features` can't be different:\n {}\n {}" .format (
231232 features , info .features
232233 )
233234 )
235+ features = features if features is not None else info .feature if info is not None else None
234236 pa_table : pa .Table = pa .Table .from_pydict (
235- mapping = mapping , schema = pa .schema (info . features .type ) if info . features is not None else None
237+ mapping = mapping , schema = pa .schema (features .type ) if features is not None else None
236238 )
237- if info .features is None :
238- info .features = Features .from_arrow_schema (pa_table .schema )
239239 return cls (pa_table , info = info , split = split )
240240
241241 @property
@@ -277,14 +277,6 @@ def column_names(self) -> List[str]:
277277 """Names of the columns in the dataset. """
278278 return self ._data .column_names
279279
280- @property
281- def schema (self ) -> pa .Schema :
282- """The Arrow schema of the Apache Arrow table backing the dataset.
283- You probably don't need to access directly this and can rather use
284- :func:`nlp.Dataset.features` to inspect the dataset features.
285- """
286- return self ._data .schema
287-
288280 @property
289281 def shape (self ):
290282 """Shape of the dataset (number of columns, number of rows)."""
@@ -340,6 +332,7 @@ def dictionary_encode_column(self, column: str):
340332 casted_field = pa .field (field .name , pa .dictionary (pa .int32 (), field .type ), nullable = False )
341333 casted_schema .set (field_index , casted_field )
342334 self ._data = self ._data .cast (casted_schema )
335+ self .info .features = Features .from_arrow_schema (self ._data .schema )
343336
344337 def flatten (self , max_depth = 16 ):
345338 """ Flatten the Table.
@@ -352,7 +345,7 @@ def flatten(self, max_depth=16):
352345 else :
353346 break
354347 if self .info is not None :
355- self .info .features = Features .from_arrow_schema (self .schema )
348+ self .info .features = Features .from_arrow_schema (self ._data . schema )
356349 logger .info (
357350 "Flattened dataset from depth {} to depth {}." .format (depth , 1 if depth + 1 < max_depth else "unknown" )
358351 )
@@ -380,8 +373,7 @@ def __iter__(self):
380373 )
381374
382375 def __repr__ (self ):
383- schema_str = dict ((a , str (b )) for a , b in zip (self ._data .schema .names , self ._data .schema .types ))
384- return f"Dataset(schema: { schema_str } , num_rows: { self .num_rows } )"
376+ return f"Dataset(features: { self .features } , num_rows: { self .num_rows } )"
385377
386378 @property
387379 def format (self ):
@@ -685,7 +677,7 @@ def map(
685677 load_from_cache_file : bool = True ,
686678 cache_file_name : Optional [str ] = None ,
687679 writer_batch_size : Optional [int ] = 1000 ,
688- arrow_schema : Optional [pa . Schema ] = None ,
680+ features : Optional [Features ] = None ,
689681 disable_nullable : bool = True ,
690682 verbose : bool = True ,
691683 ):
@@ -712,7 +704,7 @@ def map(
712704 results of the computation instead of the automatically generated cache file name.
713705 `writer_batch_size` (`int`, default: `1000`): Number of rows per write operation for the cache file writer.
714706 Higher value gives smaller cache files, lower value consume less temporary memory while running `.map()`.
715- `arrow_schema ` (`Optional[pa.Schema ]`, default: `None`): Use a specific Apache Arrow Schema to store the cache file
707+ `features ` (`Optional[nlp.Features ]`, default: `None`): Use a specific Features to store the cache file
716708 instead of the automatically generated one.
717709 `disable_nullable` (`bool`, default: `True`): Allow null values in the table.
718710 `verbose` (`bool`, default: `True`): Set to `False` to deactivate the tqdm progress bar and informations.
@@ -792,18 +784,6 @@ def apply_function_on_filtered_inputs(inputs, indices, check_same_num_examples=F
792784 inputs .update (processed_inputs )
793785 return inputs
794786
795- # Find the output schema if none is given
796- test_inputs = self [:2 ] if batched else self [0 ]
797- test_indices = [0 , 1 ] if batched else 0
798- test_output = apply_function_on_filtered_inputs (test_inputs , test_indices )
799- if arrow_schema is None and update_data :
800- if not batched :
801- test_output = self ._nest (test_output )
802- test_output = map_all_sequences_to_lists (test_output )
803- arrow_schema = pa .Table .from_pydict (test_output ).schema
804- if disable_nullable :
805- arrow_schema = pa .schema (pa .field (field .name , field .type , nullable = False ) for field in arrow_schema )
806-
807787 # Check if we've already cached this computation (indexed by a hash)
808788 if self ._data_files and update_data :
809789 if cache_file_name is None :
@@ -817,7 +797,7 @@ def apply_function_on_filtered_inputs(inputs, indices, check_same_num_examples=F
817797 "load_from_cache_file" : load_from_cache_file ,
818798 "cache_file_name" : cache_file_name ,
819799 "writer_batch_size" : writer_batch_size ,
820- "arrow_schema " : arrow_schema ,
800+ "features " : features ,
821801 "disable_nullable" : disable_nullable ,
822802 }
823803 cache_file_name = self ._get_cache_file_path (function , cache_kwargs )
@@ -830,12 +810,12 @@ def apply_function_on_filtered_inputs(inputs, indices, check_same_num_examples=F
830810 if update_data :
831811 if keep_in_memory or not self ._data_files :
832812 buf_writer = pa .BufferOutputStream ()
833- writer = ArrowWriter (schema = arrow_schema , stream = buf_writer , writer_batch_size = writer_batch_size )
813+ writer = ArrowWriter (features = features , stream = buf_writer , writer_batch_size = writer_batch_size )
834814 else :
835815 buf_writer = None
836816 if verbose :
837817 logger .info ("Caching processed dataset at %s" , cache_file_name )
838- writer = ArrowWriter (schema = arrow_schema , path = cache_file_name , writer_batch_size = writer_batch_size )
818+ writer = ArrowWriter (features = features , path = cache_file_name , writer_batch_size = writer_batch_size )
839819
840820 # Loop over single examples or batches and write to buffer/file if examples are to be updated
841821 if not batched :
@@ -928,15 +908,8 @@ def map_function(batch, *args):
928908
929909 return result
930910
931- # to avoid errors with the arrow_schema we define it here
932- test_inputs = self [:2 ]
933- if "remove_columns" in kwargs :
934- test_inputs = {key : test_inputs [key ] for key in (test_inputs .keys () - kwargs ["remove_columns" ])}
935- test_inputs = map_all_sequences_to_lists (test_inputs )
936- arrow_schema = pa .Table .from_pydict (test_inputs ).schema
937-
938911 # return map function
939- return self .map (map_function , batched = True , with_indices = with_indices , arrow_schema = arrow_schema , ** kwargs )
912+ return self .map (map_function , batched = True , with_indices = with_indices , features = self . features , ** kwargs )
940913
941914 def select (
942915 self ,
@@ -991,12 +964,12 @@ def select(
991964 # Prepare output buffer and batched writer in memory or on file if we update the table
992965 if keep_in_memory or not self ._data_files :
993966 buf_writer = pa .BufferOutputStream ()
994- writer = ArrowWriter (schema = self .schema , stream = buf_writer , writer_batch_size = writer_batch_size )
967+ writer = ArrowWriter (features = self .features , stream = buf_writer , writer_batch_size = writer_batch_size )
995968 else :
996969 buf_writer = None
997970 if verbose :
998971 logger .info ("Caching processed dataset at %s" , cache_file_name )
999- writer = ArrowWriter (schema = self .schema , path = cache_file_name , writer_batch_size = writer_batch_size )
972+ writer = ArrowWriter (features = self .features , path = cache_file_name , writer_batch_size = writer_batch_size )
1000973
1001974 # Loop over batches and write to buffer/file if examples are to be updated
1002975 for i in tqdm (range (0 , len (indices ), reader_batch_size ), disable = not verbose ):
0 commit comments