Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 229 additions & 6 deletions datafusion/core/src/datasource/file_format/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ use std::fmt::{self, Debug};
use std::sync::Arc;

use super::write::orchestration::stateless_multipart_put;
use super::{FileFormat, FileFormatFactory, DEFAULT_SCHEMA_INFER_MAX_RECORD};
use super::{
Decoder, DecoderDeserializer, FileFormat, FileFormatFactory,
DEFAULT_SCHEMA_INFER_MAX_RECORD,
};
use crate::datasource::file_format::file_compression_type::FileCompressionType;
use crate::datasource::file_format::write::BatchSerializer;
use crate::datasource::physical_plan::{
Expand All @@ -38,8 +41,8 @@ use crate::physical_plan::{

use arrow::array::RecordBatch;
use arrow::csv::WriterBuilder;
use arrow::datatypes::SchemaRef;
use arrow::datatypes::{DataType, Field, Fields, Schema};
use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef};
use arrow_schema::ArrowError;
use datafusion_common::config::{ConfigField, ConfigFileType, CsvOptions};
use datafusion_common::file_options::csv_writer::CsvWriterOptions;
use datafusion_common::{
Expand Down Expand Up @@ -293,6 +296,45 @@ impl CsvFormat {
}
}

#[derive(Debug)]
pub(crate) struct CsvDecoder {
inner: arrow::csv::reader::Decoder,
}

impl CsvDecoder {
pub(crate) fn new(decoder: arrow::csv::reader::Decoder) -> Self {
Self { inner: decoder }
}
}

impl Decoder for CsvDecoder {
fn decode(&mut self, buf: &[u8]) -> Result<usize, ArrowError> {
self.inner.decode(buf)
}

fn flush(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
self.inner.flush()
}

fn can_flush_early(&self) -> bool {
self.inner.capacity() == 0
}
}

impl Debug for CsvSerializer {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CsvSerializer")
.field("header", &self.header)
.finish()
}
}

impl From<arrow::csv::reader::Decoder> for DecoderDeserializer<CsvDecoder> {
fn from(decoder: arrow::csv::reader::Decoder) -> Self {
DecoderDeserializer::new(CsvDecoder::new(decoder))
}
}

#[async_trait]
impl FileFormat for CsvFormat {
fn as_any(&self) -> &dyn Any {
Expand Down Expand Up @@ -692,23 +734,28 @@ impl DataSink for CsvSink {
mod tests {
use super::super::test_util::scan_format;
use super::*;
use crate::arrow::util::pretty;
use crate::assert_batches_eq;
use crate::datasource::file_format::file_compression_type::FileCompressionType;
use crate::datasource::file_format::test_util::VariableStream;
use crate::datasource::file_format::{
BatchDeserializer, DecoderDeserializer, DeserializerOutput,
};
use crate::datasource::listing::ListingOptions;
use crate::execution::session_state::SessionStateBuilder;
use crate::physical_plan::collect;
use crate::prelude::{CsvReadOptions, SessionConfig, SessionContext};
use crate::test_util::arrow_test_data;

use arrow::compute::concat_batches;
use arrow::csv::ReaderBuilder;
use arrow::util::pretty::pretty_format_batches;
use arrow_array::{BooleanArray, Float64Array, Int32Array, StringArray};
use datafusion_common::cast::as_string_array;
use datafusion_common::internal_err;
use datafusion_common::stats::Precision;
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
use datafusion_expr::{col, lit};

use crate::execution::session_state::SessionStateBuilder;
use chrono::DateTime;
use object_store::local::LocalFileSystem;
use object_store::path::Path;
Expand Down Expand Up @@ -1097,7 +1144,7 @@ mod tests {
) -> Result<usize> {
let df = ctx.sql(&format!("EXPLAIN {sql}")).await?;
let result = df.collect().await?;
let plan = format!("{}", &pretty::pretty_format_batches(&result)?);
let plan = format!("{}", &pretty_format_batches(&result)?);

let re = Regex::new(r"CsvExec: file_groups=\{(\d+) group").unwrap();

Expand Down Expand Up @@ -1464,4 +1511,180 @@ mod tests {

Ok(())
}

#[rstest]
fn test_csv_deserializer_with_finish(
#[values(1, 5, 17)] batch_size: usize,
#[values(0, 5, 93)] line_count: usize,
) -> Result<()> {
let schema = csv_schema();
let generator = CsvBatchGenerator::new(batch_size, line_count);
let mut deserializer = csv_deserializer(batch_size, &schema);

for data in generator {
deserializer.digest(data);
}
deserializer.finish();

let batch_count = line_count.div_ceil(batch_size);

let mut all_batches = RecordBatch::new_empty(schema.clone());
for _ in 0..batch_count {
let output = deserializer.next()?;
let DeserializerOutput::RecordBatch(batch) = output else {
panic!("Expected RecordBatch, got {:?}", output);
};
all_batches = concat_batches(&schema, &[all_batches, batch])?;
}
assert_eq!(deserializer.next()?, DeserializerOutput::InputExhausted);

let expected = csv_expected_batch(schema, line_count)?;

assert_eq!(
expected.clone(),
all_batches.clone(),
"Expected:\n{}\nActual:\n{}",
pretty_format_batches(&[expected])?,
pretty_format_batches(&[all_batches])?,
);

Ok(())
}

#[rstest]
fn test_csv_deserializer_without_finish(
#[values(1, 5, 17)] batch_size: usize,
#[values(0, 5, 93)] line_count: usize,
) -> Result<()> {
let schema = csv_schema();
let generator = CsvBatchGenerator::new(batch_size, line_count);
let mut deserializer = csv_deserializer(batch_size, &schema);

for data in generator {
deserializer.digest(data);
}

let batch_count = line_count / batch_size;

let mut all_batches = RecordBatch::new_empty(schema.clone());
for _ in 0..batch_count {
let output = deserializer.next()?;
let DeserializerOutput::RecordBatch(batch) = output else {
panic!("Expected RecordBatch, got {:?}", output);
};
all_batches = concat_batches(&schema, &[all_batches, batch])?;
}
assert_eq!(deserializer.next()?, DeserializerOutput::RequiresMoreData);

let expected = csv_expected_batch(schema, batch_count * batch_size)?;

assert_eq!(
expected.clone(),
all_batches.clone(),
"Expected:\n{}\nActual:\n{}",
pretty_format_batches(&[expected])?,
pretty_format_batches(&[all_batches])?,
);

Ok(())
}

struct CsvBatchGenerator {
batch_size: usize,
line_count: usize,
offset: usize,
}

impl CsvBatchGenerator {
fn new(batch_size: usize, line_count: usize) -> Self {
Self {
batch_size,
line_count,
offset: 0,
}
}
}

impl Iterator for CsvBatchGenerator {
type Item = Bytes;

fn next(&mut self) -> Option<Self::Item> {
// Return `batch_size` rows per batch:
let mut buffer = Vec::new();
for _ in 0..self.batch_size {
if self.offset >= self.line_count {
break;
}
buffer.extend_from_slice(&csv_line(self.offset));
self.offset += 1;
}

(!buffer.is_empty()).then(|| buffer.into())
}
}

fn csv_expected_batch(
schema: SchemaRef,
line_count: usize,
) -> Result<RecordBatch, DataFusionError> {
let mut c1 = Vec::with_capacity(line_count);
let mut c2 = Vec::with_capacity(line_count);
let mut c3 = Vec::with_capacity(line_count);
let mut c4 = Vec::with_capacity(line_count);

for i in 0..line_count {
let (int_value, float_value, bool_value, char_value) = csv_values(i);
c1.push(int_value);
c2.push(float_value);
c3.push(bool_value);
c4.push(char_value);
}

let expected = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(c1)),
Arc::new(Float64Array::from(c2)),
Arc::new(BooleanArray::from(c3)),
Arc::new(StringArray::from(c4)),
],
)?;
Ok(expected)
}

fn csv_line(line_number: usize) -> Bytes {
let (int_value, float_value, bool_value, char_value) = csv_values(line_number);
format!(
"{},{},{},{}\n",
int_value, float_value, bool_value, char_value
)
.into()
}

fn csv_values(line_number: usize) -> (i32, f64, bool, String) {
let int_value = line_number as i32;
let float_value = line_number as f64;
let bool_value = line_number % 2 == 0;
let char_value = format!("{}-string", line_number);
(int_value, float_value, bool_value, char_value)
}

fn csv_schema() -> Arc<Schema> {
Arc::new(Schema::new(vec![
Field::new("c1", DataType::Int32, true),
Field::new("c2", DataType::Float64, true),
Field::new("c3", DataType::Boolean, true),
Field::new("c4", DataType::Utf8, true),
]))
}

fn csv_deserializer(
batch_size: usize,
schema: &Arc<Schema>,
) -> impl BatchDeserializer<Bytes> {
let decoder = ReaderBuilder::new(schema.clone())
.with_batch_size(batch_size)
.build_decoder();
DecoderDeserializer::new(CsvDecoder::new(decoder))
}
}
Loading