diff --git a/arrow-flight/src/utils.rs b/arrow-flight/src/utils.rs index 9232acc2aefb..fd5f4dc9e366 100644 --- a/arrow-flight/src/utils.rs +++ b/arrow-flight/src/utils.rs @@ -69,6 +69,7 @@ pub fn flight_data_to_arrow_batch( batch, schema, dictionaries_by_field, + None, ) })? } diff --git a/arrow/src/ipc/reader.rs b/arrow/src/ipc/reader.rs index 27633dfdb850..6a70768a2902 100644 --- a/arrow/src/ipc/reader.rs +++ b/arrow/src/ipc/reader.rs @@ -465,6 +465,7 @@ pub fn read_record_batch( batch: ipc::RecordBatch, schema: SchemaRef, dictionaries: &[Option], + projection: Option<&[usize]>, ) -> Result { let buffers = batch.buffers().ok_or_else(|| { ArrowError::IoError("Unable to get buffers from IPC RecordBatch".to_string()) @@ -477,23 +478,43 @@ pub fn read_record_batch( let mut node_index = 0; let mut arrays = vec![]; - // keep track of index as lists require more than one node - for field in schema.fields() { - let triple = create_array( - field_nodes, - field.data_type(), - buf, - buffers, - dictionaries, - node_index, - buffer_index, - )?; - node_index = triple.1; - buffer_index = triple.2; - arrays.push(triple.0); - } + if let Some(projection) = projection { + let fields = schema.fields(); + for &index in projection { + let field = &fields[index]; + let triple = create_array( + field_nodes, + field.data_type(), + buf, + buffers, + dictionaries, + node_index, + buffer_index, + )?; + node_index = triple.1; + buffer_index = triple.2; + arrays.push(triple.0); + } - RecordBatch::try_new(schema, arrays) + RecordBatch::try_new(Arc::new(schema.project(projection)?), arrays) + } else { + // keep track of index as lists require more than one node + for field in schema.fields() { + let triple = create_array( + field_nodes, + field.data_type(), + buf, + buffers, + dictionaries, + node_index, + buffer_index, + )?; + node_index = triple.1; + buffer_index = triple.2; + arrays.push(triple.0); + } + RecordBatch::try_new(schema, arrays) + } } /// Read the dictionary from the buffer and provided metadata, @@ -532,6 +553,7 @@ pub fn read_dictionary( batch.data().unwrap(), Arc::new(schema), dictionaries_by_field, + None, )?; Some(record_batch.column(0).clone()) } @@ -581,6 +603,9 @@ pub struct FileReader { /// Metadata version metadata_version: ipc::MetadataVersion, + + /// Optional projection and projected_schema + projection: Option<(Vec, Schema)>, } impl FileReader { @@ -588,7 +613,7 @@ impl FileReader { /// /// Returns errors if the file does not meet the Arrow Format header and footer /// requirements - pub fn try_new(reader: R) -> Result { + pub fn try_new(reader: R, projection: Option>) -> Result { let mut reader = BufReader::new(reader); // check if header and footer contain correct magic bytes let mut magic_buffer: [u8; 6] = [0; 6]; @@ -671,6 +696,13 @@ impl FileReader { } }; } + let projection = match projection { + Some(projection_indices) => { + let schema = schema.project(&projection_indices)?; + Some((projection_indices, schema)) + } + _ => None, + }; Ok(Self { reader, @@ -680,6 +712,7 @@ impl FileReader { total_blocks, dictionaries_by_field, metadata_version: footer.version(), + projection, }) } @@ -760,6 +793,8 @@ impl FileReader { batch, self.schema(), &self.dictionaries_by_field, + self.projection.as_ref().map(|x| x.0.as_ref()), + ).map(Some) } ipc::MessageHeader::NONE => { @@ -808,6 +843,9 @@ pub struct StreamReader { /// /// This value is set to `true` the first time the reader's `next()` returns `None`. finished: bool, + + /// Optional projection + projection: Option<(Vec, Schema)>, } impl StreamReader { @@ -816,7 +854,7 @@ impl StreamReader { /// The first message in the stream is the schema, the reader will fail if it does not /// encounter a schema. /// To check if the reader is done, use `is_finished(self)` - pub fn try_new(reader: R) -> Result { + pub fn try_new(reader: R, projection: Option>) -> Result { let mut reader = BufReader::new(reader); // determine metadata length let mut meta_size: [u8; 4] = [0; 4]; @@ -845,11 +883,19 @@ impl StreamReader { // Create an array of optional dictionary value arrays, one per field. let dictionaries_by_field = vec![None; schema.fields().len()]; + let projection = match projection { + Some(projection_indices) => { + let schema = schema.project(&projection_indices)?; + Some((projection_indices, schema)) + } + _ => None, + }; Ok(Self { reader, schema: Arc::new(schema), finished: false, dictionaries_by_field, + projection, }) } @@ -922,7 +968,7 @@ impl StreamReader { let mut buf = vec![0; message.bodyLength() as usize]; self.reader.read_exact(&mut buf)?; - read_record_batch(&buf, batch, self.schema(), &self.dictionaries_by_field).map(Some) + read_record_batch(&buf, batch, self.schema(), &self.dictionaries_by_field, self.projection.as_ref().map(|x| x.0.as_ref())).map(Some) } ipc::MessageHeader::DictionaryBatch => { let batch = message.header_as_dictionary_batch().ok_or_else(|| { @@ -998,7 +1044,7 @@ mod tests { )) .unwrap(); - let mut reader = FileReader::try_new(file).unwrap(); + let mut reader = FileReader::try_new(file, None).unwrap(); // read expected JSON output let arrow_json = read_gzip_json(version, path); @@ -1015,7 +1061,7 @@ mod tests { testdata )) .unwrap(); - FileReader::try_new(file).unwrap(); + FileReader::try_new(file, None).unwrap(); } #[test] @@ -1031,7 +1077,7 @@ mod tests { testdata )) .unwrap(); - FileReader::try_new(file).unwrap(); + FileReader::try_new(file, None).unwrap(); } #[test] @@ -1056,7 +1102,39 @@ mod tests { )) .unwrap(); - FileReader::try_new(file).unwrap(); + FileReader::try_new(file, None).unwrap(); + }); + } + + #[test] + fn projection_should_work() { + // complementary to the previous test + let testdata = crate::util::test_util::arrow_test_data(); + let paths = vec![ + "generated_interval", + "generated_datetime", + // "generated_map", Err: Last offset 872415232 of Utf8 is larger than values length 52 (https://github.com/apache/arrow-rs/issues/859) + "generated_nested", + "generated_null_trivial", + "generated_null", + "generated_primitive_no_batches", + "generated_primitive_zerolength", + "generated_primitive", + ]; + paths.iter().for_each(|path| { + let file = File::open(format!( + "{}/arrow-ipc-stream/integration/1.0.0-bigendian/{}.arrow_file", + testdata, path + )) + .unwrap(); + + let reader = FileReader::try_new(file, Some(vec![0])).unwrap(); + let datatype_0 = reader.schema().fields()[0].data_type().clone(); + reader.for_each(|batch| { + let batch = batch.unwrap(); + assert_eq!(batch.columns().len(), 1); + assert_eq!(datatype_0, batch.schema().fields()[0].data_type().clone()); + }); }); } @@ -1083,7 +1161,7 @@ mod tests { )) .unwrap(); - let mut reader = StreamReader::try_new(file).unwrap(); + let mut reader = StreamReader::try_new(file, None).unwrap(); // read expected JSON output let arrow_json = read_gzip_json(version, path); @@ -1120,7 +1198,7 @@ mod tests { )) .unwrap(); - let mut reader = FileReader::try_new(file).unwrap(); + let mut reader = FileReader::try_new(file, None).unwrap(); // read expected JSON output let arrow_json = read_gzip_json(version, path); @@ -1153,7 +1231,7 @@ mod tests { )) .unwrap(); - let mut reader = StreamReader::try_new(file).unwrap(); + let mut reader = StreamReader::try_new(file, None).unwrap(); // read expected JSON output let arrow_json = read_gzip_json(version, path); @@ -1189,7 +1267,7 @@ mod tests { // read stream back let file = File::open("target/debug/testdata/float.stream").unwrap(); - let reader = StreamReader::try_new(file).unwrap(); + let reader = StreamReader::try_new(file, None).unwrap(); reader.for_each(|batch| { let batch = batch.unwrap(); @@ -1211,7 +1289,19 @@ mod tests { .value(0) != 0.0 ); - }) + }); + + let file = File::open("target/debug/testdata/float.stream").unwrap(); + + // Read with projection + let reader = StreamReader::try_new(file, Some(vec![0, 3])).unwrap(); + + reader.for_each(|batch| { + let batch = batch.unwrap(); + assert_eq!(batch.schema().fields().len(), 2); + assert_eq!(batch.schema().fields()[0].data_type(), &DataType::Float32); + assert_eq!(batch.schema().fields()[1].data_type(), &DataType::Int32); + }); } fn roundtrip_ipc(rb: &RecordBatch) -> RecordBatch { @@ -1223,7 +1313,7 @@ mod tests { drop(writer); let mut reader = - ipc::reader::FileReader::try_new(std::io::Cursor::new(buf)).unwrap(); + ipc::reader::FileReader::try_new(std::io::Cursor::new(buf), None).unwrap(); reader.next().unwrap().unwrap() } diff --git a/arrow/src/ipc/writer.rs b/arrow/src/ipc/writer.rs index 7dda6a64c623..d76a496229c2 100644 --- a/arrow/src/ipc/writer.rs +++ b/arrow/src/ipc/writer.rs @@ -881,7 +881,7 @@ mod tests { let file = File::open(format!("target/debug/testdata/{}.arrow_file", "arrow")) .unwrap(); - let mut reader = FileReader::try_new(file).unwrap(); + let mut reader = FileReader::try_new(file, None).unwrap(); while let Some(Ok(read_batch)) = reader.next() { read_batch .columns() @@ -929,7 +929,7 @@ mod tests { { let file = File::open(&file_name).unwrap(); - let reader = FileReader::try_new(file).unwrap(); + let reader = FileReader::try_new(file, None).unwrap(); reader.for_each(|maybe_batch| { maybe_batch .unwrap() @@ -999,7 +999,7 @@ mod tests { )) .unwrap(); - let mut reader = FileReader::try_new(file).unwrap(); + let mut reader = FileReader::try_new(file, None).unwrap(); // read and rewrite the file to a temp location { @@ -1020,7 +1020,7 @@ mod tests { version, path )) .unwrap(); - let mut reader = FileReader::try_new(file).unwrap(); + let mut reader = FileReader::try_new(file, None).unwrap(); // read expected JSON output let arrow_json = read_gzip_json(version, path); @@ -1051,7 +1051,7 @@ mod tests { )) .unwrap(); - let reader = StreamReader::try_new(file).unwrap(); + let reader = StreamReader::try_new(file, None).unwrap(); // read and rewrite the stream to a temp location { @@ -1070,7 +1070,7 @@ mod tests { let file = File::open(format!("target/debug/testdata/{}-{}.stream", version, path)) .unwrap(); - let mut reader = StreamReader::try_new(file).unwrap(); + let mut reader = StreamReader::try_new(file, None).unwrap(); // read expected JSON output let arrow_json = read_gzip_json(version, path); @@ -1108,7 +1108,7 @@ mod tests { )) .unwrap(); - let mut reader = FileReader::try_new(file).unwrap(); + let mut reader = FileReader::try_new(file, None).unwrap(); // read and rewrite the file to a temp location { @@ -1134,7 +1134,7 @@ mod tests { version, path )) .unwrap(); - let mut reader = FileReader::try_new(file).unwrap(); + let mut reader = FileReader::try_new(file, None).unwrap(); // read expected JSON output let arrow_json = read_gzip_json(version, path); @@ -1172,7 +1172,7 @@ mod tests { )) .unwrap(); - let reader = StreamReader::try_new(file).unwrap(); + let reader = StreamReader::try_new(file, None).unwrap(); // read and rewrite the stream to a temp location { @@ -1195,7 +1195,7 @@ mod tests { let file = File::open(format!("target/debug/testdata/{}-{}.stream", version, path)) .unwrap(); - let mut reader = StreamReader::try_new(file).unwrap(); + let mut reader = StreamReader::try_new(file, None).unwrap(); // read expected JSON output let arrow_json = read_gzip_json(version, path); diff --git a/integration-testing/src/bin/arrow-file-to-stream.rs b/integration-testing/src/bin/arrow-file-to-stream.rs index fbc217aab858..e939fe4f0bf7 100644 --- a/integration-testing/src/bin/arrow-file-to-stream.rs +++ b/integration-testing/src/bin/arrow-file-to-stream.rs @@ -32,7 +32,7 @@ fn main() -> Result<()> { let args = Args::parse(); let f = File::open(&args.file_name)?; let reader = BufReader::new(f); - let mut reader = FileReader::try_new(reader)?; + let mut reader = FileReader::try_new(reader, None)?; let schema = reader.schema(); let mut writer = StreamWriter::try_new(io::stdout(), &schema)?; diff --git a/integration-testing/src/bin/arrow-json-integration-test.rs b/integration-testing/src/bin/arrow-json-integration-test.rs index 81200b99f982..17d2528e07ff 100644 --- a/integration-testing/src/bin/arrow-json-integration-test.rs +++ b/integration-testing/src/bin/arrow-json-integration-test.rs @@ -83,7 +83,7 @@ fn arrow_to_json(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> } let arrow_file = File::open(arrow_name)?; - let reader = FileReader::try_new(arrow_file)?; + let reader = FileReader::try_new(arrow_file, None)?; let mut fields: Vec = vec![]; for f in reader.schema().fields() { @@ -117,7 +117,7 @@ fn validate(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> { // open Arrow file let arrow_file = File::open(arrow_name)?; - let mut arrow_reader = FileReader::try_new(arrow_file)?; + let mut arrow_reader = FileReader::try_new(arrow_file, None)?; let arrow_schema = arrow_reader.schema().as_ref().to_owned(); // compare schemas diff --git a/integration-testing/src/bin/arrow-stream-to-file.rs b/integration-testing/src/bin/arrow-stream-to-file.rs index f81d42e6eda2..07ac5c7ddd42 100644 --- a/integration-testing/src/bin/arrow-stream-to-file.rs +++ b/integration-testing/src/bin/arrow-stream-to-file.rs @@ -22,7 +22,7 @@ use arrow::ipc::reader::StreamReader; use arrow::ipc::writer::FileWriter; fn main() -> Result<()> { - let mut arrow_stream_reader = StreamReader::try_new(io::stdin())?; + let mut arrow_stream_reader = StreamReader::try_new(io::stdin(), None)?; let schema = arrow_stream_reader.schema(); let mut writer = FileWriter::try_new(io::stdout(), &schema)?; diff --git a/integration-testing/src/flight_server_scenarios/integration_test.rs b/integration-testing/src/flight_server_scenarios/integration_test.rs index 9c9ebd8548f7..56d399923ade 100644 --- a/integration-testing/src/flight_server_scenarios/integration_test.rs +++ b/integration-testing/src/flight_server_scenarios/integration_test.rs @@ -295,6 +295,7 @@ async fn record_batch_from_message( ipc_batch, schema_ref, dictionaries_by_field, + None, ); arrow_batch_result.map_err(|e| {