Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions vortex-layout/src/layouts/row_idx/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,26 +64,27 @@ impl RowIdxLayoutReader {
}
}

fn partition_expr(&self, expr: &Expression) -> Partitioning {
fn partition_expr(&self, expr: &Expression) -> VortexResult<Partitioning> {
let key = ExactExpr(expr.clone());

// Check cache first with read-only lock.
if let Some(entry) = self.partition_cache.get(&key)
&& let Some(partitioning) = entry.value().get()
{
return partitioning.clone();
return Ok(partitioning.clone());
}

let cell = self
.partition_cache
let result = self.compute_partitioning(expr)?;

self.partition_cache
.entry(key)
.or_insert_with(|| Arc::new(OnceLock::new()))
.clone();
.get_or_init(|| result.clone());

cell.get_or_init(|| self.compute_partitioning(expr)).clone()
Ok(result)
}

fn compute_partitioning(&self, expr: &Expression) -> Partitioning {
fn compute_partitioning(&self, expr: &Expression) -> VortexResult<Partitioning> {
// Partition the expression into row idx and child expressions.
let mut partitioned = partition(expr.clone(), self.dtype(), |expr| {
if expr.is::<RowIdx>() {
Expand All @@ -93,17 +94,16 @@ impl RowIdxLayoutReader {
} else {
vec![]
}
})
.vortex_expect("We should not fail to partition expression over struct fields");
})?;

// If there's only a single partition, we can directly return the expression.
if partitioned.partitions.len() == 1 {
return match &partitioned.partition_annotations[0] {
return Ok(match &partitioned.partition_annotations[0] {
Partition::RowIdx => {
Partitioning::RowIdx(replace(expr.clone(), &row_idx(), root()))
}
Partition::Child => Partitioning::Child(expr.clone()),
};
});
}

// Replace the row_idx expression with the root expression in the row_idx partition.
Expand All @@ -113,7 +113,7 @@ impl RowIdxLayoutReader {
.map(|p| replace(p, &row_idx(), root()))
.collect();

Partitioning::Partitioned(Arc::new(partitioned))
Ok(Partitioning::Partitioned(Arc::new(partitioned)))
}
}

Expand Down Expand Up @@ -182,7 +182,7 @@ impl LayoutReader for RowIdxLayoutReader {
expr: &Expression,
mask: Mask,
) -> VortexResult<MaskFuture> {
Ok(match &self.partition_expr(expr) {
Ok(match &self.partition_expr(expr)? {
Partitioning::RowIdx(expr) => row_idx_mask_future(
self.row_offset,
row_range,
Expand All @@ -201,7 +201,7 @@ impl LayoutReader for RowIdxLayoutReader {
expr: &Expression,
mask: MaskFuture,
) -> VortexResult<MaskFuture> {
match &self.partition_expr(expr) {
match &self.partition_expr(expr)? {
// Since this is run during pruning, we skip re-evaluating the row index expression
// during the filter evaluation.
Partitioning::RowIdx(_) => Ok(mask),
Expand Down Expand Up @@ -239,7 +239,7 @@ impl LayoutReader for RowIdxLayoutReader {
expr: &Expression,
mask: MaskFuture,
) -> VortexResult<BoxFuture<'static, VortexResult<ArrayRef>>> {
match &self.partition_expr(expr) {
match &self.partition_expr(expr)? {
Partitioning::RowIdx(expr) => Ok(row_idx_array_future(
self.row_offset,
row_range,
Expand Down
84 changes: 65 additions & 19 deletions vortex-layout/src/layouts/struct_/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,32 +152,30 @@ impl StructReader {
}

/// Utility for partitioning an expression over the fields of a struct.
fn partition_expr(&self, expr: Expression) -> Partitioned {
fn partition_expr(&self, expr: Expression) -> VortexResult<Partitioned> {
let key = ExactExpr(expr.clone());

if let Some(entry) = self.partitioned_expr_cache.get(&key)
&& let Some(partitioning) = entry.value().get()
{
return partitioning.clone();
return Ok(partitioning.clone());
}

let cell = self
.partitioned_expr_cache
let result = self.compute_partitioned_expr(expr)?;

self.partitioned_expr_cache
.entry(key)
.or_insert_with(|| Arc::new(OnceLock::new()))
.clone();
.get_or_init(|| result.clone());

cell.get_or_init(|| self.compute_partitioned_expr(expr))
.clone()
Ok(result)
}

fn compute_partitioned_expr(&self, expr: Expression) -> Partitioned {
fn compute_partitioned_expr(&self, expr: Expression) -> VortexResult<Partitioned> {
// First, we expand the root scope into the fields of the struct to ensure
// that partitioning works correctly.
let expr = replace(expr, &root(), self.expanded_root_expr.clone());
let expr = expr
.optimize_recursive(self.dtype())
.vortex_expect("We should not fail to simplify expression over struct fields");
let expr = expr.optimize_recursive(self.dtype())?;

// Partition the expression into expressions that can be evaluated over individual fields
let mut partitioned = partition(
Expand All @@ -188,16 +186,15 @@ impl StructReader {
.as_struct_fields_opt()
.vortex_expect("We know it's a struct DType"),
),
)
.vortex_expect("We should not fail to partition expression over struct fields");
)?;

if partitioned.partitions.len() == 1 {
// If there's only one partition, we step into the field scope of the original
// expression by replacing any `$.a` with `$`.
return Partitioned::Single(
return Ok(Partitioned::Single(
partitioned.partition_names[0].clone(),
replace(expr, &col(partitioned.partition_names[0].clone()), root()),
);
));
}

// We now need to process the partitioned expressions to rewrite the root scope
Expand All @@ -210,7 +207,7 @@ impl StructReader {
.map(|(e, name)| replace(e.clone(), &col(name.clone()), root()))
.collect();

Partitioned::Multi(Arc::new(partitioned))
Ok(Partitioned::Multi(Arc::new(partitioned)))
}
}

Expand Down Expand Up @@ -265,7 +262,7 @@ impl LayoutReader for StructReader {
mask: Mask,
) -> VortexResult<MaskFuture> {
// Partition the expression into expressions that can be evaluated over individual fields
match &self.partition_expr(expr.clone()) {
match &self.partition_expr(expr.clone())? {
Partitioned::Single(name, partition) => self
.field_reader(name)?
.pruning_evaluation(row_range, partition, mask)
Expand All @@ -287,7 +284,7 @@ impl LayoutReader for StructReader {
mask: MaskFuture,
) -> VortexResult<MaskFuture> {
// Partition the expression into expressions that can be evaluated over individual fields
match &self.partition_expr(expr.clone()) {
match &self.partition_expr(expr.clone())? {
Partitioned::Single(name, partition) => self
.field_reader(name)?
.filter_evaluation(row_range, partition, mask)
Expand Down Expand Up @@ -329,7 +326,7 @@ impl LayoutReader for StructReader {
.transpose()?;

// Partition the expression into expressions that can be evaluated over individual fields
let (projected, is_pack_merge) = match &self.partition_expr(expr.clone()) {
let (projected, is_pack_merge) = match &self.partition_expr(expr.clone())? {
Partitioned::Single(name, partition) => (
self.field_reader(name)?
.projection_evaluation(row_range, partition, mask_fut)
Expand Down Expand Up @@ -817,4 +814,53 @@ mod tests {

assert_eq!(result.len(), 5);
}

/// Regression test for https://github.com/vortex-data/vortex/issues/7808
///
/// A filter expression whose DType is incompatible with the scanned schema
/// (e.g. comparing a u8 column to an i32 literal) must return an error, not panic.
#[test]
fn test_struct_filter_dtype_mismatch_returns_error() {
let ctx = ArrayContext::empty();
let segments = Arc::new(TestSegments::default());
let (ptr, eof) = SequenceId::root().split();
let strategy = TableStrategy::new(
Arc::new(FlatLayoutStrategy::default()),
Arc::new(FlatLayoutStrategy::default()),
);
let segments2 = Arc::<TestSegments>::clone(&segments);
let layout = block_on(|handle| async move {
let session = SESSION.clone().with_handle(handle);
strategy
.write_stream(
ctx,
segments2,
StructArray::from_fields(
[
("age", buffer![7u8, 2, 3].into_array()),
("score", buffer![4u8, 5, 6].into_array()),
]
.as_slice(),
)
.unwrap()
.into_array()
.to_array_stream()
.sequenced(ptr),
eof,
&session,
)
.await
})
.unwrap();

let reader = layout.new_reader("".into(), segments, &SESSION).unwrap();

// DType mismatch: "age" is u8 but literal is i32
let filt = eq(col("age"), lit(67i32));

let result = reader.filter_evaluation(&(0..3), &filt, MaskFuture::new_true(3));
assert!(result.is_err());
let err = result.err().unwrap().to_string();
assert!(err.contains("Cannot compare different DTypes"), "{err}");
}
}
Loading