Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 144 additions & 2 deletions datafusion/core/tests/parquet/expr_adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@

use std::sync::Arc;

use arrow::array::{RecordBatch, record_batch};
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use arrow::array::{
Array, ArrayRef, BooleanArray, Int32Array, Int64Array, RecordBatch, StringArray,
StructArray, record_batch,
};
use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef};
use bytes::{BufMut, BytesMut};
use datafusion::assert_batches_eq;
use datafusion::common::Result;
Expand Down Expand Up @@ -320,6 +323,145 @@ async fn test_physical_expr_adapter_with_non_null_defaults() {
assert_batches_eq!(expected, &batches);
}

#[tokio::test]
async fn test_struct_schema_evolution_projection_and_filter() -> Result<()> {
use std::collections::HashMap;

// Physical struct: {id: Int32, name: Utf8}
let physical_struct_fields: Fields = vec![
Arc::new(Field::new("id", DataType::Int32, false)),
Arc::new(Field::new("name", DataType::Utf8, true)),
]
.into();

let struct_array = StructArray::new(
physical_struct_fields.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef,
],
None,
);

let physical_schema = Arc::new(Schema::new(vec![Field::new(
"s",
DataType::Struct(physical_struct_fields),
true,
)]));

let batch =
RecordBatch::try_new(Arc::clone(&physical_schema), vec![Arc::new(struct_array)])?;

let store = Arc::new(InMemory::new()) as Arc<dyn ObjectStore>;
let store_url = ObjectStoreUrl::parse("memory://").unwrap();
write_parquet(batch, store.clone(), "struct_evolution.parquet").await;

// Logical struct: {id: Int64?, name: Utf8?, extra: Boolean?} + metadata
let logical_struct_fields: Fields = vec![
Arc::new(Field::new("id", DataType::Int64, true)),
Arc::new(Field::new("name", DataType::Utf8, true)),
Arc::new(Field::new("extra", DataType::Boolean, true).with_metadata(
HashMap::from([("nested_meta".to_string(), "1".to_string())]),
)),
]
.into();

let table_schema = Arc::new(Schema::new(vec![
Field::new("s", DataType::Struct(logical_struct_fields), false)
.with_metadata(HashMap::from([("top_meta".to_string(), "1".to_string())])),
]));

let mut cfg = SessionConfig::new()
.with_collect_statistics(false)
.with_parquet_pruning(false)
.with_parquet_page_index_pruning(false);
cfg.options_mut().execution.parquet.pushdown_filters = true;

let ctx = SessionContext::new_with_config(cfg);
ctx.register_object_store(store_url.as_ref(), Arc::clone(&store));

let listing_table_config =
ListingTableConfig::new(ListingTableUrl::parse("memory:///").unwrap())
.infer_options(&ctx.state())
.await
.unwrap()
.with_schema(table_schema.clone())
.with_expr_adapter_factory(Arc::new(DefaultPhysicalExprAdapterFactory));

let table = ListingTable::try_new(listing_table_config).unwrap();
ctx.register_table("t", Arc::new(table)).unwrap();

let batches = ctx
.sql("SELECT s FROM t")
.await
.unwrap()
.collect()
.await
.unwrap();
assert_eq!(batches.len(), 1);

// Verify top-level metadata propagation
let output_schema = batches[0].schema();
let s_field = output_schema.field_with_name("s").unwrap();
assert_eq!(
s_field.metadata().get("top_meta").map(String::as_str),
Some("1")
);

// Verify nested struct type/field propagation + values
let s_array = batches[0]
.column(0)
.as_any()
.downcast_ref::<StructArray>()
.expect("expected struct array");

let id_array = s_array
.column_by_name("id")
.expect("id column")
.as_any()
.downcast_ref::<Int64Array>()
.expect("id should be cast to Int64");
assert_eq!(id_array.values(), &[1, 2, 3]);

let extra_array = s_array.column_by_name("extra").expect("extra column");
assert_eq!(extra_array.null_count(), 3);

// Verify nested field metadata propagation
let extra_field = match s_field.data_type() {
DataType::Struct(fields) => fields
.iter()
.find(|f| f.name() == "extra")
.expect("extra field"),
other => panic!("expected struct type for s, got {other:?}"),
};
assert_eq!(
extra_field
.metadata()
.get("nested_meta")
.map(String::as_str),
Some("1")
);

// Smoke test: filtering on a missing nested field evaluates correctly
let filtered = ctx
.sql("SELECT get_field(s, 'extra') AS extra FROM t WHERE get_field(s, 'extra') IS NULL")
.await
.unwrap()
.collect()
.await
.unwrap();
assert_eq!(filtered.len(), 1);
assert_eq!(filtered[0].num_rows(), 3);
let extra = filtered[0]
.column(0)
.as_any()
.downcast_ref::<BooleanArray>()
.expect("extra should be a boolean array");
assert_eq!(extra.null_count(), 3);

Ok(())
}

/// Test demonstrating that a single PhysicalExprAdapterFactory instance can be
/// reused across multiple ListingTable instances.
///
Expand Down
99 changes: 92 additions & 7 deletions datafusion/physical-expr-adapter/src/schema_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use arrow::compute::can_cast_types;
use arrow::datatypes::{DataType, Field, SchemaRef};
use datafusion_common::{
Result, ScalarValue, exec_err,
metadata::FieldMetadata,
nested_struct::validate_struct_compatibility,
tree_node::{Transformed, TransformedResult, TreeNode},
};
Expand Down Expand Up @@ -368,7 +369,10 @@ impl DefaultPhysicalExprAdapterRewriter {
};

let null_value = ScalarValue::Null.cast_to(logical_struct_field.data_type())?;
Ok(Some(expressions::lit(null_value)))
Ok(Some(Arc::new(expressions::Literal::new_with_metadata(
null_value,
Some(FieldMetadata::from(logical_struct_field.as_ref())),
))))
}

fn rewrite_column(
Expand Down Expand Up @@ -416,24 +420,33 @@ impl DefaultPhysicalExprAdapterRewriter {
// If the column is missing from the physical schema fill it in with nulls.
// For a different behavior, provide a custom `PhysicalExprAdapter` implementation.
let null_value = ScalarValue::Null.cast_to(logical_field.data_type())?;
return Ok(Transformed::yes(expressions::lit(null_value)));
return Ok(Transformed::yes(Arc::new(
expressions::Literal::new_with_metadata(
null_value,
Some(FieldMetadata::from(logical_field)),
),
)));
}
};
let physical_field = self.physical_file_schema.field(physical_column_index);

if column.index() == physical_column_index
&& logical_field.data_type() == physical_field.data_type()
{
if column.index() == physical_column_index && logical_field == physical_field {
return Ok(Transformed::no(expr));
}

let column = self.resolve_column(column, physical_column_index)?;

if logical_field.data_type() == physical_field.data_type() {
// If the data types match, we can use the column as is
if logical_field == physical_field {
// If the fields match (including metadata/nullability), we can use the column as is
return Ok(Transformed::yes(Arc::new(column)));
}

if logical_field.data_type() == physical_field.data_type() {
// The data type matches, but the field metadata / nullability differs.
// Emit a CastColumnExpr so downstream schema construction uses the logical field.
return self.create_cast_column_expr(column, logical_field);
}

// We need to cast the column to the logical data type
// TODO: add optimization to move the cast from the column to literal expressions in the case of `col = 123`
// since that's much cheaper to evalaute.
Expand Down Expand Up @@ -690,6 +703,43 @@ mod tests {
assert!(result.as_any().downcast_ref::<CastColumnExpr>().is_some());
}

#[test]
fn test_rewrite_column_with_metadata_or_nullability_mismatch() -> Result<()> {
use std::collections::HashMap;

let physical_schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]);
let logical_schema =
Schema::new(vec![Field::new("a", DataType::Int64, false).with_metadata(
HashMap::from([("logical_meta".to_string(), "1".to_string())]),
)]);

let factory = DefaultPhysicalExprAdapterFactory;
let adapter = factory
.create(Arc::new(logical_schema), Arc::new(physical_schema.clone()))
.unwrap();

let result = adapter.rewrite(Arc::new(Column::new("a", 0)))?;
let cast = result
.as_any()
.downcast_ref::<CastColumnExpr>()
.expect("Expected CastColumnExpr");

assert_eq!(cast.target_field().data_type(), &DataType::Int64);
assert!(!cast.target_field().is_nullable());
assert_eq!(
cast.target_field()
.metadata()
.get("logical_meta")
.map(String::as_str),
Some("1")
);

// Ensure the expression reports the logical nullability regardless of input schema
assert!(!result.nullable(physical_schema.as_ref())?);

Ok(())
}

#[test]
fn test_rewrite_multi_column_expr_with_type_cast() {
let (physical_schema, logical_schema) = create_test_schema();
Expand Down Expand Up @@ -862,6 +912,41 @@ mod tests {
Ok(())
}

#[test]
fn test_rewrite_missing_column_propagates_metadata() -> Result<()> {
use std::collections::HashMap;

let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let logical_schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, true).with_metadata(HashMap::from([(
"logical_meta".to_string(),
"1".to_string(),
)])),
]);

let factory = DefaultPhysicalExprAdapterFactory;
let adapter = factory
.create(Arc::new(logical_schema), Arc::new(physical_schema.clone()))
.unwrap();

let result = adapter.rewrite(Arc::new(Column::new("b", 1)))?;
let literal = result
.as_any()
.downcast_ref::<expressions::Literal>()
.expect("Expected literal expression");

assert_eq!(
literal
.return_field(physical_schema.as_ref())?
.metadata()
.get("logical_meta")
.map(String::as_str),
Some("1")
);
Ok(())
}

#[test]
fn test_rewrite_missing_column_non_nullable_error() {
let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
Expand Down
14 changes: 13 additions & 1 deletion datafusion/physical-expr/src/equivalence/properties/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use self::dependency::{
use crate::equivalence::{
AcrossPartitions, EquivalenceGroup, OrderingEquivalenceClass, ProjectionMapping,
};
use crate::expressions::{CastExpr, Column, Literal, with_new_schema};
use crate::expressions::{CastColumnExpr, CastExpr, Column, Literal, with_new_schema};
use crate::{
ConstExpr, LexOrdering, LexRequirement, PhysicalExpr, PhysicalSortExpr,
PhysicalSortRequirement,
Expand Down Expand Up @@ -853,6 +853,18 @@ impl EquivalenceProperties {
sort_expr.options,
));
}
} else if let Some(cast_expr) =
r_expr.as_any().downcast_ref::<CastColumnExpr>()
{
let cast_type = cast_expr.target_field().data_type();
if cast_expr.expr().eq(&sort_expr.expr)
&& CastExpr::check_bigger_cast(cast_type, &expr_type)
{
result.push(PhysicalSortExpr::new(
r_expr,
sort_expr.options,
));
}
}
}
result.push(sort_expr);
Expand Down
20 changes: 20 additions & 0 deletions datafusion/physical-expr/src/expressions/cast_column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ use datafusion_common::{
Result, ScalarValue, format::DEFAULT_CAST_OPTIONS, nested_struct::cast_column,
};
use datafusion_expr_common::columnar_value::ColumnarValue;
use datafusion_expr_common::interval_arithmetic::Interval;
use datafusion_expr_common::sort_properties::ExprProperties;
use std::{
any::Any,
fmt::{self, Display},
Expand Down Expand Up @@ -177,6 +179,24 @@ impl PhysicalExpr for CastColumnExpr {
)))
}

/// A [`CastColumnExpr`] preserves the ordering of its child if the cast is done
/// under the same datatype family.
fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
let source_datatype = children[0].range.data_type();
let target_type = self.target_field.data_type();

let unbounded = Interval::make_unbounded(target_type)?;
if (source_datatype.is_numeric() || source_datatype == DataType::Boolean)
&& target_type.is_numeric()
|| source_datatype.is_temporal() && target_type.is_temporal()
|| source_datatype.eq(target_type)
{
Ok(children[0].clone().with_range(unbounded))
} else {
Ok(ExprProperties::new_unknown().with_range(unbounded))
}
}

fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
Display::fmt(self, f)
}
Expand Down