diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index b7101e2bbf40..65bb40810f18 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -1034,7 +1034,7 @@ impl SchemaExt for Schema { .iter() .zip(other.fields().iter()) .try_for_each(|(f1, f2)| { - if f1.name() != f2.name() || !DFSchema::datatype_is_logically_equal(f1.data_type(), f2.data_type()) { + if f1.name() != f2.name() || (!DFSchema::datatype_is_logically_equal(f1.data_type(), f2.data_type()) && !can_cast_types(f2.data_type(), f1.data_type())) { _plan_err!( "Inserting query schema mismatch: Expected table field '{}' with type {:?}, \ but got '{}' with type {:?}.", diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 43428d6846a1..a902cf8ae65b 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -5673,3 +5673,75 @@ async fn test_fill_null_all_columns() -> Result<()> { assert_batches_sorted_eq!(expected, &results); Ok(()) } + +#[tokio::test] +async fn test_insert_into_casting_support() -> Result<()> { + // Testing case1: + // Inserting query schema mismatch: Expected table field 'a' with type Float16, but got 'a' with type Utf8. + // And the cast is not supported from Utf8 to Float16. + + // Create a new schema with one field called "a" of type Float16, and setting nullable to false + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float16, false)])); + + let session_ctx = SessionContext::new(); + + // Create and register the initial table with the provided schema and data + let initial_table = Arc::new(MemTable::try_new(schema.clone(), vec![vec![]])?); + session_ctx.register_table("t", initial_table.clone())?; + + let mut write_df = session_ctx.sql("values ('a123'), ('b456')").await.unwrap(); + + write_df = write_df + .clone() + .with_column_renamed("column1", "a") + .unwrap(); + + let e = write_df + .write_table("t", DataFrameWriteOptions::new()) + .await + .unwrap_err(); + + assert_contains!(e.to_string(), "Inserting query schema mismatch: Expected table field 'a' with type Float16, but got 'a' with type Utf8."); + + // Testing case2: + // Inserting query schema mismatch: Expected table field 'a' with type Utf8View, but got 'a' with type Utf8. + // And the cast is supported from Utf8 to Utf8View. + + // Create a new schema with one field called "a" of type Utf8View, and setting nullable to false + let schema = Arc::new(Schema::new(vec![Field::new( + "a", + DataType::Utf8View, + false, + )])); + + let initial_table = Arc::new(MemTable::try_new(schema.clone(), vec![vec![]])?); + + session_ctx.register_table("t2", initial_table.clone())?; + + let mut write_df = session_ctx.sql("values ('a123'), ('b456')").await.unwrap(); + + write_df = write_df + .clone() + .with_column_renamed("column1", "a") + .unwrap(); + + write_df + .write_table("t2", DataFrameWriteOptions::new()) + .await?; + + let res = session_ctx + .sql("select * from t2") + .await + .unwrap() + .collect() + .await + .unwrap(); + + // The result should be the same as the input which is ['a123', 'b456'] + let expected = [ + "+------+", "| a |", "+------+", "| a123 |", "| b456 |", "+------+", + ]; + + assert_batches_eq!(expected, &res); + Ok(()) +}