Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions arrow-flight/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ pub fn flight_data_to_arrow_batch(
batch,
schema,
dictionaries_by_field,
None,
)
})?
}
Expand Down
157 changes: 128 additions & 29 deletions arrow/src/ipc/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ pub fn read_record_batch(
batch: ipc::RecordBatch,
schema: SchemaRef,
dictionaries: &[Option<ArrayRef>],
projection: Option<&[usize]>,
) -> Result<RecordBatch> {
let buffers = batch.buffers().ok_or_else(|| {
ArrowError::IoError("Unable to get buffers from IPC RecordBatch".to_string())
Expand All @@ -477,23 +478,43 @@ pub fn read_record_batch(
let mut node_index = 0;
let mut arrays = vec![];

// keep track of index as lists require more than one node
for field in schema.fields() {
let triple = create_array(
field_nodes,
field.data_type(),
buf,
buffers,
dictionaries,
node_index,
buffer_index,
)?;
node_index = triple.1;
buffer_index = triple.2;
arrays.push(triple.0);
}
if let Some(projection) = projection {
let fields = schema.fields();
for &index in projection {
let field = &fields[index];
let triple = create_array(
field_nodes,
field.data_type(),
buf,
buffers,
dictionaries,
node_index,
buffer_index,
)?;
node_index = triple.1;
buffer_index = triple.2;
arrays.push(triple.0);
}

RecordBatch::try_new(schema, arrays)
RecordBatch::try_new(Arc::new(schema.project(projection)?), arrays)
} else {
// keep track of index as lists require more than one node
for field in schema.fields() {
let triple = create_array(
field_nodes,
field.data_type(),
buf,
buffers,
dictionaries,
node_index,
buffer_index,
)?;
node_index = triple.1;
buffer_index = triple.2;
arrays.push(triple.0);
}
RecordBatch::try_new(schema, arrays)
}
}

/// Read the dictionary from the buffer and provided metadata,
Expand Down Expand Up @@ -532,6 +553,7 @@ pub fn read_dictionary(
batch.data().unwrap(),
Arc::new(schema),
dictionaries_by_field,
None,
)?;
Some(record_batch.column(0).clone())
}
Expand Down Expand Up @@ -581,14 +603,17 @@ pub struct FileReader<R: Read + Seek> {

/// Metadata version
metadata_version: ipc::MetadataVersion,

/// Optional projection and projected_schema
projection: Option<(Vec<usize>, Schema)>,
}

impl<R: Read + Seek> FileReader<R> {
/// Try to create a new file reader
///
/// Returns errors if the file does not meet the Arrow Format header and footer
/// requirements
pub fn try_new(reader: R) -> Result<Self> {
pub fn try_new(reader: R, projection: Option<Vec<usize>>) -> Result<Self> {
let mut reader = BufReader::new(reader);
// check if header and footer contain correct magic bytes
let mut magic_buffer: [u8; 6] = [0; 6];
Expand Down Expand Up @@ -672,6 +697,18 @@ impl<R: Read + Seek> FileReader<R> {
};
}

let projection = projection.map(|projection| {
let fields = projection
.iter()
.map(|x| schema.fields[*x].clone())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it fine if this panics if x > fields.len()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I think here needs some extra check on the projection values to avoid panics.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we could reuse Schema::project: https://docs.rs/arrow/9.1.0/arrow/datatypes/struct.Schema.html#method.project

(which also handles metadata correctly)

.collect();
let schema = Schema {
fields,
metadata: schema.metadata.clone(),
};
(projection, schema)
});

Ok(Self {
reader,
schema: Arc::new(schema),
Expand All @@ -680,6 +717,7 @@ impl<R: Read + Seek> FileReader<R> {
total_blocks,
dictionaries_by_field,
metadata_version: footer.version(),
projection,
})
}

Expand Down Expand Up @@ -760,6 +798,8 @@ impl<R: Read + Seek> FileReader<R> {
batch,
self.schema(),
&self.dictionaries_by_field,
self.projection.as_ref().map(|x| x.0.as_ref()),

).map(Some)
}
ipc::MessageHeader::NONE => {
Expand Down Expand Up @@ -808,6 +848,9 @@ pub struct StreamReader<R: Read> {
///
/// This value is set to `true` the first time the reader's `next()` returns `None`.
finished: bool,

/// Optional projection
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// Optional projection
/// Optional projection and projected schema

projection: Option<(Vec<usize>, Schema)>,
}

impl<R: Read> StreamReader<R> {
Expand All @@ -816,7 +859,7 @@ impl<R: Read> StreamReader<R> {
/// The first message in the stream is the schema, the reader will fail if it does not
/// encounter a schema.
/// To check if the reader is done, use `is_finished(self)`
pub fn try_new(reader: R) -> Result<Self> {
pub fn try_new(reader: R, projection: Option<Vec<usize>>) -> Result<Self> {
let mut reader = BufReader::new(reader);
// determine metadata length
let mut meta_size: [u8; 4] = [0; 4];
Expand Down Expand Up @@ -845,11 +888,23 @@ impl<R: Read> StreamReader<R> {
// Create an array of optional dictionary value arrays, one per field.
let dictionaries_by_field = vec![None; schema.fields().len()];

let projection = projection.map(|projection| {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here -- Schema::projection might make this code easier to read

let fields = projection
.iter()
.map(|x| schema.fields[*x].clone())
.collect();
let schema = Schema {
fields,
metadata: schema.metadata.clone(),
};
(projection, schema)
});
Ok(Self {
reader,
schema: Arc::new(schema),
finished: false,
dictionaries_by_field,
projection,
})
}

Expand Down Expand Up @@ -922,7 +977,7 @@ impl<R: Read> StreamReader<R> {
let mut buf = vec![0; message.bodyLength() as usize];
self.reader.read_exact(&mut buf)?;

read_record_batch(&buf, batch, self.schema(), &self.dictionaries_by_field).map(Some)
read_record_batch(&buf, batch, self.schema(), &self.dictionaries_by_field, self.projection.as_ref().map(|x| x.0.as_ref())).map(Some)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe it is worth adding something like

impl<R: Read> StreamReader<R> {
...
  /// get projected schema, if any
  pub fn projected_schema(&self) -> Option<&Schema> {
    ...
  }

}
ipc::MessageHeader::DictionaryBatch => {
let batch = message.header_as_dictionary_batch().ok_or_else(|| {
Expand Down Expand Up @@ -998,7 +1053,7 @@ mod tests {
))
.unwrap();

let mut reader = FileReader::try_new(file).unwrap();
let mut reader = FileReader::try_new(file, None).unwrap();

// read expected JSON output
let arrow_json = read_gzip_json(version, path);
Expand All @@ -1015,7 +1070,7 @@ mod tests {
testdata
))
.unwrap();
FileReader::try_new(file).unwrap();
FileReader::try_new(file, None).unwrap();
}

#[test]
Expand All @@ -1031,7 +1086,7 @@ mod tests {
testdata
))
.unwrap();
FileReader::try_new(file).unwrap();
FileReader::try_new(file, None).unwrap();
}

#[test]
Expand All @@ -1056,7 +1111,39 @@ mod tests {
))
.unwrap();

FileReader::try_new(file).unwrap();
FileReader::try_new(file, None).unwrap();
});
}

#[test]
fn projection_should_work() {
// complementary to the previous test
let testdata = crate::util::test_util::arrow_test_data();
let paths = vec![
"generated_interval",
"generated_datetime",
// "generated_map", Err: Last offset 872415232 of Utf8 is larger than values length 52 (https://github.com/apache/arrow-rs/issues/859)
"generated_nested",
"generated_null_trivial",
"generated_null",
"generated_primitive_no_batches",
"generated_primitive_zerolength",
"generated_primitive",
];
paths.iter().for_each(|path| {
let file = File::open(format!(
"{}/arrow-ipc-stream/integration/1.0.0-bigendian/{}.arrow_file",
testdata, path
))
.unwrap();

let reader = FileReader::try_new(file, Some(vec![0])).unwrap();
let datatype_0 = reader.schema().fields()[0].data_type().clone();
reader.for_each(|batch| {
let batch = batch.unwrap();
assert_eq!(batch.columns().len(), 1);
assert_eq!(datatype_0, batch.schema().fields()[0].data_type().clone());
});
});
}

Expand All @@ -1083,7 +1170,7 @@ mod tests {
))
.unwrap();

let mut reader = StreamReader::try_new(file).unwrap();
let mut reader = StreamReader::try_new(file, None).unwrap();

// read expected JSON output
let arrow_json = read_gzip_json(version, path);
Expand Down Expand Up @@ -1120,7 +1207,7 @@ mod tests {
))
.unwrap();

let mut reader = FileReader::try_new(file).unwrap();
let mut reader = FileReader::try_new(file, None).unwrap();

// read expected JSON output
let arrow_json = read_gzip_json(version, path);
Expand Down Expand Up @@ -1153,7 +1240,7 @@ mod tests {
))
.unwrap();

let mut reader = StreamReader::try_new(file).unwrap();
let mut reader = StreamReader::try_new(file, None).unwrap();

// read expected JSON output
let arrow_json = read_gzip_json(version, path);
Expand Down Expand Up @@ -1189,7 +1276,7 @@ mod tests {

// read stream back
let file = File::open("target/debug/testdata/float.stream").unwrap();
let reader = StreamReader::try_new(file).unwrap();
let reader = StreamReader::try_new(file, None).unwrap();

reader.for_each(|batch| {
let batch = batch.unwrap();
Expand All @@ -1211,7 +1298,19 @@ mod tests {
.value(0)
!= 0.0
);
})
});

let file = File::open("target/debug/testdata/float.stream").unwrap();

// Read with projection
let reader = StreamReader::try_new(file, Some(vec![0, 3])).unwrap();

reader.for_each(|batch| {
let batch = batch.unwrap();
assert_eq!(batch.schema().fields().len(), 2);
assert_eq!(batch.schema().fields()[0].data_type(), &DataType::Float32);
assert_eq!(batch.schema().fields()[1].data_type(), &DataType::Int32);
});
}

fn roundtrip_ipc(rb: &RecordBatch) -> RecordBatch {
Expand All @@ -1223,7 +1322,7 @@ mod tests {
drop(writer);

let mut reader =
ipc::reader::FileReader::try_new(std::io::Cursor::new(buf)).unwrap();
ipc::reader::FileReader::try_new(std::io::Cursor::new(buf), None).unwrap();
reader.next().unwrap().unwrap()
}

Expand Down
Loading