diff --git a/arrow-flight/src/error.rs b/arrow-flight/src/error.rs index d5ac568e9788..d22c24eea6d4 100644 --- a/arrow-flight/src/error.rs +++ b/arrow-flight/src/error.rs @@ -78,6 +78,12 @@ impl From for FlightError { } } +impl From for FlightError { + fn from(error: prost::DecodeError) -> Self { + Self::DecodeError(error.to_string()) + } +} + impl From for FlightError { fn from(value: ArrowError) -> Self { Self::Arrow(value) diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs index 5009ae5ea50a..5defa9e7ae18 100644 --- a/arrow-flight/src/sql/client.rs +++ b/arrow-flight/src/sql/client.rs @@ -27,6 +27,7 @@ use tonic::metadata::AsciiMetadataKey; use crate::decode::FlightRecordBatchStream; use crate::encode::FlightDataEncoderBuilder; use crate::error::FlightError; +use crate::error::Result; use crate::flight_service_client::FlightServiceClient; use crate::sql::r#gen::action_end_transaction_request::EndTransaction; use crate::sql::server::{ @@ -49,11 +50,7 @@ use crate::{ IpcMessage, PutResult, Ticket, }; use arrow_array::RecordBatch; -use arrow_buffer::Buffer; -use arrow_ipc::convert::fb_to_schema; -use arrow_ipc::reader::read_record_batch; -use arrow_ipc::{MessageHeader, root_as_message}; -use arrow_schema::{ArrowError, Schema, SchemaRef}; +use arrow_schema::{ArrowError, Schema}; use futures::{Stream, TryStreamExt, stream}; use prost::Message; use tonic::transport::Channel; @@ -126,15 +123,10 @@ impl FlightSqlServiceClient { async fn get_flight_info_for_command( &mut self, cmd: M, - ) -> Result { + ) -> Result { let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec()); let req = self.set_request_headers(descriptor.into_request())?; - let fi = self - .flight_client - .get_flight_info(req) - .await - .map_err(status_to_arrow_error)? - .into_inner(); + let fi = self.flight_client.get_flight_info(req).await?.into_inner(); Ok(fi) } @@ -143,7 +135,7 @@ impl FlightSqlServiceClient { &mut self, query: String, transaction_id: Option, - ) -> Result { + ) -> Result { let cmd = CommandStatementQuery { query, transaction_id, @@ -156,7 +148,7 @@ impl FlightSqlServiceClient { /// If the server returns an "authorization" header, it is automatically parsed and set as /// a token for future requests. Any other data returned by the server in the handshake /// response is returned as a binary blob. - pub async fn handshake(&mut self, username: &str, password: &str) -> Result { + pub async fn handshake(&mut self, username: &str, password: &str) -> Result { let cmd = HandshakeRequest { protocol_version: 0, payload: Default::default(), @@ -168,18 +160,14 @@ impl FlightSqlServiceClient { .map_err(|_| ArrowError::ParseError("Cannot parse header".to_string()))?; req.metadata_mut().insert("authorization", val); let req = self.set_request_headers(req)?; - let resp = self - .flight_client - .handshake(req) - .await - .map_err(|e| ArrowError::IpcError(format!("Can't handshake {e}")))?; + let resp = self.flight_client.handshake(req).await?; if let Some(auth) = resp.metadata().get("authorization") { let auth = auth .to_str() .map_err(|_| ArrowError::ParseError("Can't read auth header".to_string()))?; let bearer = "Bearer "; if !auth.starts_with(bearer) { - Err(ArrowError::ParseError("Invalid auth header!".to_string()))?; + return Err(ArrowError::ParseError("Invalid auth header!".to_string()))?; } let auth = auth[bearer.len()..].to_string(); self.token = Some(auth); @@ -204,7 +192,7 @@ impl FlightSqlServiceClient { &mut self, query: String, transaction_id: Option, - ) -> Result { + ) -> Result { let cmd = CommandStatementUpdate { query, transaction_id, @@ -217,19 +205,9 @@ impl FlightSqlServiceClient { }]) .into_request(), )?; - let mut result = self - .flight_client - .do_put(req) - .await - .map_err(status_to_arrow_error)? - .into_inner(); - let result = result - .message() - .await - .map_err(status_to_arrow_error)? - .unwrap(); - let result: DoPutUpdateResult = - Message::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?; + let mut result = self.flight_client.do_put(req).await?.into_inner(); + let result = result.message().await?.unwrap(); + let result: DoPutUpdateResult = Message::decode(&*result.app_metadata)?; Ok(result.record_count) } @@ -238,7 +216,7 @@ impl FlightSqlServiceClient { &mut self, command: CommandStatementIngest, stream: S, - ) -> Result + ) -> Result where S: Stream> + Send + 'static, { @@ -255,41 +233,28 @@ impl FlightSqlServiceClient { FallibleRequestStream::new(sender, flight_data); let req = self.set_request_headers(flight_data.into_streaming_request())?; - let mut result = self - .flight_client - .do_put(req) - .await - .map_err(status_to_arrow_error)? - .into_inner(); + let mut result = self.flight_client.do_put(req).await?.into_inner(); // check if the there were any errors in the input stream provided note // if receiver.await fails, it means the sender was dropped and there is // no message to return. if let Ok(msg) = receiver.await { - return Err(ArrowError::ExternalError(Box::new(msg))); + return Err(FlightError::ExternalError(Box::new(msg))); } - let result = result - .message() - .await - .map_err(status_to_arrow_error)? - .unwrap(); - let result: DoPutUpdateResult = - Message::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?; + let result = result.message().await?.unwrap(); + let result: DoPutUpdateResult = Message::decode(&*result.app_metadata)?; Ok(result.record_count) } /// Request a list of catalogs as tabular FlightInfo results - pub async fn get_catalogs(&mut self) -> Result { + pub async fn get_catalogs(&mut self) -> Result { self.get_flight_info_for_command(CommandGetCatalogs {}) .await } /// Request a list of database schemas as tabular FlightInfo results - pub async fn get_db_schemas( - &mut self, - request: CommandGetDbSchemas, - ) -> Result { + pub async fn get_db_schemas(&mut self, request: CommandGetDbSchemas) -> Result { self.get_flight_info_for_command(request).await } @@ -297,15 +262,10 @@ impl FlightSqlServiceClient { pub async fn do_get( &mut self, ticket: impl IntoRequest, - ) -> Result { + ) -> Result { let req = self.set_request_headers(ticket.into_request())?; - let (md, response_stream, _ext) = self - .flight_client - .do_get(req) - .await - .map_err(status_to_arrow_error)? - .into_parts(); + let (md, response_stream, _ext) = self.flight_client.do_get(req).await?.into_parts(); let (response_stream, trailers) = extract_lazy_trailers(response_stream); Ok(FlightRecordBatchStream::new_from_flight_data( @@ -319,43 +279,27 @@ impl FlightSqlServiceClient { pub async fn do_put( &mut self, request: impl tonic::IntoStreamingRequest, - ) -> Result, ArrowError> { + ) -> Result> { let req = self.set_request_headers(request.into_streaming_request())?; - Ok(self - .flight_client - .do_put(req) - .await - .map_err(status_to_arrow_error)? - .into_inner()) + Ok(self.flight_client.do_put(req).await?.into_inner()) } /// DoAction allows a flight client to do a specific action against a flight service pub async fn do_action( &mut self, request: impl IntoRequest, - ) -> Result, ArrowError> { + ) -> Result> { let req = self.set_request_headers(request.into_request())?; - Ok(self - .flight_client - .do_action(req) - .await - .map_err(status_to_arrow_error)? - .into_inner()) + Ok(self.flight_client.do_action(req).await?.into_inner()) } /// Request a list of tables. - pub async fn get_tables( - &mut self, - request: CommandGetTables, - ) -> Result { + pub async fn get_tables(&mut self, request: CommandGetTables) -> Result { self.get_flight_info_for_command(request).await } /// Request the primary keys for a table. - pub async fn get_primary_keys( - &mut self, - request: CommandGetPrimaryKeys, - ) -> Result { + pub async fn get_primary_keys(&mut self, request: CommandGetPrimaryKeys) -> Result { self.get_flight_info_for_command(request).await } @@ -364,7 +308,7 @@ impl FlightSqlServiceClient { pub async fn get_exported_keys( &mut self, request: CommandGetExportedKeys, - ) -> Result { + ) -> Result { self.get_flight_info_for_command(request).await } @@ -372,7 +316,7 @@ impl FlightSqlServiceClient { pub async fn get_imported_keys( &mut self, request: CommandGetImportedKeys, - ) -> Result { + ) -> Result { self.get_flight_info_for_command(request).await } @@ -382,21 +326,18 @@ impl FlightSqlServiceClient { pub async fn get_cross_reference( &mut self, request: CommandGetCrossReference, - ) -> Result { + ) -> Result { self.get_flight_info_for_command(request).await } /// Request a list of table types. - pub async fn get_table_types(&mut self) -> Result { + pub async fn get_table_types(&mut self) -> Result { self.get_flight_info_for_command(CommandGetTableTypes {}) .await } /// Request a list of SQL information. - pub async fn get_sql_info( - &mut self, - sql_infos: Vec, - ) -> Result { + pub async fn get_sql_info(&mut self, sql_infos: Vec) -> Result { let request = CommandGetSqlInfo { info: sql_infos.iter().map(|sql_info| *sql_info as u32).collect(), }; @@ -407,7 +348,7 @@ impl FlightSqlServiceClient { pub async fn get_xdbc_type_info( &mut self, request: CommandGetXdbcTypeInfo, - ) -> Result { + ) -> Result { self.get_flight_info_for_command(request).await } @@ -416,7 +357,7 @@ impl FlightSqlServiceClient { &mut self, query: String, transaction_id: Option, - ) -> Result, ArrowError> { + ) -> Result> { let cmd = ActionCreatePreparedStatementRequest { query, transaction_id, @@ -426,18 +367,9 @@ impl FlightSqlServiceClient { body: cmd.as_any().encode_to_vec().into(), }; let req = self.set_request_headers(action.into_request())?; - let mut result = self - .flight_client - .do_action(req) - .await - .map_err(status_to_arrow_error)? - .into_inner(); - let result = result - .message() - .await - .map_err(status_to_arrow_error)? - .unwrap(); - let any = Any::decode(&*result.body).map_err(decode_error_to_arrow_error)?; + let mut result = self.flight_client.do_action(req).await?.into_inner(); + let result = result.message().await?.unwrap(); + let any = Any::decode(&*result.body)?; let prepared_result: ActionCreatePreparedStatementResult = any.unpack()?.unwrap(); let dataset_schema = match prepared_result.dataset_schema.len() { 0 => Schema::empty(), @@ -456,25 +388,16 @@ impl FlightSqlServiceClient { } /// Request to begin a transaction. - pub async fn begin_transaction(&mut self) -> Result { + pub async fn begin_transaction(&mut self) -> Result { let cmd = ActionBeginTransactionRequest {}; let action = Action { r#type: BEGIN_TRANSACTION.to_string(), body: cmd.as_any().encode_to_vec().into(), }; let req = self.set_request_headers(action.into_request())?; - let mut result = self - .flight_client - .do_action(req) - .await - .map_err(status_to_arrow_error)? - .into_inner(); - let result = result - .message() - .await - .map_err(status_to_arrow_error)? - .unwrap(); - let any = Any::decode(&*result.body).map_err(decode_error_to_arrow_error)?; + let mut result = self.flight_client.do_action(req).await?.into_inner(); + let result = result.message().await?.unwrap(); + let any = Any::decode(&*result.body)?; let begin_result: ActionBeginTransactionResult = any.unpack()?.unwrap(); Ok(begin_result.transaction_id) } @@ -484,7 +407,7 @@ impl FlightSqlServiceClient { &mut self, transaction_id: Bytes, action: EndTransaction, - ) -> Result<(), ArrowError> { + ) -> Result<()> { let cmd = ActionEndTransactionRequest { transaction_id, action: action as i32, @@ -494,25 +417,17 @@ impl FlightSqlServiceClient { body: cmd.as_any().encode_to_vec().into(), }; let req = self.set_request_headers(action.into_request())?; - let _ = self - .flight_client - .do_action(req) - .await - .map_err(status_to_arrow_error)? - .into_inner(); + let _ = self.flight_client.do_action(req).await?.into_inner(); Ok(()) } /// Explicitly shut down and clean up the client. - pub async fn close(&mut self) -> Result<(), ArrowError> { + pub async fn close(&mut self) -> Result<()> { // TODO: consume self instead of &mut self to explicitly prevent reuse? Ok(()) } - fn set_request_headers( - &self, - mut req: tonic::Request, - ) -> Result, ArrowError> { + fn set_request_headers(&self, mut req: tonic::Request) -> Result> { for (k, v) in &self.headers { let k = AsciiMetadataKey::from_str(k.as_str()).map_err(|e| { ArrowError::ParseError(format!("Cannot convert header key \"{k}\": {e}")) @@ -559,7 +474,7 @@ impl PreparedStatement { } /// Executes the prepared statement query on the server. - pub async fn execute(&mut self) -> Result { + pub async fn execute(&mut self) -> Result { self.write_bind_params().await?; let cmd = CommandPreparedStatementQuery { @@ -574,7 +489,7 @@ impl PreparedStatement { } /// Executes the prepared statement update query on the server. - pub async fn execute_update(&mut self) -> Result { + pub async fn execute_update(&mut self) -> Result { self.write_bind_params().await?; let cmd = CommandPreparedStatementUpdate { @@ -588,35 +503,30 @@ impl PreparedStatement { ..Default::default() }])) .await?; - let result = result - .message() - .await - .map_err(status_to_arrow_error)? - .unwrap(); - let result: DoPutUpdateResult = - Message::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?; + let result = result.message().await?.unwrap(); + let result: DoPutUpdateResult = Message::decode(&*result.app_metadata)?; Ok(result.record_count) } /// Retrieve the parameter schema from the query. - pub fn parameter_schema(&self) -> Result<&Schema, ArrowError> { + pub fn parameter_schema(&self) -> Result<&Schema> { Ok(&self.parameter_schema) } /// Retrieve the ResultSet schema from the query. - pub fn dataset_schema(&self) -> Result<&Schema, ArrowError> { + pub fn dataset_schema(&self) -> Result<&Schema> { Ok(&self.dataset_schema) } /// Set a RecordBatch that contains the parameters that will be bind. - pub fn set_parameters(&mut self, parameter_binding: RecordBatch) -> Result<(), ArrowError> { + pub fn set_parameters(&mut self, parameter_binding: RecordBatch) -> Result<()> { self.parameter_binding = Some(parameter_binding); Ok(()) } /// Submit parameters to the server, if any have been set on this prepared statement instance /// Updates our stored prepared statement handle with the handle given by the server response. - async fn write_bind_params(&mut self) -> Result<(), ArrowError> { + async fn write_bind_params(&mut self) -> Result<()> { if let Some(ref params_batch) = self.parameter_binding { let cmd = CommandPreparedStatementQuery { prepared_statement_handle: self.handle.clone(), @@ -631,8 +541,7 @@ impl PreparedStatement { self.parameter_binding.clone().map(Ok), )) .try_collect::>() - .await - .map_err(flight_error_to_arrow_error)?; + .await?; // Attempt to update the stored handle with any updated handle in the DoPut result. // Older servers do not respond with a result for DoPut, so skip this step when @@ -642,8 +551,7 @@ impl PreparedStatement { .do_put(stream::iter(flight_data)) .await? .message() - .await - .map_err(status_to_arrow_error)? + .await? { if let Some(handle) = self.unpack_prepared_statement_handle(&result)? { self.handle = handle; @@ -656,18 +564,14 @@ impl PreparedStatement { /// Decodes the app_metadata stored in a [`PutResult`] as a /// [`DoPutPreparedStatementResult`] and then returns /// the inner prepared statement handle as [`Bytes`] - fn unpack_prepared_statement_handle( - &self, - put_result: &PutResult, - ) -> Result, ArrowError> { - let result: DoPutPreparedStatementResult = - Message::decode(&*put_result.app_metadata).map_err(decode_error_to_arrow_error)?; + fn unpack_prepared_statement_handle(&self, put_result: &PutResult) -> Result> { + let result: DoPutPreparedStatementResult = Message::decode(&*put_result.app_metadata)?; Ok(result.prepared_statement_handle) } /// Close the prepared statement, so that this PreparedStatement can not used /// anymore and server can free up any resources. - pub async fn close(mut self) -> Result<(), ArrowError> { + pub async fn close(mut self) -> Result<()> { let cmd = ActionClosePreparedStatementRequest { prepared_statement_handle: self.handle.clone(), }; @@ -680,21 +584,6 @@ impl PreparedStatement { } } -fn decode_error_to_arrow_error(err: prost::DecodeError) -> ArrowError { - ArrowError::IpcError(err.to_string()) -} - -fn status_to_arrow_error(status: tonic::Status) -> ArrowError { - ArrowError::IpcError(format!("{status:?}")) -} - -fn flight_error_to_arrow_error(err: FlightError) -> ArrowError { - match err { - FlightError::Arrow(e) => e, - e => ArrowError::ExternalError(Box::new(e)), - } -} - /// A polymorphic structure to natively represent different types of data contained in `FlightData` pub enum ArrowFlightData { /// A record batch @@ -702,77 +591,3 @@ pub enum ArrowFlightData { /// A schema Schema(Schema), } - -/// Extract `Schema` or `RecordBatch`es from the `FlightData` wire representation -pub fn arrow_data_from_flight_data( - flight_data: FlightData, - arrow_schema_ref: &SchemaRef, -) -> Result { - let ipc_message = root_as_message(&flight_data.data_header[..]) - .map_err(|err| ArrowError::ParseError(format!("Unable to get root as message: {err:?}")))?; - - match ipc_message.header_type() { - MessageHeader::RecordBatch => { - let ipc_record_batch = ipc_message.header_as_record_batch().ok_or_else(|| { - ArrowError::ComputeError( - "Unable to convert flight data header to a record batch".to_string(), - ) - })?; - - let dictionaries_by_field = HashMap::new(); - let record_batch = read_record_batch( - &Buffer::from(flight_data.data_body), - ipc_record_batch, - arrow_schema_ref.clone(), - &dictionaries_by_field, - None, - &ipc_message.version(), - )?; - Ok(ArrowFlightData::RecordBatch(record_batch)) - } - MessageHeader::Schema => { - let ipc_schema = ipc_message.header_as_schema().ok_or_else(|| { - ArrowError::ComputeError( - "Unable to convert flight data header to a schema".to_string(), - ) - })?; - - let arrow_schema = fb_to_schema(ipc_schema); - Ok(ArrowFlightData::Schema(arrow_schema)) - } - MessageHeader::DictionaryBatch => { - let _ = ipc_message.header_as_dictionary_batch().ok_or_else(|| { - ArrowError::ComputeError( - "Unable to convert flight data header to a dictionary batch".to_string(), - ) - })?; - Err(ArrowError::NotYetImplemented( - "no idea on how to convert an ipc dictionary batch to an arrow type".to_string(), - )) - } - MessageHeader::Tensor => { - let _ = ipc_message.header_as_tensor().ok_or_else(|| { - ArrowError::ComputeError( - "Unable to convert flight data header to a tensor".to_string(), - ) - })?; - Err(ArrowError::NotYetImplemented( - "no idea on how to convert an ipc tensor to an arrow type".to_string(), - )) - } - MessageHeader::SparseTensor => { - let _ = ipc_message.header_as_sparse_tensor().ok_or_else(|| { - ArrowError::ComputeError( - "Unable to convert flight data header to a sparse tensor".to_string(), - ) - })?; - Err(ArrowError::NotYetImplemented( - "no idea on how to convert an ipc sparse tensor to an arrow type".to_string(), - )) - } - _ => Err(ArrowError::ComputeError(format!( - "Unable to convert message with header_type: '{:?}' to arrow data", - ipc_message.header_type() - ))), - } -}