@@ -23,7 +23,10 @@ use std::fmt::{self, Debug};
2323use std:: sync:: Arc ;
2424
2525use super :: write:: orchestration:: stateless_multipart_put;
26- use super :: { FileFormat , FileFormatFactory , DEFAULT_SCHEMA_INFER_MAX_RECORD } ;
26+ use super :: {
27+ Decoder , DecoderDeserializer , FileFormat , FileFormatFactory ,
28+ DEFAULT_SCHEMA_INFER_MAX_RECORD ,
29+ } ;
2730use crate :: datasource:: file_format:: file_compression_type:: FileCompressionType ;
2831use crate :: datasource:: file_format:: write:: BatchSerializer ;
2932use crate :: datasource:: physical_plan:: {
@@ -38,8 +41,8 @@ use crate::physical_plan::{
3841
3942use arrow:: array:: RecordBatch ;
4043use arrow:: csv:: WriterBuilder ;
41- use arrow:: datatypes:: SchemaRef ;
42- use arrow :: datatypes :: { DataType , Field , Fields , Schema } ;
44+ use arrow:: datatypes:: { DataType , Field , Fields , Schema , SchemaRef } ;
45+ use arrow_schema :: ArrowError ;
4346use datafusion_common:: config:: { ConfigField , ConfigFileType , CsvOptions } ;
4447use datafusion_common:: file_options:: csv_writer:: CsvWriterOptions ;
4548use datafusion_common:: {
@@ -293,6 +296,45 @@ impl CsvFormat {
293296 }
294297}
295298
299+ #[ derive( Debug ) ]
300+ pub ( crate ) struct CsvDecoder {
301+ inner : arrow:: csv:: reader:: Decoder ,
302+ }
303+
304+ impl CsvDecoder {
305+ pub ( crate ) fn new ( decoder : arrow:: csv:: reader:: Decoder ) -> Self {
306+ Self { inner : decoder }
307+ }
308+ }
309+
310+ impl Decoder for CsvDecoder {
311+ fn decode ( & mut self , buf : & [ u8 ] ) -> Result < usize , ArrowError > {
312+ self . inner . decode ( buf)
313+ }
314+
315+ fn flush ( & mut self ) -> Result < Option < RecordBatch > , ArrowError > {
316+ self . inner . flush ( )
317+ }
318+
319+ fn can_flush_early ( & self ) -> bool {
320+ self . inner . capacity ( ) == 0
321+ }
322+ }
323+
324+ impl Debug for CsvSerializer {
325+ fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
326+ f. debug_struct ( "CsvSerializer" )
327+ . field ( "header" , & self . header )
328+ . finish ( )
329+ }
330+ }
331+
332+ impl From < arrow:: csv:: reader:: Decoder > for DecoderDeserializer < CsvDecoder > {
333+ fn from ( decoder : arrow:: csv:: reader:: Decoder ) -> Self {
334+ DecoderDeserializer :: new ( CsvDecoder :: new ( decoder) )
335+ }
336+ }
337+
296338#[ async_trait]
297339impl FileFormat for CsvFormat {
298340 fn as_any ( & self ) -> & dyn Any {
@@ -692,23 +734,28 @@ impl DataSink for CsvSink {
692734mod tests {
693735 use super :: super :: test_util:: scan_format;
694736 use super :: * ;
695- use crate :: arrow:: util:: pretty;
696737 use crate :: assert_batches_eq;
697738 use crate :: datasource:: file_format:: file_compression_type:: FileCompressionType ;
698739 use crate :: datasource:: file_format:: test_util:: VariableStream ;
740+ use crate :: datasource:: file_format:: {
741+ BatchDeserializer , DecoderDeserializer , DeserializerOutput ,
742+ } ;
699743 use crate :: datasource:: listing:: ListingOptions ;
744+ use crate :: execution:: session_state:: SessionStateBuilder ;
700745 use crate :: physical_plan:: collect;
701746 use crate :: prelude:: { CsvReadOptions , SessionConfig , SessionContext } ;
702747 use crate :: test_util:: arrow_test_data;
703748
704749 use arrow:: compute:: concat_batches;
750+ use arrow:: csv:: ReaderBuilder ;
751+ use arrow:: util:: pretty:: pretty_format_batches;
752+ use arrow_array:: { BooleanArray , Float64Array , Int32Array , StringArray } ;
705753 use datafusion_common:: cast:: as_string_array;
706754 use datafusion_common:: internal_err;
707755 use datafusion_common:: stats:: Precision ;
708756 use datafusion_execution:: runtime_env:: RuntimeEnvBuilder ;
709757 use datafusion_expr:: { col, lit} ;
710758
711- use crate :: execution:: session_state:: SessionStateBuilder ;
712759 use chrono:: DateTime ;
713760 use object_store:: local:: LocalFileSystem ;
714761 use object_store:: path:: Path ;
@@ -1097,7 +1144,7 @@ mod tests {
10971144 ) -> Result < usize > {
10981145 let df = ctx. sql ( & format ! ( "EXPLAIN {sql}" ) ) . await ?;
10991146 let result = df. collect ( ) . await ?;
1100- let plan = format ! ( "{}" , & pretty :: pretty_format_batches( & result) ?) ;
1147+ let plan = format ! ( "{}" , & pretty_format_batches( & result) ?) ;
11011148
11021149 let re = Regex :: new ( r"CsvExec: file_groups=\{(\d+) group" ) . unwrap ( ) ;
11031150
@@ -1464,4 +1511,180 @@ mod tests {
14641511
14651512 Ok ( ( ) )
14661513 }
1514+
1515+ #[ rstest]
1516+ fn test_csv_deserializer_with_finish (
1517+ #[ values( 1 , 5 , 17 ) ] batch_size : usize ,
1518+ #[ values( 0 , 5 , 93 ) ] line_count : usize ,
1519+ ) -> Result < ( ) > {
1520+ let schema = csv_schema ( ) ;
1521+ let generator = CsvBatchGenerator :: new ( batch_size, line_count) ;
1522+ let mut deserializer = csv_deserializer ( batch_size, & schema) ;
1523+
1524+ for data in generator {
1525+ deserializer. digest ( data) ;
1526+ }
1527+ deserializer. finish ( ) ;
1528+
1529+ let batch_count = line_count. div_ceil ( batch_size) ;
1530+
1531+ let mut all_batches = RecordBatch :: new_empty ( schema. clone ( ) ) ;
1532+ for _ in 0 ..batch_count {
1533+ let output = deserializer. next ( ) ?;
1534+ let DeserializerOutput :: RecordBatch ( batch) = output else {
1535+ panic ! ( "Expected RecordBatch, got {:?}" , output) ;
1536+ } ;
1537+ all_batches = concat_batches ( & schema, & [ all_batches, batch] ) ?;
1538+ }
1539+ assert_eq ! ( deserializer. next( ) ?, DeserializerOutput :: InputExhausted ) ;
1540+
1541+ let expected = csv_expected_batch ( schema, line_count) ?;
1542+
1543+ assert_eq ! (
1544+ expected. clone( ) ,
1545+ all_batches. clone( ) ,
1546+ "Expected:\n {}\n Actual:\n {}" ,
1547+ pretty_format_batches( & [ expected] ) ?,
1548+ pretty_format_batches( & [ all_batches] ) ?,
1549+ ) ;
1550+
1551+ Ok ( ( ) )
1552+ }
1553+
1554+ #[ rstest]
1555+ fn test_csv_deserializer_without_finish (
1556+ #[ values( 1 , 5 , 17 ) ] batch_size : usize ,
1557+ #[ values( 0 , 5 , 93 ) ] line_count : usize ,
1558+ ) -> Result < ( ) > {
1559+ let schema = csv_schema ( ) ;
1560+ let generator = CsvBatchGenerator :: new ( batch_size, line_count) ;
1561+ let mut deserializer = csv_deserializer ( batch_size, & schema) ;
1562+
1563+ for data in generator {
1564+ deserializer. digest ( data) ;
1565+ }
1566+
1567+ let batch_count = line_count / batch_size;
1568+
1569+ let mut all_batches = RecordBatch :: new_empty ( schema. clone ( ) ) ;
1570+ for _ in 0 ..batch_count {
1571+ let output = deserializer. next ( ) ?;
1572+ let DeserializerOutput :: RecordBatch ( batch) = output else {
1573+ panic ! ( "Expected RecordBatch, got {:?}" , output) ;
1574+ } ;
1575+ all_batches = concat_batches ( & schema, & [ all_batches, batch] ) ?;
1576+ }
1577+ assert_eq ! ( deserializer. next( ) ?, DeserializerOutput :: RequiresMoreData ) ;
1578+
1579+ let expected = csv_expected_batch ( schema, batch_count * batch_size) ?;
1580+
1581+ assert_eq ! (
1582+ expected. clone( ) ,
1583+ all_batches. clone( ) ,
1584+ "Expected:\n {}\n Actual:\n {}" ,
1585+ pretty_format_batches( & [ expected] ) ?,
1586+ pretty_format_batches( & [ all_batches] ) ?,
1587+ ) ;
1588+
1589+ Ok ( ( ) )
1590+ }
1591+
1592+ struct CsvBatchGenerator {
1593+ batch_size : usize ,
1594+ line_count : usize ,
1595+ offset : usize ,
1596+ }
1597+
1598+ impl CsvBatchGenerator {
1599+ fn new ( batch_size : usize , line_count : usize ) -> Self {
1600+ Self {
1601+ batch_size,
1602+ line_count,
1603+ offset : 0 ,
1604+ }
1605+ }
1606+ }
1607+
1608+ impl Iterator for CsvBatchGenerator {
1609+ type Item = Bytes ;
1610+
1611+ fn next ( & mut self ) -> Option < Self :: Item > {
1612+ // Return `batch_size` rows per batch:
1613+ let mut buffer = Vec :: new ( ) ;
1614+ for _ in 0 ..self . batch_size {
1615+ if self . offset >= self . line_count {
1616+ break ;
1617+ }
1618+ buffer. extend_from_slice ( & csv_line ( self . offset ) ) ;
1619+ self . offset += 1 ;
1620+ }
1621+
1622+ ( !buffer. is_empty ( ) ) . then ( || buffer. into ( ) )
1623+ }
1624+ }
1625+
1626+ fn csv_expected_batch (
1627+ schema : SchemaRef ,
1628+ line_count : usize ,
1629+ ) -> Result < RecordBatch , DataFusionError > {
1630+ let mut c1 = Vec :: with_capacity ( line_count) ;
1631+ let mut c2 = Vec :: with_capacity ( line_count) ;
1632+ let mut c3 = Vec :: with_capacity ( line_count) ;
1633+ let mut c4 = Vec :: with_capacity ( line_count) ;
1634+
1635+ for i in 0 ..line_count {
1636+ let ( int_value, float_value, bool_value, char_value) = csv_values ( i) ;
1637+ c1. push ( int_value) ;
1638+ c2. push ( float_value) ;
1639+ c3. push ( bool_value) ;
1640+ c4. push ( char_value) ;
1641+ }
1642+
1643+ let expected = RecordBatch :: try_new (
1644+ schema. clone ( ) ,
1645+ vec ! [
1646+ Arc :: new( Int32Array :: from( c1) ) ,
1647+ Arc :: new( Float64Array :: from( c2) ) ,
1648+ Arc :: new( BooleanArray :: from( c3) ) ,
1649+ Arc :: new( StringArray :: from( c4) ) ,
1650+ ] ,
1651+ ) ?;
1652+ Ok ( expected)
1653+ }
1654+
1655+ fn csv_line ( line_number : usize ) -> Bytes {
1656+ let ( int_value, float_value, bool_value, char_value) = csv_values ( line_number) ;
1657+ format ! (
1658+ "{},{},{},{}\n " ,
1659+ int_value, float_value, bool_value, char_value
1660+ )
1661+ . into ( )
1662+ }
1663+
1664+ fn csv_values ( line_number : usize ) -> ( i32 , f64 , bool , String ) {
1665+ let int_value = line_number as i32 ;
1666+ let float_value = line_number as f64 ;
1667+ let bool_value = line_number % 2 == 0 ;
1668+ let char_value = format ! ( "{}-string" , line_number) ;
1669+ ( int_value, float_value, bool_value, char_value)
1670+ }
1671+
1672+ fn csv_schema ( ) -> Arc < Schema > {
1673+ Arc :: new ( Schema :: new ( vec ! [
1674+ Field :: new( "c1" , DataType :: Int32 , true ) ,
1675+ Field :: new( "c2" , DataType :: Float64 , true ) ,
1676+ Field :: new( "c3" , DataType :: Boolean , true ) ,
1677+ Field :: new( "c4" , DataType :: Utf8 , true ) ,
1678+ ] ) )
1679+ }
1680+
1681+ fn csv_deserializer (
1682+ batch_size : usize ,
1683+ schema : & Arc < Schema > ,
1684+ ) -> impl BatchDeserializer < Bytes > {
1685+ let decoder = ReaderBuilder :: new ( schema. clone ( ) )
1686+ . with_batch_size ( batch_size)
1687+ . build_decoder ( ) ;
1688+ DecoderDeserializer :: new ( CsvDecoder :: new ( decoder) )
1689+ }
14671690}
0 commit comments