diff --git a/datafusion/core/tests/parquet/expr_adapter.rs b/datafusion/core/tests/parquet/expr_adapter.rs index aee37fda1670d..bb1aa272d4091 100644 --- a/datafusion/core/tests/parquet/expr_adapter.rs +++ b/datafusion/core/tests/parquet/expr_adapter.rs @@ -17,8 +17,11 @@ use std::sync::Arc; -use arrow::array::{RecordBatch, record_batch}; -use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use arrow::array::{ + Array, ArrayRef, BooleanArray, Int32Array, Int64Array, RecordBatch, StringArray, + StructArray, record_batch, +}; +use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef}; use bytes::{BufMut, BytesMut}; use datafusion::assert_batches_eq; use datafusion::common::Result; @@ -320,6 +323,145 @@ async fn test_physical_expr_adapter_with_non_null_defaults() { assert_batches_eq!(expected, &batches); } +#[tokio::test] +async fn test_struct_schema_evolution_projection_and_filter() -> Result<()> { + use std::collections::HashMap; + + // Physical struct: {id: Int32, name: Utf8} + let physical_struct_fields: Fields = vec![ + Arc::new(Field::new("id", DataType::Int32, false)), + Arc::new(Field::new("name", DataType::Utf8, true)), + ] + .into(); + + let struct_array = StructArray::new( + physical_struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef, + ], + None, + ); + + let physical_schema = Arc::new(Schema::new(vec![Field::new( + "s", + DataType::Struct(physical_struct_fields), + true, + )])); + + let batch = + RecordBatch::try_new(Arc::clone(&physical_schema), vec![Arc::new(struct_array)])?; + + let store = Arc::new(InMemory::new()) as Arc; + let store_url = ObjectStoreUrl::parse("memory://").unwrap(); + write_parquet(batch, store.clone(), "struct_evolution.parquet").await; + + // Logical struct: {id: Int64?, name: Utf8?, extra: Boolean?} + metadata + let logical_struct_fields: Fields = vec![ + Arc::new(Field::new("id", DataType::Int64, true)), + Arc::new(Field::new("name", DataType::Utf8, true)), + Arc::new(Field::new("extra", DataType::Boolean, true).with_metadata( + HashMap::from([("nested_meta".to_string(), "1".to_string())]), + )), + ] + .into(); + + let table_schema = Arc::new(Schema::new(vec![ + Field::new("s", DataType::Struct(logical_struct_fields), false) + .with_metadata(HashMap::from([("top_meta".to_string(), "1".to_string())])), + ])); + + let mut cfg = SessionConfig::new() + .with_collect_statistics(false) + .with_parquet_pruning(false) + .with_parquet_page_index_pruning(false); + cfg.options_mut().execution.parquet.pushdown_filters = true; + + let ctx = SessionContext::new_with_config(cfg); + ctx.register_object_store(store_url.as_ref(), Arc::clone(&store)); + + let listing_table_config = + ListingTableConfig::new(ListingTableUrl::parse("memory:///").unwrap()) + .infer_options(&ctx.state()) + .await + .unwrap() + .with_schema(table_schema.clone()) + .with_expr_adapter_factory(Arc::new(DefaultPhysicalExprAdapterFactory)); + + let table = ListingTable::try_new(listing_table_config).unwrap(); + ctx.register_table("t", Arc::new(table)).unwrap(); + + let batches = ctx + .sql("SELECT s FROM t") + .await + .unwrap() + .collect() + .await + .unwrap(); + assert_eq!(batches.len(), 1); + + // Verify top-level metadata propagation + let output_schema = batches[0].schema(); + let s_field = output_schema.field_with_name("s").unwrap(); + assert_eq!( + s_field.metadata().get("top_meta").map(String::as_str), + Some("1") + ); + + // Verify nested struct type/field propagation + values + let s_array = batches[0] + .column(0) + .as_any() + .downcast_ref::() + .expect("expected struct array"); + + let id_array = s_array + .column_by_name("id") + .expect("id column") + .as_any() + .downcast_ref::() + .expect("id should be cast to Int64"); + assert_eq!(id_array.values(), &[1, 2, 3]); + + let extra_array = s_array.column_by_name("extra").expect("extra column"); + assert_eq!(extra_array.null_count(), 3); + + // Verify nested field metadata propagation + let extra_field = match s_field.data_type() { + DataType::Struct(fields) => fields + .iter() + .find(|f| f.name() == "extra") + .expect("extra field"), + other => panic!("expected struct type for s, got {other:?}"), + }; + assert_eq!( + extra_field + .metadata() + .get("nested_meta") + .map(String::as_str), + Some("1") + ); + + // Smoke test: filtering on a missing nested field evaluates correctly + let filtered = ctx + .sql("SELECT get_field(s, 'extra') AS extra FROM t WHERE get_field(s, 'extra') IS NULL") + .await + .unwrap() + .collect() + .await + .unwrap(); + assert_eq!(filtered.len(), 1); + assert_eq!(filtered[0].num_rows(), 3); + let extra = filtered[0] + .column(0) + .as_any() + .downcast_ref::() + .expect("extra should be a boolean array"); + assert_eq!(extra.null_count(), 3); + + Ok(()) +} + /// Test demonstrating that a single PhysicalExprAdapterFactory instance can be /// reused across multiple ListingTable instances. /// diff --git a/datafusion/physical-expr-adapter/src/schema_rewriter.rs b/datafusion/physical-expr-adapter/src/schema_rewriter.rs index 5a9ee8502eaa9..2cf7336a1768d 100644 --- a/datafusion/physical-expr-adapter/src/schema_rewriter.rs +++ b/datafusion/physical-expr-adapter/src/schema_rewriter.rs @@ -29,6 +29,7 @@ use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field, SchemaRef}; use datafusion_common::{ Result, ScalarValue, exec_err, + metadata::FieldMetadata, nested_struct::validate_struct_compatibility, tree_node::{Transformed, TransformedResult, TreeNode}, }; @@ -368,7 +369,10 @@ impl DefaultPhysicalExprAdapterRewriter { }; let null_value = ScalarValue::Null.cast_to(logical_struct_field.data_type())?; - Ok(Some(expressions::lit(null_value))) + Ok(Some(Arc::new(expressions::Literal::new_with_metadata( + null_value, + Some(FieldMetadata::from(logical_struct_field.as_ref())), + )))) } fn rewrite_column( @@ -416,24 +420,33 @@ impl DefaultPhysicalExprAdapterRewriter { // If the column is missing from the physical schema fill it in with nulls. // For a different behavior, provide a custom `PhysicalExprAdapter` implementation. let null_value = ScalarValue::Null.cast_to(logical_field.data_type())?; - return Ok(Transformed::yes(expressions::lit(null_value))); + return Ok(Transformed::yes(Arc::new( + expressions::Literal::new_with_metadata( + null_value, + Some(FieldMetadata::from(logical_field)), + ), + ))); } }; let physical_field = self.physical_file_schema.field(physical_column_index); - if column.index() == physical_column_index - && logical_field.data_type() == physical_field.data_type() - { + if column.index() == physical_column_index && logical_field == physical_field { return Ok(Transformed::no(expr)); } let column = self.resolve_column(column, physical_column_index)?; - if logical_field.data_type() == physical_field.data_type() { - // If the data types match, we can use the column as is + if logical_field == physical_field { + // If the fields match (including metadata/nullability), we can use the column as is return Ok(Transformed::yes(Arc::new(column))); } + if logical_field.data_type() == physical_field.data_type() { + // The data type matches, but the field metadata / nullability differs. + // Emit a CastColumnExpr so downstream schema construction uses the logical field. + return self.create_cast_column_expr(column, logical_field); + } + // We need to cast the column to the logical data type // TODO: add optimization to move the cast from the column to literal expressions in the case of `col = 123` // since that's much cheaper to evalaute. @@ -690,6 +703,43 @@ mod tests { assert!(result.as_any().downcast_ref::().is_some()); } + #[test] + fn test_rewrite_column_with_metadata_or_nullability_mismatch() -> Result<()> { + use std::collections::HashMap; + + let physical_schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]); + let logical_schema = + Schema::new(vec![Field::new("a", DataType::Int64, false).with_metadata( + HashMap::from([("logical_meta".to_string(), "1".to_string())]), + )]); + + let factory = DefaultPhysicalExprAdapterFactory; + let adapter = factory + .create(Arc::new(logical_schema), Arc::new(physical_schema.clone())) + .unwrap(); + + let result = adapter.rewrite(Arc::new(Column::new("a", 0)))?; + let cast = result + .as_any() + .downcast_ref::() + .expect("Expected CastColumnExpr"); + + assert_eq!(cast.target_field().data_type(), &DataType::Int64); + assert!(!cast.target_field().is_nullable()); + assert_eq!( + cast.target_field() + .metadata() + .get("logical_meta") + .map(String::as_str), + Some("1") + ); + + // Ensure the expression reports the logical nullability regardless of input schema + assert!(!result.nullable(physical_schema.as_ref())?); + + Ok(()) + } + #[test] fn test_rewrite_multi_column_expr_with_type_cast() { let (physical_schema, logical_schema) = create_test_schema(); @@ -862,6 +912,41 @@ mod tests { Ok(()) } + #[test] + fn test_rewrite_missing_column_propagates_metadata() -> Result<()> { + use std::collections::HashMap; + + let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let logical_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, true).with_metadata(HashMap::from([( + "logical_meta".to_string(), + "1".to_string(), + )])), + ]); + + let factory = DefaultPhysicalExprAdapterFactory; + let adapter = factory + .create(Arc::new(logical_schema), Arc::new(physical_schema.clone())) + .unwrap(); + + let result = adapter.rewrite(Arc::new(Column::new("b", 1)))?; + let literal = result + .as_any() + .downcast_ref::() + .expect("Expected literal expression"); + + assert_eq!( + literal + .return_field(physical_schema.as_ref())? + .metadata() + .get("logical_meta") + .map(String::as_str), + Some("1") + ); + Ok(()) + } + #[test] fn test_rewrite_missing_column_non_nullable_error() { let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); diff --git a/datafusion/physical-expr/src/equivalence/properties/mod.rs b/datafusion/physical-expr/src/equivalence/properties/mod.rs index 996bc4b08fcd2..ecbfc0f623981 100644 --- a/datafusion/physical-expr/src/equivalence/properties/mod.rs +++ b/datafusion/physical-expr/src/equivalence/properties/mod.rs @@ -33,7 +33,7 @@ use self::dependency::{ use crate::equivalence::{ AcrossPartitions, EquivalenceGroup, OrderingEquivalenceClass, ProjectionMapping, }; -use crate::expressions::{CastExpr, Column, Literal, with_new_schema}; +use crate::expressions::{CastColumnExpr, CastExpr, Column, Literal, with_new_schema}; use crate::{ ConstExpr, LexOrdering, LexRequirement, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, @@ -853,6 +853,18 @@ impl EquivalenceProperties { sort_expr.options, )); } + } else if let Some(cast_expr) = + r_expr.as_any().downcast_ref::() + { + let cast_type = cast_expr.target_field().data_type(); + if cast_expr.expr().eq(&sort_expr.expr) + && CastExpr::check_bigger_cast(cast_type, &expr_type) + { + result.push(PhysicalSortExpr::new( + r_expr, + sort_expr.options, + )); + } } } result.push(sort_expr); diff --git a/datafusion/physical-expr/src/expressions/cast_column.rs b/datafusion/physical-expr/src/expressions/cast_column.rs index d80b6f4a588a4..f6c4d080fc7ed 100644 --- a/datafusion/physical-expr/src/expressions/cast_column.rs +++ b/datafusion/physical-expr/src/expressions/cast_column.rs @@ -27,6 +27,8 @@ use datafusion_common::{ Result, ScalarValue, format::DEFAULT_CAST_OPTIONS, nested_struct::cast_column, }; use datafusion_expr_common::columnar_value::ColumnarValue; +use datafusion_expr_common::interval_arithmetic::Interval; +use datafusion_expr_common::sort_properties::ExprProperties; use std::{ any::Any, fmt::{self, Display}, @@ -177,6 +179,24 @@ impl PhysicalExpr for CastColumnExpr { ))) } + /// A [`CastColumnExpr`] preserves the ordering of its child if the cast is done + /// under the same datatype family. + fn get_properties(&self, children: &[ExprProperties]) -> Result { + let source_datatype = children[0].range.data_type(); + let target_type = self.target_field.data_type(); + + let unbounded = Interval::make_unbounded(target_type)?; + if (source_datatype.is_numeric() || source_datatype == DataType::Boolean) + && target_type.is_numeric() + || source_datatype.is_temporal() && target_type.is_temporal() + || source_datatype.eq(target_type) + { + Ok(children[0].clone().with_range(unbounded)) + } else { + Ok(ExprProperties::new_unknown().with_range(unbounded)) + } + } + fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { Display::fmt(self, f) }