Skip to content

Commit 2cd1706

Browse files
authored
Avoid panics on error while encoding/decoding ListValue::Array as protobuf (#7837)
1 parent e00932c commit 2cd1706

File tree

2 files changed

+37
-36
lines changed

2 files changed

+37
-36
lines changed

datafusion/proto/src/logical_plan/from_proto.rs

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use crate::protobuf::{
2626
OptimizedPhysicalPlanType, PlaceholderNode, RollupNode,
2727
};
2828
use arrow::{
29-
buffer::{Buffer, MutableBuffer},
29+
buffer::Buffer,
3030
datatypes::{
3131
i256, DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, TimeUnit,
3232
UnionFields, UnionMode,
@@ -645,6 +645,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
645645
Value::Float32Value(v) => Self::Float32(Some(*v)),
646646
Value::Float64Value(v) => Self::Float64(Some(*v)),
647647
Value::Date32Value(v) => Self::Date32(Some(*v)),
648+
// ScalarValue::List is serialized using arrow IPC format
648649
Value::ListValue(scalar_list) => {
649650
let protobuf::ScalarListValue {
650651
ipc_message,
@@ -655,29 +656,36 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
655656
let schema: Schema = if let Some(schema_ref) = schema {
656657
schema_ref.try_into()?
657658
} else {
658-
return Err(Error::General("Unexpected schema".to_string()));
659+
return Err(Error::General(
660+
"Invalid schema while deserializing ScalarValue::List"
661+
.to_string(),
662+
));
659663
};
660664

661-
let message = root_as_message(ipc_message.as_slice()).unwrap();
665+
let message = root_as_message(ipc_message.as_slice()).map_err(|e| {
666+
Error::General(format!(
667+
"Error IPC message while deserializing ScalarValue::List: {e}"
668+
))
669+
})?;
670+
let buffer = Buffer::from(arrow_data);
662671

663-
// TODO: Add comment to why adding 0 before arrow_data.
664-
// This code is from https://github.com/apache/arrow-rs/blob/4320a753beaee0a1a6870c59ef46b59e88c9c323/arrow-ipc/src/reader.rs#L1670-L1674C45
665-
// Construct an unaligned buffer
666-
let mut buffer = MutableBuffer::with_capacity(arrow_data.len() + 1);
667-
buffer.push(0_u8);
668-
buffer.extend_from_slice(arrow_data.as_slice());
669-
let b = Buffer::from(buffer).slice(1);
672+
let ipc_batch = message.header_as_record_batch().ok_or_else(|| {
673+
Error::General(
674+
"Unexpected message type deserializing ScalarValue::List"
675+
.to_string(),
676+
)
677+
})?;
670678

671-
let ipc_batch = message.header_as_record_batch().unwrap();
672679
let record_batch = read_record_batch(
673-
&b,
680+
&buffer,
674681
ipc_batch,
675682
Arc::new(schema),
676683
&Default::default(),
677684
None,
678685
&message.version(),
679686
)
680-
.unwrap();
687+
.map_err(DataFusionError::ArrowError)
688+
.map_err(|e| e.context("Decoding ScalarValue::List Value"))?;
681689
let arr = record_batch.column(0);
682690
Self::List(arr.to_owned())
683691
}

datafusion/proto/src/logical_plan/to_proto.rs

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,6 @@ use datafusion_expr::{
5656
pub enum Error {
5757
General(String),
5858

59-
InconsistentListTyping(DataType, DataType),
60-
61-
InconsistentListDesignated {
62-
value: ScalarValue,
63-
designated: DataType,
64-
},
65-
6659
InvalidScalarValue(ScalarValue),
6760

6861
InvalidScalarType(DataType),
@@ -80,18 +73,6 @@ impl std::fmt::Display for Error {
8073
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
8174
match self {
8275
Self::General(desc) => write!(f, "General error: {desc}"),
83-
Self::InconsistentListTyping(type1, type2) => {
84-
write!(
85-
f,
86-
"Lists with inconsistent typing; {type1:?} and {type2:?} found within list",
87-
)
88-
}
89-
Self::InconsistentListDesignated { value, designated } => {
90-
write!(
91-
f,
92-
"Value {value:?} was inconsistent with designated type {designated:?}"
93-
)
94-
}
9576
Self::InvalidScalarValue(value) => {
9677
write!(f, "{value:?} is invalid as a DataFusion scalar value")
9778
}
@@ -1145,15 +1126,27 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue {
11451126
"Proto serialization error: ScalarValue::Fixedsizelist not supported"
11461127
.to_string(),
11471128
)),
1129+
// ScalarValue::List is serialized using Arrow IPC messages.
1130+
// as a single column RecordBatch
11481131
ScalarValue::List(arr) => {
1149-
let batch =
1150-
RecordBatch::try_from_iter(vec![("field_name", arr.to_owned())])
1151-
.unwrap();
1132+
// Wrap in a "field_name" column
1133+
let batch = RecordBatch::try_from_iter(vec![(
1134+
"field_name",
1135+
arr.to_owned(),
1136+
)])
1137+
.map_err(|e| {
1138+
Error::General( format!("Error creating temporary batch while encoding ScalarValue::List: {e}"))
1139+
})?;
1140+
11521141
let gen = IpcDataGenerator {};
11531142
let mut dict_tracker = DictionaryTracker::new(false);
11541143
let (_, encoded_message) = gen
11551144
.encoded_batch(&batch, &mut dict_tracker, &Default::default())
1156-
.unwrap();
1145+
.map_err(|e| {
1146+
Error::General(format!(
1147+
"Error encoding ScalarValue::List as IPC: {e}"
1148+
))
1149+
})?;
11571150

11581151
let schema: protobuf::Schema = batch.schema().try_into()?;
11591152

0 commit comments

Comments
 (0)