diff --git a/datafusion/expr-common/src/casts.rs b/datafusion/expr-common/src/casts.rs index d18c3d4f043eb..4122f32f75511 100644 --- a/datafusion/expr-common/src/casts.rs +++ b/datafusion/expr-common/src/casts.rs @@ -22,6 +22,10 @@ //! unwrap_cast module to be shared between logical and physical layers. use std::cmp::Ordering; +use std::sync::Arc; + +use crate::interval_arithmetic::Interval; +use crate::operator::Operator; use arrow::datatypes::{ DataType, MAX_DECIMAL32_FOR_EACH_PRECISION, MAX_DECIMAL64_FOR_EACH_PRECISION, @@ -29,7 +33,18 @@ use arrow::datatypes::{ MIN_DECIMAL64_FOR_EACH_PRECISION, MIN_DECIMAL128_FOR_EACH_PRECISION, TimeUnit, }; use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; -use datafusion_common::ScalarValue; +use datafusion_common::{Result, ScalarValue}; + +/// Source-domain preimage of `CAST(source_expr AS target_type) OP literal`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum CastPredicatePreimage { + /// A singleton preimage represented by a literal in the source type. This + /// can keep the original comparison operator. + Exact(ScalarValue), + /// A half-open source-domain interval `[lower, upper)`. The caller must + /// map the comparison operator to range predicates. + Range(Interval), +} /// Convert a literal [`ScalarValue`] to `target_type`, preserving the exact value. /// @@ -74,6 +89,255 @@ pub fn try_cast_literal_to_type( .or_else(|| try_cast_binary(lit_value, target_type)) } +/// Computes a source-domain preimage for `CAST(source AS target_type) OP literal`. +/// +/// This is the shared semantic core for logical and physical cast-predicate +/// rewrites. It returns a singleton [`CastPredicatePreimage::Exact`] for casts +/// where moving the cast to the literal preserves comparison semantics, and a +/// [`CastPredicatePreimage::Range`] for many-to-one casts with known preimages +/// such as timestamp precision narrowing. +pub fn cast_predicate_preimage( + source_type: &DataType, + target_type: &DataType, + op: Operator, + lit_value: &ScalarValue, +) -> Result> { + if let Some(interval) = + timestamp_precision_narrowing_preimage(source_type, target_type, lit_value)? + { + return Ok(Some(CastPredicatePreimage::Range(interval))); + } + + if is_timestamp_precision_narrowing_cast(source_type, target_type) { + return Ok(None); + } + + Ok( + exact_preimage_for_cast_predicate(source_type, target_type, op, lit_value) + .map(CastPredicatePreimage::Exact), + ) +} + +/// Computes a singleton source-domain literal for exact cast-predicate rewrites. +/// +/// This intentionally returns `None` for timestamp precision narrowing: those +/// casts are many-to-one and need range preimages instead. +pub fn cast_predicate_exact_literal( + source_type: &DataType, + target_type: &DataType, + lit_value: &ScalarValue, +) -> Option { + if is_timestamp_precision_narrowing_cast(source_type, target_type) { + return None; + } + + let source_value = try_cast_literal_to_type(lit_value, source_type)?; + if is_timestamp_cast(source_type, target_type) { + let round_tripped = try_cast_literal_to_type(&source_value, target_type)?; + if &round_tripped != lit_value { + return None; + } + } + + Some(source_value) +} + +/// Returns true when casting a timestamp from `source_type` to `target_type` +/// loses timestamp precision. +pub fn is_timestamp_precision_narrowing_cast( + source_type: &DataType, + target_type: &DataType, +) -> bool { + let (DataType::Timestamp(source_unit, _), DataType::Timestamp(target_unit, _)) = + (source_type, target_type) + else { + return false; + }; + + timestamp_unit_scale(source_unit) > timestamp_unit_scale(target_unit) +} + +fn is_timestamp_cast(source_type: &DataType, target_type: &DataType) -> bool { + matches!( + (source_type, target_type), + (DataType::Timestamp(_, _), DataType::Timestamp(_, _)) + ) +} + +fn exact_preimage_for_cast_predicate( + source_type: &DataType, + target_type: &DataType, + op: Operator, + lit_value: &ScalarValue, +) -> Option { + cast_to_string_equality_preimage(source_type, target_type, op, lit_value) + .or_else(|| cast_predicate_exact_literal(source_type, target_type, lit_value)) +} + +/// Computes a singleton preimage for equality predicates over casts whose target +/// value is a string representation of an integer source value. +/// +/// For example, `CAST(int_col AS Utf8) = '123'` can be rewritten to +/// `int_col = 123`, but `CAST(int_col AS Utf8) = '0123'` cannot be rewritten to +/// `int_col = 123` because casting `123` back to a string yields `'123'`, not +/// `'0123'`. +fn cast_to_string_equality_preimage( + source_type: &DataType, + target_type: &DataType, + op: Operator, + lit_value: &ScalarValue, +) -> Option { + if !matches!( + target_type, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View + ) { + return None; + } + + match (op, lit_value) { + ( + Operator::Eq | Operator::NotEq, + ScalarValue::Utf8(Some(_)) + | ScalarValue::Utf8View(Some(_)) + | ScalarValue::LargeUtf8(Some(_)), + ) => { + // Only try for integer types (TODO can we do this for other types + // like timestamps)? + use DataType::*; + if matches!( + source_type, + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 + ) { + let casted = lit_value.cast_to(source_type).ok()?; + let round_tripped = casted.cast_to(&lit_value.data_type()).ok()?; + if lit_value != &round_tripped { + return None; + } + Some(casted) + } else { + None + } + } + _ => None, + } +} + +/// Computes the source-domain preimage interval for timestamp precision +/// narrowing casts. +/// +/// The preimage is computed entirely in the source timestamp domain and +/// preserves `source_tz`. The target timezone is intentionally *not* copied +/// into the generated bounds: this models the raw timestamp values after +/// truncation and avoids comparing source columns against target-timezone +/// literals. +fn timestamp_precision_narrowing_preimage( + source_type: &DataType, + target_type: &DataType, + lit_value: &ScalarValue, +) -> Result> { + let ( + DataType::Timestamp(source_unit, source_tz), + DataType::Timestamp(target_unit, _), + ) = (source_type, target_type) + else { + return Ok(None); + }; + + let source_scale = i128::from(timestamp_unit_scale(source_unit)); + let target_scale = i128::from(timestamp_unit_scale(target_unit)); + if source_scale <= target_scale { + return Ok(None); + } + + let Some(target_value) = timestamp_literal_value(lit_value, target_unit) else { + return Ok(None); + }; + + let bucket_width = source_scale / target_scale; + let Some((lower, upper)) = trunc_toward_zero_bucket(target_value, bucket_width) + else { + return Ok(None); + }; + + let Ok(lower) = i64::try_from(lower) else { + return Ok(None); + }; + let Ok(upper) = i64::try_from(upper) else { + return Ok(None); + }; + + Interval::try_new( + timestamp_scalar(source_unit, source_tz.clone(), lower), + timestamp_scalar(source_unit, source_tz.clone(), upper), + ) + .map(Some) +} + +/// Returns the half-open source-domain bucket `[lower, upper)` that truncates +/// toward zero to `value` when divided by `bucket_width`. +/// +/// Timestamp precision narrowing follows integer truncation toward zero rather +/// than mathematical floor. For example, when `bucket_width = 1_000_000`, both +/// `999_999` and `-999_999` truncate to `0`, while `-1_000_000` truncates to +/// `-1`. +/// +/// This makes the inverse bucket depend on the sign of `value`: +/// +/// * `value > 0`: `[value * width, (value + 1) * width)` +/// * `value == 0`: `[1 - width, width)`, spanning small negative and positive +/// values that both truncate to zero +/// * `value < 0`: `[(value - 1) * width + 1, value * width + 1)` +/// +/// The arithmetic uses `checked_*` operations and returns `None` if an +/// intermediate bound cannot be represented as `i128`. +fn trunc_toward_zero_bucket(value: i64, bucket_width: i128) -> Option<(i128, i128)> { + let value = value as i128; + if value > 0 { + let lower = value.checked_mul(bucket_width)?; + let upper = value.checked_add(1)?.checked_mul(bucket_width)?; + Some((lower, upper)) + } else if value == 0 { + Some((1_i128.checked_sub(bucket_width)?, bucket_width)) + } else { + let lower = value + .checked_sub(1)? + .checked_mul(bucket_width)? + .checked_add(1)?; + let upper = value.checked_mul(bucket_width)?.checked_add(1)?; + Some((lower, upper)) + } +} + +fn timestamp_literal_value(lit_value: &ScalarValue, unit: &TimeUnit) -> Option { + match (lit_value, unit) { + (ScalarValue::TimestampSecond(Some(value), _), TimeUnit::Second) + | (ScalarValue::TimestampMillisecond(Some(value), _), TimeUnit::Millisecond) + | (ScalarValue::TimestampMicrosecond(Some(value), _), TimeUnit::Microsecond) + | (ScalarValue::TimestampNanosecond(Some(value), _), TimeUnit::Nanosecond) => { + Some(*value) + } + _ => None, + } +} + +fn timestamp_unit_scale(unit: &TimeUnit) -> i64 { + match unit { + TimeUnit::Second => 1, + TimeUnit::Millisecond => MILLISECONDS, + TimeUnit::Microsecond => MICROSECONDS, + TimeUnit::Nanosecond => NANOSECONDS, + } +} + +fn timestamp_scalar(unit: &TimeUnit, tz: Option>, value: i64) -> ScalarValue { + match unit { + TimeUnit::Second => ScalarValue::TimestampSecond(Some(value), tz), + TimeUnit::Millisecond => ScalarValue::TimestampMillisecond(Some(value), tz), + TimeUnit::Microsecond => ScalarValue::TimestampMicrosecond(Some(value), tz), + TimeUnit::Nanosecond => ScalarValue::TimestampNanosecond(Some(value), tz), + } +} + /// Returns true if unwrap_cast_in_comparison supports this data type pub fn is_supported_type(data_type: &DataType) -> bool { is_supported_numeric_type(data_type) @@ -399,18 +663,12 @@ fn try_cast_dictionary( fn cast_between_timestamp(from: &DataType, to: &DataType, value: i128) -> Option { let value = value as i64; let from_scale = match from { - DataType::Timestamp(TimeUnit::Second, _) => 1, - DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS, - DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS, - DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS, + DataType::Timestamp(unit, _) => timestamp_unit_scale(unit), _ => return Some(value), }; let to_scale = match to { - DataType::Timestamp(TimeUnit::Second, _) => 1, - DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS, - DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS, - DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS, + DataType::Timestamp(unit, _) => timestamp_unit_scale(unit), _ => return Some(value), }; @@ -938,6 +1196,147 @@ mod tests { assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(None, None)); } + #[test] + fn test_cast_predicate_preimage_exact() { + assert_eq!( + cast_predicate_preimage( + &DataType::Int32, + &DataType::Int64, + Operator::Gt, + &ScalarValue::Int64(Some(10)), + ) + .unwrap(), + Some(CastPredicatePreimage::Exact(ScalarValue::Int32(Some(10)))) + ); + + assert_eq!( + cast_predicate_preimage( + &DataType::Int32, + &DataType::Utf8, + Operator::Eq, + &ScalarValue::Utf8(Some("123".to_string())), + ) + .unwrap(), + Some(CastPredicatePreimage::Exact(ScalarValue::Int32(Some(123)))) + ); + + assert_eq!( + cast_predicate_preimage( + &DataType::Int32, + &DataType::Utf8, + Operator::Eq, + &ScalarValue::Utf8(Some("0123".to_string())), + ) + .unwrap(), + None + ); + } + + #[test] + fn test_cast_predicate_preimage_timestamp_narrowing_range() { + let ts_ns = DataType::Timestamp(TimeUnit::Nanosecond, None); + let ts_ms = DataType::Timestamp(TimeUnit::Millisecond, None); + + assert_eq!( + cast_predicate_preimage( + &ts_ns, + &ts_ms, + Operator::Eq, + &ScalarValue::TimestampMillisecond(Some(1000), None), + ) + .unwrap(), + Some(CastPredicatePreimage::Range( + Interval::try_new( + ScalarValue::TimestampNanosecond(Some(1_000_000_000), None), + ScalarValue::TimestampNanosecond(Some(1_001_000_000), None), + ) + .unwrap() + )) + ); + + assert_eq!( + cast_predicate_preimage( + &ts_ns, + &ts_ms, + Operator::Eq, + &ScalarValue::TimestampMillisecond(Some(0), None), + ) + .unwrap(), + Some(CastPredicatePreimage::Range( + Interval::try_new( + ScalarValue::TimestampNanosecond(Some(-999_999), None), + ScalarValue::TimestampNanosecond(Some(1_000_000), None), + ) + .unwrap() + )) + ); + + assert_eq!( + cast_predicate_preimage( + &ts_ns, + &ts_ms, + Operator::Eq, + &ScalarValue::TimestampMillisecond(Some(-1), None), + ) + .unwrap(), + Some(CastPredicatePreimage::Range( + Interval::try_new( + ScalarValue::TimestampNanosecond(Some(-1_999_999), None), + ScalarValue::TimestampNanosecond(Some(-999_999), None), + ) + .unwrap() + )) + ); + } + + #[test] + fn test_cast_predicate_preimage_timestamp_widening_exact_only() { + let ts_ms = DataType::Timestamp(TimeUnit::Millisecond, None); + let ts_ns = DataType::Timestamp(TimeUnit::Nanosecond, None); + + assert_eq!( + cast_predicate_preimage( + &ts_ms, + &ts_ns, + Operator::Eq, + &ScalarValue::TimestampNanosecond(Some(123_000_000), None), + ) + .unwrap(), + Some(CastPredicatePreimage::Exact( + ScalarValue::TimestampMillisecond(Some(123), None) + )) + ); + + assert_eq!( + cast_predicate_preimage( + &ts_ms, + &ts_ns, + Operator::Eq, + &ScalarValue::TimestampNanosecond(Some(123_456_789), None), + ) + .unwrap(), + None + ); + } + + #[test] + fn test_cast_predicate_preimage_timestamp_null_literal_unsupported() { + let ts_ns = DataType::Timestamp(TimeUnit::Nanosecond, None); + let ts_ms = DataType::Timestamp(TimeUnit::Millisecond, None); + let null_ms = ScalarValue::TimestampMillisecond(None, None); + + for op in [ + Operator::Eq, + Operator::IsDistinctFrom, + Operator::IsNotDistinctFrom, + ] { + assert_eq!( + cast_predicate_preimage(&ts_ns, &ts_ms, op, &null_ms).unwrap(), + None + ); + } + } + #[test] fn test_try_cast_to_string_type() { let scalars = vec![ @@ -1373,4 +1772,70 @@ mod tests { ExpectedCast::NoValue, ); } + + #[test] + fn test_cast_predicate_preimage_extreme_literals() { + let ts_ns = DataType::Timestamp(TimeUnit::Nanosecond, None); + let ts_ms = DataType::Timestamp(TimeUnit::Millisecond, None); + + // i64::MAX in milliseconds expands beyond i64 range in nanoseconds, + // so the preimage should return None rather than panicking. + assert_eq!( + cast_predicate_preimage( + &ts_ns, + &ts_ms, + Operator::Eq, + &ScalarValue::TimestampMillisecond(Some(i64::MAX), None), + ) + .unwrap(), + None + ); + + // i64::MIN in milliseconds expands beyond i64 range in nanoseconds. + assert_eq!( + cast_predicate_preimage( + &ts_ns, + &ts_ms, + Operator::Eq, + &ScalarValue::TimestampMillisecond(Some(i64::MIN), None), + ) + .unwrap(), + None + ); + } + + #[test] + fn test_cast_predicate_preimage_timezone_preservation() { + let source_type = + DataType::Timestamp(TimeUnit::Nanosecond, Some("+05:30".into())); + let target_type = DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())); + let lit = ScalarValue::TimestampMillisecond(Some(1000), Some("UTC".into())); + + let result = + cast_predicate_preimage(&source_type, &target_type, Operator::Eq, &lit) + .unwrap(); + + match result { + Some(CastPredicatePreimage::Range(interval)) => { + let (lower, upper) = interval.into_bounds(); + assert_eq!( + lower, + ScalarValue::TimestampNanosecond( + Some(1_000_000_000), + Some("+05:30".into()) + ), + "lower bound should preserve source timezone +05:30" + ); + assert_eq!( + upper, + ScalarValue::TimestampNanosecond( + Some(1_001_000_000), + Some("+05:30".into()) + ), + "upper bound should preserve source timezone +05:30" + ); + } + other => panic!("Expected CastPredicatePreimage::Range but got {other:?}"), + } + } } diff --git a/datafusion/optimizer/src/simplify_expressions/cast_preimage.rs b/datafusion/optimizer/src/simplify_expressions/cast_preimage.rs new file mode 100644 index 0000000000000..bc125c749de83 --- /dev/null +++ b/datafusion/optimizer/src/simplify_expressions/cast_preimage.rs @@ -0,0 +1,431 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Preimage rewrites for cast comparisons. +//! +//! This module computes source-domain predicates for expressions such as +//! `CAST(expr AS target_type) OP literal`. For casts that are many-to-one, a +//! same-operator unwrap is not equivalent; the correct rewrite is a preimage +//! range over the input expression. + +use arrow::datatypes::DataType; +use datafusion_common::{Result, internal_err, tree_node::Transformed}; +use datafusion_expr::expr::InList; +use datafusion_expr::{ + BinaryExpr, Cast, Expr, Operator, TryCast, lit, simplify::SimplifyContext, +}; +use datafusion_expr_common::casts::{ + CastPredicatePreimage, cast_predicate_exact_literal, cast_predicate_preimage, +}; + +use super::udf_preimage::rewrite_with_preimage; + +pub(super) fn rewrite_cast_predicate_for_binary( + info: &SimplifyContext, + cast_expr: Expr, + literal: Expr, + op: Operator, +) -> Result> { + let Some((expr, target_type)) = cast_input_and_type(cast_expr) else { + return internal_err!("Expect cast expr"); + }; + let Expr::Literal(lit_value, _) = literal else { + return internal_err!("Expect literal expr"); + }; + + let source_type = info.get_data_type(&expr)?; + match cast_predicate_preimage(&source_type, &target_type, op, &lit_value)? { + Some(CastPredicatePreimage::Range(interval)) => { + rewrite_with_preimage(interval, op, *expr) + } + Some(CastPredicatePreimage::Exact(value)) => { + Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { + left: expr, + op, + right: Box::new(lit(value)), + }))) + } + None => internal_err!( + "Can't compute cast predicate preimage for source type {} target type {} literal {:?}", + source_type, + target_type, + lit_value + ), + } +} + +pub(super) fn supports_cast_predicate_for_binary( + info: &SimplifyContext, + expr: &Expr, + op: Operator, + literal: &Expr, +) -> bool { + if !matches!( + op, + Operator::Eq + | Operator::NotEq + | Operator::Lt + | Operator::LtEq + | Operator::Gt + | Operator::GtEq + | Operator::IsDistinctFrom + | Operator::IsNotDistinctFrom + ) { + return false; + } + + let Some((inner_expr, target_type)) = cast_input_and_type_ref(expr) else { + return false; + }; + let Expr::Literal(lit_value, _) = literal else { + return false; + }; + let Ok(source_type) = info.get_data_type(inner_expr) else { + return false; + }; + + cast_predicate_preimage(&source_type, target_type, op, lit_value) + .ok() + .flatten() + .is_some() +} + +pub(super) fn supports_cast_predicate_for_inlist( + info: &SimplifyContext, + expr: &Expr, + list: &[Expr], +) -> bool { + let Some((inner_expr, target_type)) = cast_input_and_type_ref(expr) else { + return false; + }; + let Ok(source_type) = info.get_data_type(inner_expr) else { + return false; + }; + + list.iter().all(|right| match right { + Expr::Literal(lit_val, _) => { + cast_predicate_exact_literal(&source_type, target_type, lit_val).is_some() + } + _ => false, + }) +} + +pub(super) fn rewrite_cast_predicate_for_inlist( + info: &SimplifyContext, + expr: Expr, + list: Vec, + negated: bool, +) -> Result> { + let Some((inner_expr, target_type)) = cast_input_and_type(expr) else { + return internal_err!("Expect cast expr"); + }; + let source_type = info.get_data_type(&inner_expr)?; + + let list = list + .into_iter() + .map(|right| match right { + Expr::Literal(lit_value, _) => { + let Some(value) = + cast_predicate_exact_literal(&source_type, &target_type, &lit_value) + else { + return internal_err!( + "Can't cast the list expr {:?} to type {}", + lit_value, + &source_type + ); + }; + Ok(lit(value)) + } + other_expr => internal_err!( + "Only support literal expr to optimize, but the expr is {:?}", + &other_expr + ), + }) + .collect::>>()?; + + Ok(Transformed::yes(Expr::InList(InList { + expr: inner_expr, + list, + negated, + }))) +} + +fn cast_input_and_type(cast_expr: Expr) -> Option<(Box, DataType)> { + match cast_expr { + Expr::TryCast(TryCast { expr, field, .. }) + | Expr::Cast(Cast { expr, field, .. }) => Some((expr, field.data_type().clone())), + _ => None, + } +} + +fn cast_input_and_type_ref(cast_expr: &Expr) -> Option<(&Expr, &DataType)> { + match cast_expr { + Expr::TryCast(TryCast { expr, field, .. }) + | Expr::Cast(Cast { expr, field, .. }) => { + Some((expr.as_ref(), field.data_type())) + } + _ => None, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashMap; + use std::sync::Arc; + + use crate::simplify_expressions::ExprSimplifier; + use arrow::datatypes::{Field, TimeUnit}; + use datafusion_common::{DFSchema, DFSchemaRef, ScalarValue}; + use datafusion_expr::simplify::SimplifyContext; + use datafusion_expr::{binary_expr, cast, col, in_list}; + + #[test] + fn test_cast_predicate_exact_literal_unwrap() { + let schema = expr_test_schema(); + + let expr = cast(col("c1"), DataType::Int64).gt(lit(10_i64)); + let expected = col("c1").gt(lit(10_i32)); + assert_eq!(optimize_test(expr, &schema), expected); + + let expr = lit(10_i64).lt(cast(col("c1"), DataType::Int64)); + let expected = col("c1").gt(lit(10_i32)); + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_cast_predicate_string_integer_round_trip() { + let schema = expr_test_schema(); + + let expr = cast(col("c1"), DataType::Utf8).eq(lit("123")); + let expected = col("c1").eq(lit(123_i32)); + assert_eq!(optimize_test(expr, &schema), expected); + + let expr = cast(col("c1"), DataType::Utf8).eq(lit("0123")); + assert_eq!(optimize_test(expr.clone(), &schema), expr); + } + + #[test] + fn test_cast_predicate_inlist_exact_unwrap() { + let schema = expr_test_schema(); + + let expr = in_list( + cast(col("c1"), DataType::Int64), + vec![lit(0_i64), lit(1_i64), lit(2_i64), lit(3_i64), lit(4_i64)], + false, + ); + let expected = in_list( + col("c1"), + vec![lit(0_i32), lit(1_i32), lit(2_i32), lit(3_i32), lit(4_i32)], + false, + ); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_cast_preimage_timestamp_precision_narrowing_eq() { + let schema = expr_test_schema(); + + let expr = + cast(col("ts_nano"), timestamp_millis_type()).eq(lit_timestamp_millis(1000)); + let expected = col("ts_nano") + .gt_eq(lit_timestamp_nano(1_000_000_000)) + .and(col("ts_nano").lt(lit_timestamp_nano(1_001_000_000))); + assert_eq!(optimize_test(expr, &schema), expected); + + let expr = + cast(col("ts_nano"), timestamp_millis_type()).eq(lit_timestamp_millis(0)); + let expected = col("ts_nano") + .gt_eq(lit_timestamp_nano(-999_999)) + .and(col("ts_nano").lt(lit_timestamp_nano(1_000_000))); + assert_eq!(optimize_test(expr, &schema), expected); + + let expr = + cast(col("ts_nano"), timestamp_millis_type()).eq(lit_timestamp_millis(-1)); + let expected = col("ts_nano") + .gt_eq(lit_timestamp_nano(-1_999_999)) + .and(col("ts_nano").lt(lit_timestamp_nano(-999_999))); + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_cast_preimage_timestamp_precision_narrowing_inequality() { + let schema = expr_test_schema(); + + let expr = + cast(col("ts_nano"), timestamp_millis_type()).gt(lit_timestamp_millis(1000)); + let expected = col("ts_nano").gt_eq(lit_timestamp_nano(1_001_000_000)); + assert_eq!(optimize_test(expr, &schema), expected); + + let expr = + cast(col("ts_nano"), timestamp_millis_type()).lt_eq(lit_timestamp_millis(-1)); + let expected = col("ts_nano").lt(lit_timestamp_nano(-999_999)); + assert_eq!(optimize_test(expr, &schema), expected); + + let expr = + cast(col("ts_nano"), timestamp_millis_type()).lt(lit_timestamp_millis(0)); + let expected = col("ts_nano").lt(lit_timestamp_nano(-999_999)); + assert_eq!(optimize_test(expr, &schema), expected); + + let expr = + cast(col("ts_nano"), timestamp_millis_type()).lt_eq(lit_timestamp_millis(0)); + let expected = col("ts_nano").lt(lit_timestamp_nano(1_000_000)); + assert_eq!(optimize_test(expr, &schema), expected); + + let expr = + cast(col("ts_nano"), timestamp_millis_type()).gt(lit_timestamp_millis(0)); + let expected = col("ts_nano").gt_eq(lit_timestamp_nano(1_000_000)); + assert_eq!(optimize_test(expr, &schema), expected); + + let expr = + cast(col("ts_nano"), timestamp_millis_type()).gt_eq(lit_timestamp_millis(0)); + let expected = col("ts_nano").gt_eq(lit_timestamp_nano(-999_999)); + assert_eq!(optimize_test(expr, &schema), expected); + + let expr = + cast(col("ts_nano"), timestamp_millis_type()).not_eq(lit_timestamp_millis(0)); + let expected = col("ts_nano") + .lt(lit_timestamp_nano(-999_999)) + .or(col("ts_nano").gt_eq(lit_timestamp_nano(1_000_000))); + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_cast_preimage_timestamp_precision_narrowing_distinctness() { + let schema = expr_test_schema(); + + let expr = binary_expr( + cast(col("ts_nano"), timestamp_millis_type()), + Operator::IsNotDistinctFrom, + lit_timestamp_millis(0), + ); + let expected = col("ts_nano") + .gt_eq(lit_timestamp_nano(-999_999)) + .and(col("ts_nano").lt(lit_timestamp_nano(1_000_000))); + assert_eq!(optimize_test(expr, &schema), expected); + + let expr = binary_expr( + cast(col("ts_nano"), timestamp_millis_type()), + Operator::IsDistinctFrom, + lit_timestamp_millis(0), + ); + let expected = col("ts_nano") + .lt(lit_timestamp_nano(-999_999)) + .or(col("ts_nano").gt_eq(lit_timestamp_nano(1_000_000))); + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_cast_preimage_timestamp_widening_requires_exact_literal() { + let schema = expr_test_schema(); + + let expr = cast(col("ts_milli"), timestamp_nano_type()) + .eq(lit_timestamp_nano(123_000_000)); + let expected = col("ts_milli").eq(lit_timestamp_millis(123)); + assert_eq!(optimize_test(expr, &schema), expected); + + let expr = cast(col("ts_milli"), timestamp_nano_type()) + .eq(lit_timestamp_nano(123_456_789)); + assert_eq!(optimize_test(expr.clone(), &schema), expr); + + let expr = cast(col("ts_milli"), timestamp_nano_type()) + .gt_eq(lit_timestamp_nano(123_000_000)); + let expected = col("ts_milli").gt_eq(lit_timestamp_millis(123)); + assert_eq!(optimize_test(expr, &schema), expected); + + let expr = cast(col("ts_milli"), timestamp_nano_type()) + .gt_eq(lit_timestamp_nano(123_456_789)); + assert_eq!(optimize_test(expr.clone(), &schema), expr); + } + + #[test] + fn test_cast_preimage_timestamp_literal_left_range() { + let schema = expr_test_schema(); + + // lit_timestamp_millis(1000) < cast(col("ts_nano"), timestamp_millis_type()) + // should swap and rewrite to ts_nano >= 1_001_000_000ns + let expr = + lit_timestamp_millis(1000).lt(cast(col("ts_nano"), timestamp_millis_type())); + let expected = col("ts_nano").gt_eq(lit_timestamp_nano(1_001_000_000)); + assert_eq!(optimize_test(expr, &schema), expected); + + // lit_timestamp_millis(1000) <= cast(col("ts_nano"), timestamp_millis_type()) + // should swap and rewrite to ts_nano >= 1_000_000_000ns + let expr = lit_timestamp_millis(1000) + .lt_eq(cast(col("ts_nano"), timestamp_millis_type())); + let expected = col("ts_nano").gt_eq(lit_timestamp_nano(1_000_000_000)); + assert_eq!(optimize_test(expr, &schema), expected); + + // lit_timestamp_millis(1000) > cast(col("ts_nano"), timestamp_millis_type()) + // should swap and rewrite to ts_nano < 1_000_000_000ns + let expr = + lit_timestamp_millis(1000).gt(cast(col("ts_nano"), timestamp_millis_type())); + let expected = col("ts_nano").lt(lit_timestamp_nano(1_000_000_000)); + assert_eq!(optimize_test(expr, &schema), expected); + + // lit_timestamp_millis(1000) = cast(col("ts_nano"), timestamp_millis_type()) + // should swap and rewrite to range preimage + let expr = + lit_timestamp_millis(1000).eq(cast(col("ts_nano"), timestamp_millis_type())); + let expected = col("ts_nano") + .gt_eq(lit_timestamp_nano(1_000_000_000)) + .and(col("ts_nano").lt(lit_timestamp_nano(1_001_000_000))); + assert_eq!(optimize_test(expr, &schema), expected); + } + + fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr { + let simplifier = ExprSimplifier::new( + SimplifyContext::builder() + .with_schema(Arc::clone(schema)) + .build(), + ); + + simplifier.simplify(expr).unwrap() + } + + fn expr_test_schema() -> DFSchemaRef { + Arc::new( + DFSchema::from_unqualified_fields( + vec![ + Field::new("c1", DataType::Int32, false), + Field::new("ts_milli", timestamp_millis_type(), false), + Field::new("ts_nano", timestamp_nano_type(), false), + ] + .into(), + HashMap::new(), + ) + .unwrap(), + ) + } + + fn lit_timestamp_nano(ts: i64) -> Expr { + lit(ScalarValue::TimestampNanosecond(Some(ts), None)) + } + + fn lit_timestamp_millis(ts: i64) -> Expr { + lit(ScalarValue::TimestampMillisecond(Some(ts), None)) + } + + fn timestamp_nano_type() -> DataType { + DataType::Timestamp(TimeUnit::Nanosecond, None) + } + + fn timestamp_millis_type() -> DataType { + DataType::Timestamp(TimeUnit::Millisecond, None) + } +} diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 39c8541b51b2f..65a57cf561061 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -55,18 +55,16 @@ use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionP use super::inlist_simplifier::ShortenInListSimplifier; use super::utils::*; use crate::simplify_expressions::SimplifyContext; -use crate::simplify_expressions::regex::simplify_regex_expr; -use crate::simplify_expressions::unwrap_cast::{ - is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary, - is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist, - unwrap_cast_in_comparison_for_binary, +use crate::simplify_expressions::cast_preimage::{ + rewrite_cast_predicate_for_binary, rewrite_cast_predicate_for_inlist, + supports_cast_predicate_for_binary, supports_cast_predicate_for_inlist, }; +use crate::simplify_expressions::regex::simplify_regex_expr; use crate::{ analyzer::type_coercion::TypeCoercionRewriter, simplify_expressions::udf_preimage::rewrite_with_preimage, }; use datafusion_expr::expr_rewriter::rewrite_with_guarantees_map; -use datafusion_expr_common::casts::try_cast_literal_to_type; use indexmap::IndexSet; use regex::Regex; @@ -1926,28 +1924,27 @@ impl TreeNodeRewriter for Simplifier<'_> { } // ======================================= - // unwrap_cast_in_comparison + // cast_predicate_in_comparison // ======================================= // // For case: // try_cast/cast(expr as data_type) op literal Expr::BinaryExpr(BinaryExpr { left, op, right }) - if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary( - info, &left, op, &right, - ) && op.supports_propagation() => + if supports_cast_predicate_for_binary(info, &left, op, &right) + && op.supports_propagation() => { - unwrap_cast_in_comparison_for_binary(info, *left, *right, op)? + rewrite_cast_predicate_for_binary(info, *left, *right, op)? } // literal op try_cast/cast(expr as data_type) // --> // try_cast/cast(expr as data_type) op_swap literal Expr::BinaryExpr(BinaryExpr { left, op, right }) - if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary( - info, &right, op, &left, - ) && op.supports_propagation() - && op.swap().is_some() => + if op.supports_propagation() + && op.swap().is_some_and(|swapped| { + supports_cast_predicate_for_binary(info, &right, swapped, &left) + }) => { - unwrap_cast_in_comparison_for_binary( + rewrite_cast_predicate_for_binary( info, *right, *left, @@ -1957,52 +1954,11 @@ impl TreeNodeRewriter for Simplifier<'_> { // For case: // try_cast/cast(expr as left_type) in (expr1,expr2,expr3) Expr::InList(InList { - expr: mut left, + expr, list, negated, - }) if is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist( - info, &left, &list, - ) => - { - let (Expr::TryCast(TryCast { - expr: left_expr, .. - }) - | Expr::Cast(Cast { - expr: left_expr, .. - })) = left.as_mut() - else { - return internal_err!("Expect cast expr, but got {:?}", left)?; - }; - - let expr_type = info.get_data_type(left_expr)?; - let right_exprs = list - .into_iter() - .map(|right| { - match right { - Expr::Literal(right_lit_value, _) => { - // if the right_lit_value can be casted to the type of internal_left_expr - // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal - let Some(value) = try_cast_literal_to_type(&right_lit_value, &expr_type) else { - internal_err!( - "Can't cast the list expr {:?} to type {}", - right_lit_value, &expr_type - )? - }; - Ok(lit(value)) - } - other_expr => internal_err!( - "Only support literal expr to optimize, but the expr is {:?}", - &other_expr - ), - } - }) - .collect::>>()?; - - Transformed::yes(Expr::InList(InList { - expr: std::mem::take(left_expr), - list: right_exprs, - negated, - })) + }) if supports_cast_predicate_for_inlist(info, &expr, &list) => { + rewrite_cast_predicate_for_inlist(info, *expr, list, negated)? } // ======================================= diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs b/datafusion/optimizer/src/simplify_expressions/mod.rs index e0b53b79d468c..3812f292e14c3 100644 --- a/datafusion/optimizer/src/simplify_expressions/mod.rs +++ b/datafusion/optimizer/src/simplify_expressions/mod.rs @@ -18,6 +18,7 @@ //! [`SimplifyExpressions`] simplifies expressions in the logical plan, //! [`ExprSimplifier`] simplifies individual `Expr`s. +mod cast_preimage; pub mod expr_simplifier; mod inlist_simplifier; mod linear_aggregates; @@ -27,7 +28,6 @@ pub mod simplify_exprs; pub mod simplify_literal; mod simplify_predicates; mod udf_preimage; -mod unwrap_cast; mod utils; // backwards compatibility diff --git a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs deleted file mode 100644 index a5b65d0d8e7a4..0000000000000 --- a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs +++ /dev/null @@ -1,665 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Unwrap casts in binary comparisons -//! -//! The functions in this module attempt to remove casts from -//! comparisons to literals ([`ScalarValue`]s) by applying the casts -//! to the literals if possible. It is inspired by the optimizer rule -//! `UnwrapCastInBinaryComparison` of Spark. -//! -//! Removing casts often improves performance because: -//! 1. The cast is done once (to the literal) rather than to every value -//! 2. Can enable other optimizations such as predicate pushdown that -//! don't support casting -//! -//! The rule is applied to expressions of the following forms: -//! -//! 1. `cast(left_expr as data_type) comparison_op literal_expr` -//! 2. `literal_expr comparison_op cast(left_expr as data_type)` -//! 3. `cast(literal_expr) IN (expr1, expr2, ...)` -//! 4. `literal_expr IN (cast(expr1) , cast(expr2), ...)` -//! -//! If the expression matches one of the forms above, the rule will -//! ensure the value of `literal` is in range(min, max) of the -//! expr's data_type, and if the scalar is within range, the literal -//! will be casted to the data type of expr on the other side, and the -//! cast will be removed from the other side. -//! -//! # Example -//! -//! If the DataType of c1 is INT32. Given the filter -//! -//! ```text -//! cast(c1 as INT64) > INT64(10)` -//! ``` -//! -//! This rule will remove the cast and rewrite the expression to: -//! -//! ```text -//! c1 > INT32(10) -//! ``` - -use arrow::datatypes::DataType; -use datafusion_common::{Result, ScalarValue}; -use datafusion_common::{internal_err, tree_node::Transformed}; -use datafusion_expr::{BinaryExpr, lit}; -use datafusion_expr::{Cast, Expr, Operator, TryCast, simplify::SimplifyContext}; -use datafusion_expr_common::casts::{is_supported_type, try_cast_literal_to_type}; - -pub(super) fn unwrap_cast_in_comparison_for_binary( - info: &SimplifyContext, - cast_expr: Expr, - literal: Expr, - op: Operator, -) -> Result> { - match (cast_expr, literal) { - ( - Expr::TryCast(TryCast { expr, .. }) | Expr::Cast(Cast { expr, .. }), - Expr::Literal(lit_value, _), - ) => { - let Ok(expr_type) = info.get_data_type(&expr) else { - return internal_err!("Can't get the data type of the expr {:?}", &expr); - }; - - if let Some(value) = cast_literal_to_type_with_op(&lit_value, &expr_type, op) - { - return Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { - left: expr, - op, - right: Box::new(lit(value)), - }))); - }; - - // if the lit_value can be casted to the type of internal_left_expr - // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal - let Some(value) = try_cast_literal_to_type(&lit_value, &expr_type) else { - return internal_err!( - "Can't cast the literal expr {:?} to type {}", - &lit_value, - &expr_type - ); - }; - Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { - left: expr, - op, - right: Box::new(lit(value)), - }))) - } - _ => internal_err!("Expect cast expr and literal"), - } -} - -pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary( - info: &SimplifyContext, - expr: &Expr, - op: Operator, - literal: &Expr, -) -> bool { - match (expr, literal) { - ( - Expr::TryCast(TryCast { - expr: left_expr, .. - }) - | Expr::Cast(Cast { - expr: left_expr, .. - }), - Expr::Literal(lit_val, _), - ) => { - let Ok(expr_type) = info.get_data_type(left_expr) else { - return false; - }; - - let Ok(lit_type) = info.get_data_type(literal) else { - return false; - }; - - if cast_literal_to_type_with_op(lit_val, &expr_type, op).is_some() { - return true; - } - - try_cast_literal_to_type(lit_val, &expr_type).is_some() - && is_supported_type(&expr_type) - && is_supported_type(&lit_type) - } - _ => false, - } -} - -pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist( - info: &SimplifyContext, - expr: &Expr, - list: &[Expr], -) -> bool { - let (Expr::TryCast(TryCast { - expr: left_expr, .. - }) - | Expr::Cast(Cast { - expr: left_expr, .. - })) = expr - else { - return false; - }; - - let Ok(expr_type) = info.get_data_type(left_expr) else { - return false; - }; - - if !is_supported_type(&expr_type) { - return false; - } - - for right in list { - let Ok(right_type) = info.get_data_type(right) else { - return false; - }; - - if !is_supported_type(&right_type) { - return false; - } - - match right { - Expr::Literal(lit_val, _) - if try_cast_literal_to_type(lit_val, &expr_type).is_some() => {} - _ => return false, - } - } - - true -} - -///// Tries to move a cast from an expression (such as column) to the literal other side of a comparison operator./ -/// -/// Specifically, rewrites -/// ```sql -/// cast(col) -/// ``` -/// -/// To -/// -/// ```sql -/// col cast() -/// col -/// ``` -fn cast_literal_to_type_with_op( - lit_value: &ScalarValue, - target_type: &DataType, - op: Operator, -) -> Option { - match (op, lit_value) { - ( - Operator::Eq | Operator::NotEq, - ScalarValue::Utf8(Some(_)) - | ScalarValue::Utf8View(Some(_)) - | ScalarValue::LargeUtf8(Some(_)), - ) => { - // Only try for integer types (TODO can we do this for other types - // like timestamps)? - use DataType::*; - if matches!( - target_type, - Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 - ) { - let casted = lit_value.cast_to(target_type).ok()?; - let round_tripped = casted.cast_to(&lit_value.data_type()).ok()?; - if lit_value != &round_tripped { - return None; - } - Some(casted) - } else { - None - } - } - _ => None, - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::collections::HashMap; - use std::sync::Arc; - - use crate::simplify_expressions::ExprSimplifier; - use arrow::datatypes::{Field, TimeUnit}; - use datafusion_common::{DFSchema, DFSchemaRef}; - use datafusion_expr::simplify::SimplifyContext; - use datafusion_expr::{cast, col, in_list, try_cast}; - - #[test] - fn test_not_unwrap_cast_comparison() { - let schema = expr_test_schema(); - // cast(INT32(c1), INT64) > INT64(c2) - let c1_gt_c2 = cast(col("c1"), DataType::Int64).gt(col("c2")); - assert_eq!(optimize_test(c1_gt_c2.clone(), &schema), c1_gt_c2); - - // INT32(c1) < INT32(16), the type is same - let expr_lt = col("c1").lt(lit(16i32)); - assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); - - // the 99999999999 is not within the range of MAX(int32) and MIN(int32), we don't cast the lit(99999999999) to int32 type - let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(99999999999i64)); - assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); - - // cast(c1, UTF8) < '123', only eq/not_eq should be optimized - let expr_lt = cast(col("c1"), DataType::Utf8).lt(lit("123")); - assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); - - // cast(c1, UTF8) = '0123', cast(cast('0123', Int32), UTF8) != '0123', so '0123' should not - // be casted - let expr_lt = cast(col("c1"), DataType::Utf8).lt(lit("0123")); - assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); - - // cast(c1, UTF8) = 'not a number', should not be able to cast to column type - let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("not a number")); - assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input); - - // cast(c1, UTF8) = '99999999999', where '99999999999' does not fit into int32, so it will - // not be optimized to integer comparison - let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("99999999999")); - assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input); - } - - #[test] - fn test_unwrap_cast_comparison() { - let schema = expr_test_schema(); - // cast(c1, INT64) < INT64(16) -> INT32(c1) < cast(INT32(16)) - // the 16 is within the range of MAX(int32) and MIN(int32), we can cast the 16 to int32(16) - let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(16i64)); - let expected = col("c1").lt(lit(16i32)); - assert_eq!(optimize_test(expr_lt, &schema), expected); - let expr_lt = try_cast(col("c1"), DataType::Int64).lt(lit(16i64)); - let expected = col("c1").lt(lit(16i32)); - assert_eq!(optimize_test(expr_lt, &schema), expected); - - // cast(c2, INT32) = INT32(16) => INT64(c2) = INT64(16) - let c2_eq_lit = cast(col("c2"), DataType::Int32).eq(lit(16i32)); - let expected = col("c2").eq(lit(16i64)); - assert_eq!(optimize_test(c2_eq_lit, &schema), expected); - - // cast(c1, INT64) < INT64(NULL) => NULL - let c1_lt_lit_null = cast(col("c1"), DataType::Int64).lt(null_i64()); - let expected = null_bool(); - assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected); - - // cast(INT8(NULL), INT32) < INT32(12) => INT8(NULL) < INT8(12) => BOOL(NULL) - let lit_lt_lit = cast(null_i8(), DataType::Int32).lt(lit(12i32)); - let expected = null_bool(); - assert_eq!(optimize_test(lit_lt_lit, &schema), expected); - - // cast(c1, UTF8) = '123' => c1 = 123 - let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("123")); - let expected = col("c1").eq(lit(123i32)); - assert_eq!(optimize_test(expr_input, &schema), expected); - - // cast(c1, UTF8) != '123' => c1 != 123 - let expr_input = cast(col("c1"), DataType::Utf8).not_eq(lit("123")); - let expected = col("c1").not_eq(lit(123i32)); - assert_eq!(optimize_test(expr_input, &schema), expected); - - // cast(c1, UTF8) = NULL => NULL - let expr_input = cast(col("c1"), DataType::Utf8).eq(lit(ScalarValue::Utf8(None))); - let expected = null_bool(); - assert_eq!(optimize_test(expr_input, &schema), expected); - } - - #[test] - fn test_unwrap_cast_comparison_unsigned() { - // "cast(c6, UINT64) = 0u64 => c6 = 0u32 - let schema = expr_test_schema(); - let expr_input = cast(col("c6"), DataType::UInt64).eq(lit(0u64)); - let expected = col("c6").eq(lit(0u32)); - assert_eq!(optimize_test(expr_input, &schema), expected); - - // cast(c6, UTF8) = "123" => c6 = 123 - let expr_input = cast(col("c6"), DataType::Utf8).eq(lit("123")); - let expected = col("c6").eq(lit(123u32)); - assert_eq!(optimize_test(expr_input, &schema), expected); - - // cast(c6, UTF8) != "123" => c6 != 123 - let expr_input = cast(col("c6"), DataType::Utf8).not_eq(lit("123")); - let expected = col("c6").not_eq(lit(123u32)); - assert_eq!(optimize_test(expr_input, &schema), expected); - } - - #[test] - fn test_unwrap_cast_comparison_string() { - let schema = expr_test_schema(); - let dict = ScalarValue::Dictionary( - Box::new(DataType::Int32), - Box::new(ScalarValue::from("value")), - ); - - // cast(str1 as Dictionary) = arrow_cast('value', 'Dictionary') => str1 = Utf8('value1') - let expr_input = cast(col("str1"), dict.data_type()).eq(lit(dict.clone())); - let expected = col("str1").eq(lit("value")); - assert_eq!(optimize_test(expr_input, &schema), expected); - - // cast(tag as Utf8) = Utf8('value') => tag = arrow_cast('value', 'Dictionary') - let expr_input = cast(col("tag"), DataType::Utf8).eq(lit("value")); - let expected = col("tag").eq(lit(dict.clone())); - assert_eq!(optimize_test(expr_input, &schema), expected); - - // Verify reversed argument order - // arrow_cast('value', 'Dictionary') = cast(str1 as Dictionary) => Utf8('value1') = str1 - let expr_input = lit(dict.clone()).eq(cast(col("str1"), dict.data_type())); - let expected = col("str1").eq(lit("value")); - assert_eq!(optimize_test(expr_input, &schema), expected); - } - - #[test] - fn test_unwrap_cast_comparison_large_string() { - let schema = expr_test_schema(); - // cast(largestr as Dictionary) = arrow_cast('value', 'Dictionary') => str1 = LargeUtf8('value1') - let dict = ScalarValue::Dictionary( - Box::new(DataType::Int32), - Box::new(ScalarValue::LargeUtf8(Some("value".to_owned()))), - ); - let expr_input = cast(col("largestr"), dict.data_type()).eq(lit(dict)); - let expected = - col("largestr").eq(lit(ScalarValue::LargeUtf8(Some("value".to_owned())))); - assert_eq!(optimize_test(expr_input, &schema), expected); - } - - #[test] - fn test_not_unwrap_cast_with_decimal_comparison() { - let schema = expr_test_schema(); - // integer to decimal: value is out of the bounds of the decimal - // cast(c3, INT64) = INT64(100000000000000000) - let expr_eq = cast(col("c3"), DataType::Int64).eq(lit(100000000000000000i64)); - assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq); - - // cast(c4, INT64) = INT64(1000) will overflow the i128 - let expr_eq = cast(col("c4"), DataType::Int64).eq(lit(1000i64)); - assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq); - - // decimal to decimal: value will lose the scale when convert to the target data type - // c3 = DECIMAL(12340,20,4) - let expr_eq = - cast(col("c3"), DataType::Decimal128(20, 4)).eq(lit_decimal(12340, 20, 4)); - assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq); - - // decimal to integer - // c1 = DECIMAL(123, 10, 1): value will lose the scale when convert to the target data type - let expr_eq = - cast(col("c1"), DataType::Decimal128(10, 1)).eq(lit_decimal(123, 10, 1)); - assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq); - - // c1 = DECIMAL(1230, 10, 2): value will lose the scale when convert to the target data type - let expr_eq = - cast(col("c1"), DataType::Decimal128(10, 2)).eq(lit_decimal(1230, 10, 2)); - assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq); - } - - #[test] - fn test_unwrap_cast_with_decimal_lit_comparison() { - let schema = expr_test_schema(); - // integer to decimal - // c3 < INT64(16) -> c3 < (CAST(INT64(16) AS DECIMAL(18,2)); - let expr_lt = try_cast(col("c3"), DataType::Int64).lt(lit(16i64)); - let expected = col("c3").lt(lit_decimal(1600, 18, 2)); - assert_eq!(optimize_test(expr_lt, &schema), expected); - - // c3 < INT64(NULL) - let c1_lt_lit_null = cast(col("c3"), DataType::Int64).lt(null_i64()); - let expected = null_bool(); - assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected); - - // decimal to decimal - // c3 < Decimal(123,10,0) -> c3 < CAST(DECIMAL(123,10,0) AS DECIMAL(18,2)) -> c3 < DECIMAL(12300,18,2) - let expr_lt = - cast(col("c3"), DataType::Decimal128(10, 0)).lt(lit_decimal(123, 10, 0)); - let expected = col("c3").lt(lit_decimal(12300, 18, 2)); - assert_eq!(optimize_test(expr_lt, &schema), expected); - - // c3 < Decimal(1230,10,3) -> c3 < CAST(DECIMAL(1230,10,3) AS DECIMAL(18,2)) -> c3 < DECIMAL(123,18,2) - let expr_lt = - cast(col("c3"), DataType::Decimal128(10, 3)).lt(lit_decimal(1230, 10, 3)); - let expected = col("c3").lt(lit_decimal(123, 18, 2)); - assert_eq!(optimize_test(expr_lt, &schema), expected); - - // decimal to integer - // c1 < Decimal(12300, 10, 2) -> c1 < CAST(DECIMAL(12300,10,2) AS INT32) -> c1 < INT32(123) - let expr_lt = - cast(col("c1"), DataType::Decimal128(10, 2)).lt(lit_decimal(12300, 10, 2)); - let expected = col("c1").lt(lit(123i32)); - assert_eq!(optimize_test(expr_lt, &schema), expected); - } - - #[test] - fn test_not_unwrap_list_cast_lit_comparison() { - let schema = expr_test_schema(); - // internal left type is not supported - // FLOAT32(C5) in ... - let expr_lt = - cast(col("c5"), DataType::Int64).in_list(vec![lit(12i64), lit(12i64)], false); - assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); - - // cast(INT32(C1), Float32) in (FLOAT32(1.23), Float32(12), Float32(12)) - let expr_lt = cast(col("c1"), DataType::Float32) - .in_list(vec![lit(12.0f32), lit(12.0f32), lit(1.23f32)], false); - assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); - - // INT32(C1) in (INT64(99999999999), INT64(12)) - let expr_lt = cast(col("c1"), DataType::Int64) - .in_list(vec![lit(12i32), lit(99999999999i64)], false); - assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); - - // DECIMAL(C3) in (INT64(12), INT32(12), DECIMAL(128,12,3)) - let expr_lt = cast(col("c3"), DataType::Decimal128(12, 3)).in_list( - vec![ - lit_decimal(12, 12, 3), - lit_decimal(12, 12, 3), - lit_decimal(128, 12, 3), - ], - false, - ); - assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); - } - - #[test] - fn test_unwrap_list_cast_comparison() { - let schema = expr_test_schema(); - // INT32(C1) IN (INT32(12),INT64(23),INT64(34),INT64(56),INT64(78)) -> - // INT32(C1) IN (INT32(12),INT32(23),INT32(34),INT32(56),INT32(78)) - let expr_lt = cast(col("c1"), DataType::Int64).in_list( - vec![lit(12i64), lit(23i64), lit(34i64), lit(56i64), lit(78i64)], - false, - ); - let expected = col("c1").in_list( - vec![lit(12i32), lit(23i32), lit(34i32), lit(56i32), lit(78i32)], - false, - ); - assert_eq!(optimize_test(expr_lt, &schema), expected); - // INT32(C2) IN (INT64(NULL),INT64(24),INT64(34),INT64(56),INT64(78)) -> - // INT32(C2) IN (INT32(NULL),INT32(24),INT32(34),INT32(56),INT32(78)) - let expr_lt = cast(col("c2"), DataType::Int32).in_list( - vec![null_i32(), lit(24i32), lit(34i64), lit(56i64), lit(78i64)], - false, - ); - let expected = col("c2").in_list( - vec![null_i64(), lit(24i64), lit(34i64), lit(56i64), lit(78i64)], - false, - ); - - assert_eq!(optimize_test(expr_lt, &schema), expected); - - // decimal test case - // c3 is decimal(18,2) - let expr_lt = cast(col("c3"), DataType::Decimal128(19, 3)).in_list( - vec![ - lit_decimal(12000, 19, 3), - lit_decimal(24000, 19, 3), - lit_decimal(1280, 19, 3), - lit_decimal(1240, 19, 3), - ], - false, - ); - let expected = col("c3").in_list( - vec![ - lit_decimal(1200, 18, 2), - lit_decimal(2400, 18, 2), - lit_decimal(128, 18, 2), - lit_decimal(124, 18, 2), - ], - false, - ); - assert_eq!(optimize_test(expr_lt, &schema), expected); - - // cast(INT32(12), INT64) IN (.....) => - // INT64(12) IN (INT64(12),INT64(13),INT64(14),INT64(15),INT64(16)) - // => true - let expr_lt = cast(lit(12i32), DataType::Int64).in_list( - vec![lit(12i64), lit(13i64), lit(14i64), lit(15i64), lit(16i64)], - false, - ); - let expected = lit(true); - assert_eq!(optimize_test(expr_lt, &schema), expected); - } - - #[test] - fn aliased() { - let schema = expr_test_schema(); - // c1 < INT64(16) -> c1 < cast(INT32(16)) - // the 16 is within the range of MAX(int32) and MIN(int32), we can cast the 16 to int32(16) - let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(16i64)).alias("x"); - let expected = col("c1").lt(lit(16i32)).alias("x"); - assert_eq!(optimize_test(expr_lt, &schema), expected); - } - - #[test] - fn nested() { - let schema = expr_test_schema(); - // c1 < INT64(16) OR c1 > INT64(32) -> c1 < INT32(16) OR c1 > INT32(32) - // the 16 and 32 are within the range of MAX(int32) and MIN(int32), we can cast them to int32 - let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(16i64)).or(cast( - col("c1"), - DataType::Int64, - ) - .gt(lit(32i64))); - let expected = col("c1").lt(lit(16i32)).or(col("c1").gt(lit(32i32))); - assert_eq!(optimize_test(expr_lt, &schema), expected); - } - - #[test] - fn test_not_support_data_type() { - // "c6 > 0" will be cast to `cast(c6 as float) > 0 - // but the type of c6 is uint32 - // the rewriter will not throw error and just return the original expr - let schema = expr_test_schema(); - let expr_input = cast(col("c6"), DataType::Float64).eq(lit(0f64)); - assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input); - - // inlist for unsupported data type - let expr_input = in_list( - cast(col("c6"), DataType::Float64), - // need more literals to avoid rewriting to binary expr - vec![lit(0f64), lit(1f64), lit(2f64), lit(3f64), lit(4f64)], - false, - ); - assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input); - } - - #[test] - /// Basic integration test for unwrapping casts with different timezones - fn test_unwrap_cast_with_timestamp_nanos() { - let schema = expr_test_schema(); - // cast(ts_nano as Timestamp(Nanosecond, UTC)) < 1666612093000000000::Timestamp(Nanosecond, Utc)) - let expr_lt = try_cast(col("ts_nano_none"), timestamp_nano_utc_type()) - .lt(lit_timestamp_nano_utc(1666612093000000000)); - let expected = - col("ts_nano_none").lt(lit_timestamp_nano_none(1666612093000000000)); - assert_eq!(optimize_test(expr_lt, &schema), expected); - } - - fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr { - let simplifier = ExprSimplifier::new( - SimplifyContext::builder() - .with_schema(Arc::clone(schema)) - .build(), - ); - - simplifier.simplify(expr).unwrap() - } - - fn expr_test_schema() -> DFSchemaRef { - Arc::new( - DFSchema::from_unqualified_fields( - vec![ - Field::new("c1", DataType::Int32, false), - Field::new("c2", DataType::Int64, false), - Field::new("c3", DataType::Decimal128(18, 2), false), - Field::new("c4", DataType::Decimal128(38, 37), false), - Field::new("c5", DataType::Float32, false), - Field::new("c6", DataType::UInt32, false), - Field::new("ts_nano_none", timestamp_nano_none_type(), false), - Field::new("ts_nano_utf", timestamp_nano_utc_type(), false), - Field::new("str1", DataType::Utf8, false), - Field::new("largestr", DataType::LargeUtf8, false), - Field::new("tag", dictionary_tag_type(), false), - ] - .into(), - HashMap::new(), - ) - .unwrap(), - ) - } - - fn null_bool() -> Expr { - lit(ScalarValue::Boolean(None)) - } - - fn null_i8() -> Expr { - lit(ScalarValue::Int8(None)) - } - - fn null_i32() -> Expr { - lit(ScalarValue::Int32(None)) - } - - fn null_i64() -> Expr { - lit(ScalarValue::Int64(None)) - } - - fn lit_decimal(value: i128, precision: u8, scale: i8) -> Expr { - lit(ScalarValue::Decimal128(Some(value), precision, scale)) - } - - fn lit_timestamp_nano_none(ts: i64) -> Expr { - lit(ScalarValue::TimestampNanosecond(Some(ts), None)) - } - - fn lit_timestamp_nano_utc(ts: i64) -> Expr { - let utc = Some("+0:00".into()); - lit(ScalarValue::TimestampNanosecond(Some(ts), utc)) - } - - fn timestamp_nano_none_type() -> DataType { - DataType::Timestamp(TimeUnit::Nanosecond, None) - } - - // this is the type that now() returns - fn timestamp_nano_utc_type() -> DataType { - let utc = Some("+0:00".into()); - DataType::Timestamp(TimeUnit::Nanosecond, utc) - } - - // a dictionary type for storing string tags - fn dictionary_tag_type() -> DataType { - DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)) - } -} diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index 6fad39dc33d9f..9528ccd17b03d 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -795,7 +795,7 @@ fn extension_node_does_not_block_projection_pruning() -> Result<()> { Projection: t.a, CAST(t.ts AS Timestamp(ms, "UTC")) AS ts Filter: __common_expr_3 > TimestampMillisecond(1000, Some("UTC")) AND __common_expr_3 < TimestampMillisecond(2000, Some("UTC")) Projection: CAST(t.ts AS Timestamp(ms, "UTC")) AS __common_expr_3, t.a, t.ts - TableScan: t projection=[a, ts], partial_filters=[t.ts > TimestampNanosecond(1000000000, None), t.ts < TimestampNanosecond(2000000000, None), CAST(t.ts AS Timestamp(ms, "UTC")) > TimestampMillisecond(1000, Some("UTC")), CAST(t.ts AS Timestamp(ms, "UTC")) < TimestampMillisecond(2000, Some("UTC"))] + TableScan: t projection=[a, ts], partial_filters=[t.ts >= TimestampNanosecond(1001000000, None), t.ts < TimestampNanosecond(2000000000, None), CAST(t.ts AS Timestamp(ms, "UTC")) > TimestampMillisecond(1000, Some("UTC")), CAST(t.ts AS Timestamp(ms, "UTC")) < TimestampMillisecond(2000, Some("UTC"))] "#, ); diff --git a/datafusion/physical-expr/src/simplifier/unwrap_cast.rs b/datafusion/physical-expr/src/simplifier/unwrap_cast.rs index 4f4dfb2c20a81..f864f2d545bbb 100644 --- a/datafusion/physical-expr/src/simplifier/unwrap_cast.rs +++ b/datafusion/physical-expr/src/simplifier/unwrap_cast.rs @@ -36,10 +36,12 @@ use std::sync::Arc; use arrow::datatypes::{DataType, Schema}; use datafusion_common::{Result, ScalarValue, tree_node::Transformed}; use datafusion_expr::Operator; -use datafusion_expr_common::casts::try_cast_literal_to_type; +use datafusion_expr_common::casts::{CastPredicatePreimage, cast_predicate_preimage}; use crate::PhysicalExpr; -use crate::expressions::{BinaryExpr, CastExpr, Literal, TryCastExpr, lit}; +use crate::expressions::{ + BinaryExpr, CastExpr, Literal, TryCastExpr, is_not_null, is_null, lit, +}; /// Attempts to unwrap casts in comparison expressions. pub(crate) fn unwrap_cast_in_comparison( @@ -60,12 +62,13 @@ fn try_unwrap_cast_binary( schema: &Schema, ) -> Result>> { // Case 1: cast(left_expr) op literal - if let (Some((inner_expr, _cast_type)), Some(literal)) = ( + if let (Some((inner_expr, cast_type)), Some(literal)) = ( extract_cast_info(binary.left()), binary.right().downcast_ref::(), ) && binary.op().supports_propagation() && let Some(unwrapped) = try_unwrap_cast_comparison( Arc::clone(inner_expr), + cast_type, literal.value(), *binary.op(), schema, @@ -75,7 +78,7 @@ fn try_unwrap_cast_binary( } // Case 2: literal op cast(right_expr) - if let (Some(literal), Some((inner_expr, _cast_type))) = ( + if let (Some(literal), Some((inner_expr, cast_type))) = ( binary.left().downcast_ref::(), extract_cast_info(binary.right()), ) { @@ -84,6 +87,7 @@ fn try_unwrap_cast_binary( && binary.op().supports_propagation() && let Some(unwrapped) = try_unwrap_cast_comparison( Arc::clone(inner_expr), + cast_type, literal.value(), swapped_op, schema, @@ -117,6 +121,7 @@ fn extract_cast_info( /// Try to unwrap a cast in comparison by moving the cast to the literal fn try_unwrap_cast_comparison( inner_expr: Arc, + cast_type: &DataType, literal_value: &ScalarValue, op: Operator, schema: &Schema, @@ -124,21 +129,79 @@ fn try_unwrap_cast_comparison( // Get the data type of the inner expression let inner_type = inner_expr.data_type(schema)?; - // Try to cast the literal to the inner expression's type - if let Some(casted_literal) = try_cast_literal_to_type(literal_value, &inner_type) { - let literal_expr = lit(casted_literal); - let binary_expr = BinaryExpr::new(inner_expr, op, literal_expr); - return Ok(Some(Arc::new(binary_expr))); + match cast_predicate_preimage(&inner_type, cast_type, op, literal_value)? { + Some(CastPredicatePreimage::Exact(casted_literal)) => { + let literal_expr = lit(casted_literal); + let binary_expr = BinaryExpr::new(inner_expr, op, literal_expr); + Ok(Some(Arc::new(binary_expr))) + } + Some(CastPredicatePreimage::Range(interval)) => { + rewrite_with_preimage(interval, op, inner_expr).map(Some) + } + None => Ok(None), } +} - Ok(None) +fn rewrite_with_preimage( + interval: datafusion_expr_common::interval_arithmetic::Interval, + op: Operator, + expr: Arc, +) -> Result> { + let (lower, upper) = interval.into_bounds(); + let (lower, upper) = (lit(lower), lit(upper)); + + let rewritten_expr = match op { + Operator::Lt => binary(Arc::clone(&expr), Operator::Lt, lower), + Operator::GtEq => binary(Arc::clone(&expr), Operator::GtEq, lower), + Operator::Gt => binary(Arc::clone(&expr), Operator::GtEq, upper), + Operator::LtEq => binary(Arc::clone(&expr), Operator::Lt, upper), + Operator::Eq => binary( + binary(Arc::clone(&expr), Operator::GtEq, lower), + Operator::And, + binary(expr, Operator::Lt, upper), + ), + Operator::NotEq => binary( + binary(Arc::clone(&expr), Operator::Lt, lower), + Operator::Or, + binary(expr, Operator::GtEq, upper), + ), + Operator::IsNotDistinctFrom => binary( + binary( + is_not_null(Arc::clone(&expr))?, + Operator::And, + binary(Arc::clone(&expr), Operator::GtEq, lower), + ), + Operator::And, + binary(expr, Operator::Lt, upper), + ), + Operator::IsDistinctFrom => binary( + binary( + binary(Arc::clone(&expr), Operator::Lt, lower), + Operator::Or, + binary(Arc::clone(&expr), Operator::GtEq, upper), + ), + Operator::Or, + is_null(expr)?, + ), + _ => unreachable!("preimage only supports comparison operators"), + }; + + Ok(rewritten_expr) +} + +fn binary( + left: Arc, + op: Operator, + right: Arc, +) -> Arc { + Arc::new(BinaryExpr::new(left, op, right)) } #[cfg(test)] mod tests { use super::*; use crate::expressions::col; - use arrow::datatypes::Field; + use arrow::datatypes::{Field, TimeUnit}; use datafusion_common::tree_node::TreeNode; /// Check if an expression is a cast expression @@ -548,6 +611,242 @@ mod tests { assert!(!result.transformed); } + #[test] + fn test_timestamp_precision_narrowing_range_preimage_gt() { + let schema = Schema::new(vec![Field::new( + "ts", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + )]); + + let column_expr = col("ts", &schema).unwrap(); + let cast_expr = Arc::new(CastExpr::new( + column_expr, + DataType::Timestamp(TimeUnit::Millisecond, None), + None, + )); + let literal_expr = lit(ScalarValue::TimestampMillisecond(Some(1000), None)); + let binary_expr = + Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr)); + + let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap(); + + assert!(result.transformed); + let optimized_binary = result.data.downcast_ref::().unwrap(); + assert_eq!(*optimized_binary.op(), Operator::GtEq); + assert!(!is_cast_expr(optimized_binary.left())); + let right_literal = optimized_binary.right().downcast_ref::().unwrap(); + assert_eq!( + right_literal.value(), + &ScalarValue::TimestampNanosecond(Some(1_001_000_000), None) + ); + } + + #[test] + fn test_timestamp_precision_narrowing_range_preimage_eq() { + let schema = Schema::new(vec![Field::new( + "ts", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + )]); + + let column_expr = col("ts", &schema).unwrap(); + let cast_expr = Arc::new(CastExpr::new( + column_expr, + DataType::Timestamp(TimeUnit::Millisecond, None), + None, + )); + let literal_expr = lit(ScalarValue::TimestampMillisecond(Some(-1), None)); + let binary_expr = + Arc::new(BinaryExpr::new(cast_expr, Operator::Eq, literal_expr)); + + let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap(); + + assert!(result.transformed); + let and_binary = result.data.downcast_ref::().unwrap(); + assert_eq!(*and_binary.op(), Operator::And); + + let lower_binary = and_binary.left().downcast_ref::().unwrap(); + assert_eq!(*lower_binary.op(), Operator::GtEq); + let lower_literal = lower_binary.right().downcast_ref::().unwrap(); + assert_eq!( + lower_literal.value(), + &ScalarValue::TimestampNanosecond(Some(-1_999_999), None) + ); + + let upper_binary = and_binary.right().downcast_ref::().unwrap(); + assert_eq!(*upper_binary.op(), Operator::Lt); + let upper_literal = upper_binary.right().downcast_ref::().unwrap(); + assert_eq!( + upper_literal.value(), + &ScalarValue::TimestampNanosecond(Some(-999_999), None) + ); + } + + #[test] + fn test_timestamp_widening_exactness() { + let schema = Schema::new(vec![Field::new( + "ts", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + )]); + + let column_expr = col("ts", &schema).unwrap(); + let cast_expr = Arc::new(CastExpr::new( + column_expr, + DataType::Timestamp(TimeUnit::Nanosecond, None), + None, + )); + let literal_expr = lit(ScalarValue::TimestampNanosecond(Some(123_000_000), None)); + let binary_expr = + Arc::new(BinaryExpr::new(cast_expr, Operator::GtEq, literal_expr)); + + let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap(); + + assert!(result.transformed); + let optimized_binary = result.data.downcast_ref::().unwrap(); + assert_eq!(*optimized_binary.op(), Operator::GtEq); + assert!(!is_cast_expr(optimized_binary.left())); + let right_literal = optimized_binary.right().downcast_ref::().unwrap(); + assert_eq!( + right_literal.value(), + &ScalarValue::TimestampMillisecond(Some(123), None) + ); + + let column_expr = col("ts", &schema).unwrap(); + let cast_expr = Arc::new(CastExpr::new( + column_expr, + DataType::Timestamp(TimeUnit::Nanosecond, None), + None, + )); + let literal_expr = lit(ScalarValue::TimestampNanosecond(Some(123_456_789), None)); + let binary_expr = + Arc::new(BinaryExpr::new(cast_expr, Operator::GtEq, literal_expr)); + + let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap(); + assert!(!result.transformed); + } + + #[test] + fn test_timestamp_precision_narrowing_range_preimage_is_distinct_from() { + let schema = Schema::new(vec![Field::new( + "ts", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + )]); + + let column_expr = col("ts", &schema).unwrap(); + let cast_expr = Arc::new(CastExpr::new( + column_expr, + DataType::Timestamp(TimeUnit::Millisecond, None), + None, + )); + let literal_expr = lit(ScalarValue::TimestampMillisecond(Some(1000), None)); + let binary_expr = Arc::new(BinaryExpr::new( + cast_expr, + Operator::IsDistinctFrom, + literal_expr, + )); + + let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap(); + + assert!(result.transformed); + + // Expected: OR( OR(expr < lower, expr >= upper), IS NULL(expr) ) + let outer_or = result.data.downcast_ref::().unwrap(); + assert_eq!(*outer_or.op(), Operator::Or); + + // Right side of outer OR → IS NULL + assert!( + outer_or + .right() + .downcast_ref::() + .is_some() + ); + + // Left side of outer OR → OR(expr < lower, expr >= upper) + let inner_or = outer_or.left().downcast_ref::().unwrap(); + assert_eq!(*inner_or.op(), Operator::Or); + + // Left-left: expr < lower + let lt_binary = inner_or.left().downcast_ref::().unwrap(); + assert_eq!(*lt_binary.op(), Operator::Lt); + let lt_literal = lt_binary.right().downcast_ref::().unwrap(); + assert_eq!( + lt_literal.value(), + &ScalarValue::TimestampNanosecond(Some(1_000_000_000), None) + ); + + // Left-right: expr >= upper + let gte_binary = inner_or.right().downcast_ref::().unwrap(); + assert_eq!(*gte_binary.op(), Operator::GtEq); + let gte_literal = gte_binary.right().downcast_ref::().unwrap(); + assert_eq!( + gte_literal.value(), + &ScalarValue::TimestampNanosecond(Some(1_001_000_000), None) + ); + } + + #[test] + fn test_timestamp_precision_narrowing_range_preimage_is_not_distinct_from() { + let schema = Schema::new(vec![Field::new( + "ts", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + )]); + + let column_expr = col("ts", &schema).unwrap(); + let cast_expr = Arc::new(CastExpr::new( + column_expr, + DataType::Timestamp(TimeUnit::Millisecond, None), + None, + )); + let literal_expr = lit(ScalarValue::TimestampMillisecond(Some(1000), None)); + let binary_expr = Arc::new(BinaryExpr::new( + cast_expr, + Operator::IsNotDistinctFrom, + literal_expr, + )); + + let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap(); + + assert!(result.transformed); + + // Expected: AND( AND(IS NOT NULL(expr), expr >= lower), expr < upper ) + let outer_and = result.data.downcast_ref::().unwrap(); + assert_eq!(*outer_and.op(), Operator::And); + + // Right side of outer AND → expr < upper + let upper_binary = outer_and.right().downcast_ref::().unwrap(); + assert_eq!(*upper_binary.op(), Operator::Lt); + let upper_literal = upper_binary.right().downcast_ref::().unwrap(); + assert_eq!( + upper_literal.value(), + &ScalarValue::TimestampNanosecond(Some(1_001_000_000), None) + ); + + // Left side of outer AND → AND(IS NOT NULL(expr), expr >= lower) + let inner_and = outer_and.left().downcast_ref::().unwrap(); + assert_eq!(*inner_and.op(), Operator::And); + + // Left-left: IS NOT NULL + assert!( + inner_and + .left() + .downcast_ref::() + .is_some() + ); + + // Left-right: expr >= lower + let gte_binary = inner_and.right().downcast_ref::().unwrap(); + assert_eq!(*gte_binary.op(), Operator::GtEq); + let gte_literal = gte_binary.right().downcast_ref::().unwrap(); + assert_eq!( + gte_literal.value(), + &ScalarValue::TimestampNanosecond(Some(1_000_000_000), None) + ); + } + #[test] fn test_complex_nested_expression() { let schema = test_schema(); diff --git a/datafusion/sqllogictest/test_files/simplify_expr.slt b/datafusion/sqllogictest/test_files/simplify_expr.slt index 58ec7a1b262c3..67b2c17f2372f 100644 --- a/datafusion/sqllogictest/test_files/simplify_expr.slt +++ b/datafusion/sqllogictest/test_files/simplify_expr.slt @@ -146,3 +146,68 @@ logical_plan physical_plan 01)ProjectionExec: expr=[column1@0 = 1 as opt1, column1@0 = 2 AND column1@0 != 2 as noopt1, column1@0 = 4 as opt2, column1@0 != 5 AND column1@0 = 5 as noopt2] 02)--DataSourceExec: partitions=1, partition_sizes=[1] + +# Cast predicate preimage rewrites for timestamp precision casts +statement ok +CREATE TABLE cast_preimage_ts AS +SELECT + arrow_cast(0, 'Timestamp(ns)') AS ts_nano, + arrow_cast(0, 'Timestamp(ms)') AS ts_milli; + +query TT +EXPLAIN SELECT * FROM cast_preimage_ts +WHERE CAST(ts_nano AS TIMESTAMP(3)) > arrow_cast(1000, 'Timestamp(ms)'); +---- +logical_plan +01)Filter: cast_preimage_ts.ts_nano >= TimestampNanosecond(1001000000, None) +02)--TableScan: cast_preimage_ts projection=[ts_nano, ts_milli] +physical_plan +01)FilterExec: ts_nano@0 >= 1001000000 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +EXPLAIN SELECT * FROM cast_preimage_ts +WHERE CAST(ts_nano AS TIMESTAMP(3)) IS DISTINCT FROM arrow_cast(0, 'Timestamp(ms)'); +---- +logical_plan +01)Filter: cast_preimage_ts.ts_nano < TimestampNanosecond(-999999, None) OR cast_preimage_ts.ts_nano >= TimestampNanosecond(1000000, None) +02)--TableScan: cast_preimage_ts projection=[ts_nano, ts_milli] +physical_plan +01)FilterExec: ts_nano@0 < -999999 OR ts_nano@0 >= 1000000 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +EXPLAIN SELECT * FROM cast_preimage_ts +WHERE CAST(ts_milli AS TIMESTAMP(9)) >= arrow_cast(123000000, 'Timestamp(ns)'); +---- +logical_plan +01)Filter: cast_preimage_ts.ts_milli >= TimestampMillisecond(123, None) +02)--TableScan: cast_preimage_ts projection=[ts_nano, ts_milli] +physical_plan +01)FilterExec: ts_milli@1 >= 123 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +EXPLAIN SELECT * FROM cast_preimage_ts +WHERE CAST(ts_milli AS TIMESTAMP(9)) >= arrow_cast(123456789, 'Timestamp(ns)'); +---- +logical_plan +01)Filter: CAST(cast_preimage_ts.ts_milli AS Timestamp(ns)) >= TimestampNanosecond(123456789, None) +02)--TableScan: cast_preimage_ts projection=[ts_nano, ts_milli] +physical_plan +01)FilterExec: CAST(ts_milli@1 AS Timestamp(ns)) >= 123456789 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +EXPLAIN SELECT * FROM cast_preimage_ts +WHERE CAST(ts_nano AS TIMESTAMP(3)) = arrow_cast(-1, 'Timestamp(ms)'); +---- +logical_plan +01)Filter: cast_preimage_ts.ts_nano >= TimestampNanosecond(-1999999, None) AND cast_preimage_ts.ts_nano < TimestampNanosecond(-999999, None) +02)--TableScan: cast_preimage_ts projection=[ts_nano, ts_milli] +physical_plan +01)FilterExec: ts_nano@0 >= -1999999 AND ts_nano@0 < -999999 +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +statement ok +DROP TABLE cast_preimage_ts;