@@ -93,37 +93,34 @@ def write(self) -> int:
9393 _ = self .to_json_kwargs .pop ("path_or_buf" , None )
9494 orient = self .to_json_kwargs .pop ("orient" , "records" )
9595 lines = self .to_json_kwargs .pop ("lines" , True if orient == "records" else False )
96- index = self .to_json_kwargs .pop ("index" , False if orient in ["split" , "table" ] else True )
96+ if "index" not in self .to_json_kwargs and orient in ["split" , "table" ]:
97+ self .to_json_kwargs ["index" ] = False
9798 compression = self .to_json_kwargs .pop ("compression" , None )
9899
99100 if compression not in [None , "infer" , "gzip" , "bz2" , "xz" ]:
100101 raise NotImplementedError (f"`datasets` currently does not support { compression } compression" )
101102
102103 if isinstance (self .path_or_buf , (str , bytes , os .PathLike )):
103104 with fsspec .open (self .path_or_buf , "wb" , compression = compression ) as buffer :
104- written = self ._write (file_obj = buffer , orient = orient , lines = lines , index = index , ** self .to_json_kwargs )
105+ written = self ._write (file_obj = buffer , orient = orient , lines = lines , ** self .to_json_kwargs )
105106 else :
106107 if compression :
107108 raise NotImplementedError (
108109 f"The compression parameter is not supported when writing to a buffer, but compression={ compression } "
109110 " was passed. Please provide a local path instead."
110111 )
111- written = self ._write (
112- file_obj = self .path_or_buf , orient = orient , lines = lines , index = index , ** self .to_json_kwargs
113- )
112+ written = self ._write (file_obj = self .path_or_buf , orient = orient , lines = lines , ** self .to_json_kwargs )
114113 return written
115114
116115 def _batch_json (self , args ):
117- offset , orient , lines , index , to_json_kwargs = args
116+ offset , orient , lines , to_json_kwargs = args
118117
119118 batch = query_table (
120119 table = self .dataset .data ,
121120 key = slice (offset , offset + self .batch_size ),
122121 indices = self .dataset ._indices ,
123122 )
124- json_str = batch .to_pandas ().to_json (
125- path_or_buf = None , orient = orient , lines = lines , index = index , ** to_json_kwargs
126- )
123+ json_str = batch .to_pandas ().to_json (path_or_buf = None , orient = orient , lines = lines , ** to_json_kwargs )
127124 if not json_str .endswith ("\n " ):
128125 json_str += "\n "
129126 return json_str .encode (self .encoding )
@@ -133,7 +130,6 @@ def _write(
133130 file_obj : BinaryIO ,
134131 orient ,
135132 lines ,
136- index ,
137133 ** to_json_kwargs ,
138134 ) -> int :
139135 """Writes the pyarrow table as JSON lines to a binary file handle.
@@ -149,15 +145,15 @@ def _write(
149145 disable = not logging .is_progress_bar_enabled (),
150146 desc = "Creating json from Arrow format" ,
151147 ):
152- json_str = self ._batch_json ((offset , orient , lines , index , to_json_kwargs ))
148+ json_str = self ._batch_json ((offset , orient , lines , to_json_kwargs ))
153149 written += file_obj .write (json_str )
154150 else :
155151 num_rows , batch_size = len (self .dataset ), self .batch_size
156152 with multiprocessing .Pool (self .num_proc ) as pool :
157153 for json_str in logging .tqdm (
158154 pool .imap (
159155 self ._batch_json ,
160- [(offset , orient , lines , index , to_json_kwargs ) for offset in range (0 , num_rows , batch_size )],
156+ [(offset , orient , lines , to_json_kwargs ) for offset in range (0 , num_rows , batch_size )],
161157 ),
162158 total = (num_rows // batch_size ) + 1 if num_rows % batch_size else num_rows // batch_size ,
163159 unit = "ba" ,
0 commit comments