diff --git a/src/event/format/json.rs b/src/event/format/json.rs index c28b701de..903ab2752 100644 --- a/src/event/format/json.rs +++ b/src/event/format/json.rs @@ -23,7 +23,7 @@ use anyhow::anyhow; use arrow_array::RecordBatch; use arrow_json::reader::{infer_json_schema_from_iterator, ReaderBuilder}; use arrow_schema::{DataType, Field, Fields, Schema}; -use chrono::{DateTime, NaiveDateTime, Utc}; +use chrono::{DateTime, NaiveDate, NaiveDateTime, Utc}; use datafusion::arrow::util::bit_util::round_upto_multiple_of_64; use itertools::Itertools; use serde_json::Value; @@ -62,6 +62,7 @@ impl EventFormat for Event { schema: &HashMap>, time_partition: Option<&String>, schema_version: SchemaVersion, + static_schema_flag: bool, ) -> Result<(Self::Data, Vec>, bool), anyhow::Error> { let stream_schema = schema; @@ -111,7 +112,7 @@ impl EventFormat for Event { if value_arr .iter() - .any(|value| fields_mismatch(&schema, value, schema_version)) + .any(|value| fields_mismatch(&schema, value, schema_version, static_schema_flag)) { return Err(anyhow!( "Could not process this event due to mismatch in datatype" @@ -253,7 +254,12 @@ fn collect_keys<'a>(values: impl Iterator) -> Result], body: &Value, schema_version: SchemaVersion) -> bool { +fn fields_mismatch( + schema: &[Arc], + body: &Value, + schema_version: SchemaVersion, + static_schema_flag: bool, +) -> bool { for (name, val) in body.as_object().expect("body is of object variant") { if val.is_null() { continue; @@ -261,65 +267,118 @@ fn fields_mismatch(schema: &[Arc], body: &Value, schema_version: SchemaVe let Some(field) = get_field(schema, name) else { return true; }; - if !valid_type(field.data_type(), val, schema_version) { + if !valid_type(field, val, schema_version, static_schema_flag) { return true; } } false } -fn valid_type(data_type: &DataType, value: &Value, schema_version: SchemaVersion) -> bool { - match data_type { +fn valid_type( + field: &Field, + value: &Value, + schema_version: SchemaVersion, + static_schema_flag: bool, +) -> bool { + match field.data_type() { DataType::Boolean => value.is_boolean(), - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => value.is_i64(), + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { + validate_int(value, static_schema_flag) + } DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => value.is_u64(), DataType::Float16 | DataType::Float32 => value.is_f64(), - // All numbers can be cast as Float64 from schema version v1 - DataType::Float64 if schema_version == SchemaVersion::V1 => value.is_number(), - DataType::Float64 if schema_version != SchemaVersion::V1 => value.is_f64(), + DataType::Float64 => validate_float(value, schema_version, static_schema_flag), DataType::Utf8 => value.is_string(), - DataType::List(field) => { - let data_type = field.data_type(); - if let Value::Array(arr) = value { - for elem in arr { - if elem.is_null() { - continue; - } - if !valid_type(data_type, elem, schema_version) { - return false; - } - } - } - true - } + DataType::List(field) => validate_list(field, value, schema_version, static_schema_flag), DataType::Struct(fields) => { - if let Value::Object(val) = value { - for (key, value) in val { - let field = (0..fields.len()) - .find(|idx| fields[*idx].name() == key) - .map(|idx| &fields[idx]); - - if let Some(field) = field { - if value.is_null() { - continue; - } - if !valid_type(field.data_type(), value, schema_version) { - return false; - } - } else { - return false; - } - } - true - } else { - false + validate_struct(fields, value, schema_version, static_schema_flag) + } + DataType::Date32 => { + if let Value::String(s) = value { + return NaiveDate::parse_from_str(s, "%Y-%m-%d").is_ok(); } + false } DataType::Timestamp(_, _) => value.is_string() || value.is_number(), _ => { - error!("Unsupported datatype {:?}, value {:?}", data_type, value); - unreachable!() + error!( + "Unsupported datatype {:?}, value {:?}", + field.data_type(), + value + ); + false + } + } +} + +fn validate_int(value: &Value, static_schema_flag: bool) -> bool { + // allow casting string to int for static schema + if static_schema_flag { + if let Value::String(s) = value { + return s.trim().parse::().is_ok(); + } + } + value.is_i64() +} + +fn validate_float(value: &Value, schema_version: SchemaVersion, static_schema_flag: bool) -> bool { + // allow casting string to int for static schema + if static_schema_flag { + if let Value::String(s) = value.clone() { + let trimmed = s.trim(); + return trimmed.parse::().is_ok() || trimmed.parse::().is_ok(); + } + return value.is_number(); + } + match schema_version { + SchemaVersion::V1 => value.is_number(), + _ => value.is_f64(), + } +} + +fn validate_list( + field: &Field, + value: &Value, + schema_version: SchemaVersion, + static_schema_flag: bool, +) -> bool { + if let Value::Array(arr) = value { + for elem in arr { + if elem.is_null() { + continue; + } + if !valid_type(field, elem, schema_version, static_schema_flag) { + return false; + } + } + } + true +} + +fn validate_struct( + fields: &Fields, + value: &Value, + schema_version: SchemaVersion, + static_schema_flag: bool, +) -> bool { + if let Value::Object(val) = value { + for (key, value) in val { + let field = fields.iter().find(|f| f.name() == key); + + if let Some(field) = field { + if value.is_null() { + continue; + } + if !valid_type(field, value, schema_version, static_schema_flag) { + return false; + } + } else { + return false; + } } + true + } else { + false } } diff --git a/src/event/format/mod.rs b/src/event/format/mod.rs index ce90cfc52..40713d5e1 100644 --- a/src/event/format/mod.rs +++ b/src/event/format/mod.rs @@ -102,6 +102,7 @@ pub trait EventFormat: Sized { schema: &HashMap>, time_partition: Option<&String>, schema_version: SchemaVersion, + static_schema_flag: bool, ) -> Result<(Self::Data, EventSchema, bool), AnyError>; fn decode(data: Self::Data, schema: Arc) -> Result; @@ -117,8 +118,12 @@ pub trait EventFormat: Sized { schema_version: SchemaVersion, ) -> Result<(RecordBatch, bool), AnyError> { let p_timestamp = self.get_p_timestamp(); - let (data, mut schema, is_first) = - self.to_data(storage_schema, time_partition, schema_version)?; + let (data, mut schema, is_first) = self.to_data( + storage_schema, + time_partition, + schema_version, + static_schema_flag, + )?; if get_field(&schema, DEFAULT_TIMESTAMP_KEY).is_some() { return Err(anyhow!( diff --git a/src/query/stream_schema_provider.rs b/src/query/stream_schema_provider.rs index ff485d806..e82eb69cb 100644 --- a/src/query/stream_schema_provider.rs +++ b/src/query/stream_schema_provider.rs @@ -967,6 +967,7 @@ fn cast_or_none(scalar: &ScalarValue) -> Option> { ScalarValue::UInt32(val) => val.map(|val| CastRes::Int(val as i64)), ScalarValue::UInt64(val) => val.map(|val| CastRes::Int(val as i64)), ScalarValue::Utf8(val) => val.as_ref().map(|val| CastRes::String(val)), + ScalarValue::Date32(val) => val.map(|val| CastRes::Int(val as i64)), ScalarValue::TimestampMillisecond(val, _) => val.map(CastRes::Int), _ => None, } diff --git a/src/static_schema.rs b/src/static_schema.rs index 286ec65ad..a4a5ff995 100644 --- a/src/static_schema.rs +++ b/src/static_schema.rs @@ -111,6 +111,7 @@ pub fn convert_static_schema_to_arrow_schema( "boolean" => DataType::Boolean, "string" => DataType::Utf8, "datetime" => DataType::Timestamp(TimeUnit::Millisecond, None), + "date" => DataType::Date32, "string_list" => { DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))) }