@@ -299,24 +299,17 @@ def _get_output_signature(
299299 f"Unrecognized array dtype { np_arrays [0 ].dtype } . \n "
300300 "Nested types and image/audio types are not supported yet."
301301 )
302- if (
303- column in dataset
304- and isinstance (dataset .features [column ], Sequence )
305- and dataset .features [column ].length != - 1
306- ):
307- static_shape = [batch_size , dataset .features [column ].length ]
308- else :
309- shapes = [array .shape for array in np_arrays ]
310- static_shape = []
311- for dim in range (len (shapes [0 ])):
312- sizes = set ([shape [dim ] for shape in shapes ])
313- if dim == 0 :
314- static_shape .append (batch_size )
315- continue
316- if len (sizes ) == 1 : # This dimension looks constant
317- static_shape .append (sizes .pop ())
318- else : # Use None for variable dimensions
319- static_shape .append (None )
302+ shapes = [array .shape for array in np_arrays ]
303+ static_shape = []
304+ for dim in range (len (shapes [0 ])):
305+ sizes = set ([shape [dim ] for shape in shapes ])
306+ if dim == 0 :
307+ static_shape .append (batch_size )
308+ continue
309+ if len (sizes ) == 1 : # This dimension looks constant
310+ static_shape .append (sizes .pop ())
311+ else : # Use None for variable dimensions
312+ static_shape .append (None )
320313 tf_columns_to_signatures [column ] = tf .TensorSpec (shape = static_shape , dtype = tf_dtype )
321314 np_columns_to_dtypes [column ] = np_dtype
322315
0 commit comments