Skip to content

Commit 0b10ad8

Browse files
Reuse zstd compression context when writing IPC (#8405)
# Which issue does this PR close? - Closes #8386 # Rationale for this change Reusing the zstd context between subsequent calls to compress_zstd in the Arrow IPC writer for performance improvement. Benchmark results: ``` arrow_ipc_stream_writer/StreamWriter/write_10/zstd time: [4.0972 ms 4.1038 ms 4.1110 ms] change: [-53.848% -53.586% -53.335%] (p = 0.00 < 0.05) Performance has improved. ``` # What changes are included in this PR? Adds a `CompressionContext` struct, which when the zstd feature is enabled contains a zstd::bulk::Compressor object. This context object is owned by the ipc `StreamWriter`/`FileWriter` objects and is passed by mutable reference through the `IpcDataGenerator` to the `CompressionCodec` where it is used when compressing the ipc bytes. # Are these changes tested? Yes the existing unit tests cover the changed code paths # Are there any user-facing changes? Yes, the method `IpcDataGenerator::encoded_batch` now takes `&mut CompressionContext` as an argument.
1 parent 48686c8 commit 0b10ad8

File tree

7 files changed

+206
-36
lines changed

7 files changed

+206
-36
lines changed

arrow-flight/src/encode.rs

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use std::{collections::VecDeque, fmt::Debug, pin::Pin, sync::Arc, task::Poll};
2020
use crate::{error::Result, FlightData, FlightDescriptor, SchemaAsIpc};
2121

2222
use arrow_array::{Array, ArrayRef, RecordBatch, RecordBatchOptions, UnionArray};
23-
use arrow_ipc::writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions};
23+
use arrow_ipc::writer::{CompressionContext, DictionaryTracker, IpcDataGenerator, IpcWriteOptions};
2424

2525
use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaRef, UnionMode};
2626
use bytes::Bytes;
@@ -647,6 +647,7 @@ struct FlightIpcEncoder {
647647
options: IpcWriteOptions,
648648
data_gen: IpcDataGenerator,
649649
dictionary_tracker: DictionaryTracker,
650+
compression_context: CompressionContext,
650651
}
651652

652653
impl FlightIpcEncoder {
@@ -655,6 +656,7 @@ impl FlightIpcEncoder {
655656
options,
656657
data_gen: IpcDataGenerator::default(),
657658
dictionary_tracker: DictionaryTracker::new(error_on_replacement),
659+
compression_context: CompressionContext::default(),
658660
}
659661
}
660662

@@ -666,9 +668,12 @@ impl FlightIpcEncoder {
666668
/// Convert a `RecordBatch` to a Vec of `FlightData` representing
667669
/// dictionaries and a `FlightData` representing the batch
668670
fn encode_batch(&mut self, batch: &RecordBatch) -> Result<(Vec<FlightData>, FlightData)> {
669-
let (encoded_dictionaries, encoded_batch) =
670-
self.data_gen
671-
.encoded_batch(batch, &mut self.dictionary_tracker, &self.options)?;
671+
let (encoded_dictionaries, encoded_batch) = self.data_gen.encode(
672+
batch,
673+
&mut self.dictionary_tracker,
674+
&self.options,
675+
&mut self.compression_context,
676+
)?;
672677

673678
let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect();
674679
let flight_batch = encoded_batch.into();
@@ -1596,9 +1601,15 @@ mod tests {
15961601
) -> (Vec<FlightData>, FlightData) {
15971602
let data_gen = IpcDataGenerator::default();
15981603
let mut dictionary_tracker = DictionaryTracker::new(false);
1604+
let mut compression_context = CompressionContext::default();
15991605

16001606
let (encoded_dictionaries, encoded_batch) = data_gen
1601-
.encoded_batch(batch, &mut dictionary_tracker, options)
1607+
.encode(
1608+
batch,
1609+
&mut dictionary_tracker,
1610+
options,
1611+
&mut compression_context,
1612+
)
16021613
.expect("DictionaryTracker configured above to not error on replacement");
16031614

16041615
let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect();

arrow-flight/src/utils.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use std::sync::Arc;
2424
use arrow_array::{ArrayRef, RecordBatch};
2525
use arrow_buffer::Buffer;
2626
use arrow_ipc::convert::fb_to_schema;
27+
use arrow_ipc::writer::CompressionContext;
2728
use arrow_ipc::{reader, root_as_message, writer, writer::IpcWriteOptions};
2829
use arrow_schema::{ArrowError, Schema, SchemaRef};
2930

@@ -91,10 +92,15 @@ pub fn batches_to_flight_data(
9192

9293
let data_gen = writer::IpcDataGenerator::default();
9394
let mut dictionary_tracker = writer::DictionaryTracker::new(false);
95+
let mut compression_context = CompressionContext::default();
9496

9597
for batch in batches.iter() {
96-
let (encoded_dictionaries, encoded_batch) =
97-
data_gen.encoded_batch(batch, &mut dictionary_tracker, &options)?;
98+
let (encoded_dictionaries, encoded_batch) = data_gen.encode(
99+
batch,
100+
&mut dictionary_tracker,
101+
&options,
102+
&mut compression_context,
103+
)?;
98104

99105
dictionaries.extend(encoded_dictionaries.into_iter().map(Into::into));
100106
flight_data.push(encoded_batch.into());

arrow-integration-testing/src/flight_client_scenarios/integration_test.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@ use arrow::{
2424
array::ArrayRef,
2525
buffer::Buffer,
2626
datatypes::SchemaRef,
27-
ipc::{self, reader, writer},
27+
ipc::{
28+
self, reader,
29+
writer::{self, CompressionContext},
30+
},
2831
record_batch::RecordBatch,
2932
};
3033
use arrow_flight::{
@@ -92,6 +95,8 @@ async fn upload_data(
9295

9396
let mut original_data_iter = original_data.iter().enumerate();
9497

98+
let mut compression_context = CompressionContext::default();
99+
95100
if let Some((counter, first_batch)) = original_data_iter.next() {
96101
let metadata = counter.to_string().into_bytes();
97102
// Preload the first batch into the channel before starting the request
@@ -101,6 +106,7 @@ async fn upload_data(
101106
first_batch,
102107
&options,
103108
&mut dict_tracker,
109+
&mut compression_context,
104110
)
105111
.await?;
106112

@@ -123,6 +129,7 @@ async fn upload_data(
123129
batch,
124130
&options,
125131
&mut dict_tracker,
132+
&mut compression_context,
126133
)
127134
.await?;
128135

@@ -152,11 +159,12 @@ async fn send_batch(
152159
batch: &RecordBatch,
153160
options: &writer::IpcWriteOptions,
154161
dictionary_tracker: &mut writer::DictionaryTracker,
162+
compression_context: &mut CompressionContext,
155163
) -> Result {
156164
let data_gen = writer::IpcDataGenerator::default();
157165

158166
let (encoded_dictionaries, encoded_batch) = data_gen
159-
.encoded_batch(batch, dictionary_tracker, options)
167+
.encode(batch, dictionary_tracker, options, compression_context)
160168
.expect("DictionaryTracker configured above to not error on replacement");
161169

162170
let dictionary_flight_data: Vec<FlightData> =

arrow-integration-testing/src/flight_server_scenarios/integration_test.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,12 @@ impl FlightService for FlightServiceImpl {
144144
.enumerate()
145145
.flat_map(|(counter, batch)| {
146146
let (encoded_dictionaries, encoded_batch) = data_gen
147-
.encoded_batch(batch, &mut dictionary_tracker, &options)
147+
.encode(
148+
batch,
149+
&mut dictionary_tracker,
150+
&options,
151+
&mut Default::default(),
152+
)
148153
.expect("DictionaryTracker configured above to not error on replacement");
149154

150155
let dictionary_flight_data = encoded_dictionaries.into_iter().map(Into::into);

arrow-ipc/src/compression.rs

Lines changed: 62 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,41 @@ use arrow_schema::ArrowError;
2222
const LENGTH_NO_COMPRESSED_DATA: i64 = -1;
2323
const LENGTH_OF_PREFIX_DATA: i64 = 8;
2424

25+
/// Additional context that may be needed for compression.
26+
///
27+
/// In the case of zstd, this will contain the zstd context, which can be reused between subsequent
28+
/// compression calls to avoid the performance overhead of initialising a new context for every
29+
/// compression.
30+
pub struct CompressionContext {
31+
#[cfg(feature = "zstd")]
32+
compressor: zstd::bulk::Compressor<'static>,
33+
}
34+
35+
// the reason we allow derivable_impls here is because when zstd feature is not enabled, this
36+
// becomes derivable. however with zstd feature want to be explicit about the compression level.
37+
#[allow(clippy::derivable_impls)]
38+
impl Default for CompressionContext {
39+
fn default() -> Self {
40+
CompressionContext {
41+
// safety: `new` here will only return error here if using an invalid compression level
42+
#[cfg(feature = "zstd")]
43+
compressor: zstd::bulk::Compressor::new(zstd::DEFAULT_COMPRESSION_LEVEL)
44+
.expect("can use default compression level"),
45+
}
46+
}
47+
}
48+
49+
impl std::fmt::Debug for CompressionContext {
50+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51+
let mut ds = f.debug_struct("CompressionContext");
52+
53+
#[cfg(feature = "zstd")]
54+
ds.field("compressor", &"zstd::bulk::Compressor");
55+
56+
ds.finish()
57+
}
58+
}
59+
2560
/// Represents compressing a ipc stream using a particular compression algorithm
2661
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2762
pub enum CompressionCodec {
@@ -58,6 +93,7 @@ impl CompressionCodec {
5893
&self,
5994
input: &[u8],
6095
output: &mut Vec<u8>,
96+
context: &mut CompressionContext,
6197
) -> Result<usize, ArrowError> {
6298
let uncompressed_data_len = input.len();
6399
let original_output_len = output.len();
@@ -67,7 +103,7 @@ impl CompressionCodec {
67103
} else {
68104
// write compressed data directly into the output buffer
69105
output.extend_from_slice(&uncompressed_data_len.to_le_bytes());
70-
self.compress(input, output)?;
106+
self.compress(input, output, context)?;
71107

72108
let compression_len = output.len() - original_output_len;
73109
if compression_len > uncompressed_data_len {
@@ -115,10 +151,15 @@ impl CompressionCodec {
115151

116152
/// Compress the data in input buffer and write to output buffer
117153
/// using the specified compression
118-
fn compress(&self, input: &[u8], output: &mut Vec<u8>) -> Result<(), ArrowError> {
154+
fn compress(
155+
&self,
156+
input: &[u8],
157+
output: &mut Vec<u8>,
158+
context: &mut CompressionContext,
159+
) -> Result<(), ArrowError> {
119160
match self {
120161
CompressionCodec::Lz4Frame => compress_lz4(input, output),
121-
CompressionCodec::Zstd => compress_zstd(input, output),
162+
CompressionCodec::Zstd => compress_zstd(input, output, context),
122163
}
123164
}
124165

@@ -175,17 +216,23 @@ fn decompress_lz4(_input: &[u8], _decompressed_size: usize) -> Result<Vec<u8>, A
175216
}
176217

177218
#[cfg(feature = "zstd")]
178-
fn compress_zstd(input: &[u8], output: &mut Vec<u8>) -> Result<(), ArrowError> {
179-
use std::io::Write;
180-
let mut encoder = zstd::Encoder::new(output, 0)?;
181-
encoder.write_all(input)?;
182-
encoder.finish()?;
219+
fn compress_zstd(
220+
input: &[u8],
221+
output: &mut Vec<u8>,
222+
context: &mut CompressionContext,
223+
) -> Result<(), ArrowError> {
224+
let result = context.compressor.compress(input)?;
225+
output.extend_from_slice(&result);
183226
Ok(())
184227
}
185228

186229
#[cfg(not(feature = "zstd"))]
187230
#[allow(clippy::ptr_arg)]
188-
fn compress_zstd(_input: &[u8], _output: &mut Vec<u8>) -> Result<(), ArrowError> {
231+
fn compress_zstd(
232+
_input: &[u8],
233+
_output: &mut Vec<u8>,
234+
_context: &mut CompressionContext,
235+
) -> Result<(), ArrowError> {
189236
Err(ArrowError::InvalidArgumentError(
190237
"zstd IPC compression requires the zstd feature".to_string(),
191238
))
@@ -227,7 +274,9 @@ mod tests {
227274
let input_bytes = b"hello lz4";
228275
let codec = super::CompressionCodec::Lz4Frame;
229276
let mut output_bytes: Vec<u8> = Vec::new();
230-
codec.compress(input_bytes, &mut output_bytes).unwrap();
277+
codec
278+
.compress(input_bytes, &mut output_bytes, &mut Default::default())
279+
.unwrap();
231280
let result = codec
232281
.decompress(output_bytes.as_slice(), input_bytes.len())
233282
.unwrap();
@@ -240,7 +289,9 @@ mod tests {
240289
let input_bytes = b"hello zstd";
241290
let codec = super::CompressionCodec::Zstd;
242291
let mut output_bytes: Vec<u8> = Vec::new();
243-
codec.compress(input_bytes, &mut output_bytes).unwrap();
292+
codec
293+
.compress(input_bytes, &mut output_bytes, &mut Default::default())
294+
.unwrap();
244295
let result = codec
245296
.decompress(output_bytes.as_slice(), input_bytes.len())
246297
.unwrap();

arrow-ipc/src/reader.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2702,7 +2702,12 @@ mod tests {
27022702
let gen = IpcDataGenerator {};
27032703
let mut dict_tracker = DictionaryTracker::new(false);
27042704
let (_, encoded) = gen
2705-
.encoded_batch(&batch, &mut dict_tracker, &Default::default())
2705+
.encode(
2706+
&batch,
2707+
&mut dict_tracker,
2708+
&Default::default(),
2709+
&mut Default::default(),
2710+
)
27062711
.unwrap();
27072712

27082713
let message = root_as_message(&encoded.ipc_message).unwrap();
@@ -2740,7 +2745,12 @@ mod tests {
27402745
let gen = IpcDataGenerator {};
27412746
let mut dict_tracker = DictionaryTracker::new(false);
27422747
let (_, encoded) = gen
2743-
.encoded_batch(&batch, &mut dict_tracker, &Default::default())
2748+
.encode(
2749+
&batch,
2750+
&mut dict_tracker,
2751+
&Default::default(),
2752+
&mut Default::default(),
2753+
)
27442754
.unwrap();
27452755

27462756
let message = root_as_message(&encoded.ipc_message).unwrap();

0 commit comments

Comments
 (0)