Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
22e580f
feat(bigframes): Add substrait-datafusion engine
TrevorBergeron Apr 3, 2026
27c2a2e
add more ops to substrait compiler
TrevorBergeron Apr 3, 2026
4cf7c93
more ops, types
TrevorBergeron Apr 6, 2026
5d49145
support more ops to substrait
TrevorBergeron May 19, 2026
f3bf064
more work
TrevorBergeron May 20, 2026
82543d6
simplify, add acero engine
TrevorBergeron Jun 5, 2026
9fd9582
ruff
TrevorBergeron Jun 5, 2026
c64d1ce
shift to engine tests, fix issues
TrevorBergeron Jun 6, 2026
943421a
more tests, casting fix
TrevorBergeron Jun 8, 2026
f228079
fix comparison op issues
TrevorBergeron Jun 8, 2026
317be0b
setup deps for new engines
TrevorBergeron Jun 9, 2026
f5f944a
use ctx.from_arrow rather than from_arrow_table
TrevorBergeron Jun 9, 2026
7ffdc36
fix boolean binops
TrevorBergeron Jun 16, 2026
d960f13
ruff
TrevorBergeron Jun 16, 2026
7214cae
add agg/window support
TrevorBergeron Jun 23, 2026
d081bec
ruff
TrevorBergeron Jun 23, 2026
f0da5a5
fixes
TrevorBergeron Jun 24, 2026
7735dc3
update protobuf constraint
TrevorBergeron Jun 24, 2026
74103fe
force newer connection lib
TrevorBergeron Jun 24, 2026
cb7881f
force newer gcf client
TrevorBergeron Jun 24, 2026
cb07111
update more gcp clients min v
TrevorBergeron Jun 24, 2026
6e2fea3
update constraints
TrevorBergeron Jun 24, 2026
01a935e
even newer clients!
TrevorBergeron Jun 24, 2026
e14c13a
update constraints
TrevorBergeron Jun 24, 2026
d40bbe4
update resource manager min version
TrevorBergeron Jun 24, 2026
f7db2e4
cleanup some cast stuff
TrevorBergeron Jun 24, 2026
2350d62
casting fixes
TrevorBergeron Jun 25, 2026
81ade4d
fix dialect passing
TrevorBergeron Jun 26, 2026
bf1f5c6
fix bool numeric cast
TrevorBergeron Jun 27, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import dataclasses
from typing import cast
from typing import Literal, cast

import numpy as np
import pandas as pd
Expand All @@ -36,14 +36,17 @@
@dataclasses.dataclass
class CoerceArgsRule(op_lowering.OpLoweringRule):
op_type: type[ops.BinaryOp]
dialect: Literal["polars", "substrait"]

@property
def op(self) -> type[ops.ScalarOp]:
return self.op_type

def lower(self, expr: expression.OpExpression) -> expression.Expression:
assert isinstance(expr.op, self.op_type)
larg, rarg = _coerce_comparables(expr.children[0], expr.children[1])
larg, rarg = _coerce_comparables(
expr.children[0], expr.children[1], dialect=self.dialect
)
return expr.op.as_expr(larg, rarg)


Expand Down Expand Up @@ -285,14 +288,22 @@ def lower(self, expr: expression.OpExpression) -> expression.Expression:
return wo_bools


@dataclasses.dataclass
class LowerAsTypeRule(op_lowering.OpLoweringRule):
dialect: Literal["polars", "substrait-datafusion", "substrait-acero"]

@property
def op(self) -> type[ops.ScalarOp]:
return ops.AsTypeOp

def lower(self, expr: expression.OpExpression) -> expression.Expression:
assert isinstance(expr.op, ops.AsTypeOp)
return _lower_cast(expr.op, expr.inputs[0])
if self.dialect == "polars":
return _lower_cast_to_polars(expr.op, expr.inputs[0])
else:
return _lower_cast_to_substrait(
expr.op, expr.inputs[0], dialect=self.dialect
)


def invert_bytes(byte_string):
Expand Down Expand Up @@ -392,6 +403,7 @@ def _coerce_comparables(
expr2: expression.Expression,
*,
bools_only: bool = False,
dialect: Literal["polars", "substrait"],
):
if bools_only:
if (
Expand All @@ -402,13 +414,19 @@ def _coerce_comparables(

target_type = dtypes.coerce_to_common(expr1.output_type, expr2.output_type)
if expr1.output_type != target_type:
expr1 = _lower_cast(ops.AsTypeOp(target_type), expr1)
if dialect == "polars":
expr1 = _lower_cast_to_polars(ops.AsTypeOp(target_type), expr1)
elif dialect == "substrait":
expr1 = _lower_cast_to_substrait(ops.AsTypeOp(target_type), expr1)
if expr2.output_type != target_type:
expr2 = _lower_cast(ops.AsTypeOp(target_type), expr2)
if dialect == "polars":
expr2 = _lower_cast_to_polars(ops.AsTypeOp(target_type), expr2)
elif dialect == "substrait":
expr2 = _lower_cast_to_substrait(ops.AsTypeOp(target_type), expr2)
return expr1, expr2


def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression):
def _lower_cast_to_polars(cast_op: ops.AsTypeOp, arg: expression.Expression):
if arg.output_type == cast_op.to_type:
return arg
if (
Expand All @@ -435,8 +453,6 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression):
):
return datetime_ops.StrftimeOp("%Y-%m-%d %H:%M:%S%.6f%:::z").as_expr(arg)
if arg.output_type == dtypes.BOOL_DTYPE and cast_op.to_type == dtypes.STRING_DTYPE:
# bool -> decimal needs two-step cast
new_arg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(arg)
is_true_cond = ops.eq_op.as_expr(arg, expression.const(True))
is_false_cond = ops.eq_op.as_expr(arg, expression.const(False))
return ops.CaseWhenOp().as_expr(
Expand All @@ -459,8 +475,99 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression):
return cast_op.as_expr(arg)


LOWER_COMPARISONS = tuple(
CoerceArgsRule(op)
def _lower_cast_to_substrait(
cast_op: ops.AsTypeOp,
arg: expression.Expression,
dialect: Literal[
"substrait-datafusion", "substrait-acero"
] = "substrait-datafusion",
):
if arg.output_type == dtypes.BOOL_DTYPE and dtypes.is_numeric(cast_op.to_type):
# bool -> decimal/numeric needs two-step cast
new_arg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(arg)
return cast_op.as_expr(new_arg)

if arg.output_type == dtypes.BOOL_DTYPE and cast_op.to_type == dtypes.STRING_DTYPE:
is_true_cond = ops.eq_op.as_expr(arg, expression.const(True))
is_false_cond = ops.eq_op.as_expr(arg, expression.const(False))
return ops.CaseWhenOp().as_expr(
is_true_cond,
expression.const("True"),
is_false_cond,
expression.const("False"),
)

if cast_op.to_type == dtypes.STRING_DTYPE:
if arg.output_type == dtypes.DATETIME_DTYPE:
cast_expr = cast_op.as_expr(arg)
if dialect == "substrait-datafusion":
return string_ops.ReplaceStrOp(pat="T", repl=" ").as_expr(cast_expr)
else:
# Acero: let it cast natively, compiler will intercept and cast to precision_timestamp(0)
return cast_expr

elif arg.output_type == dtypes.TIME_DTYPE:
# Both engines use native cast (Acero is excluded in test)
return cast_op.as_expr(arg)

elif arg.output_type == dtypes.TIMESTAMP_DTYPE:
cast_expr = cast_op.as_expr(arg)
if dialect == "substrait-datafusion":
replaced_t = string_ops.ReplaceStrOp(pat="T", repl=" ").as_expr(
cast_expr
)
return string_ops.ReplaceStrOp(pat="Z", repl="+00").as_expr(replaced_t)
else:
# Acero: native cast (excluded in test)
return cast_expr

return cast_op.as_expr(arg)


class SubstraitLowerEqNullsMatchRule(op_lowering.OpLoweringRule):
@property
def op(self) -> type[ops.ScalarOp]:
return comparison_ops.EqNullsMatchOp

def lower(self, expr: expression.OpExpression) -> expression.Expression:
assert isinstance(expr.op, comparison_ops.EqNullsMatchOp)
arg1, arg2 = _coerce_comparables(
expr.children[0], expr.children[1], dialect="substrait"
)

# True constant
true_const = expression.const(True)
# False constant
false_const = expression.const(False)

# equal = arg1 == arg2
equal_expr = ops.eq_op.as_expr(arg1, arg2)

# isnull1 = arg1.isnull()
isnull1_expr = ops.isnull_op.as_expr(arg1)

# isnull2 = arg2.isnull()
isnull2_expr = ops.isnull_op.as_expr(arg2)

# both_null = isnull1 & isnull2
both_null_expr = ops.and_op.as_expr(isnull1_expr, isnull2_expr)

# any_null = isnull1 | isnull2
any_null_expr = ops.or_op.as_expr(isnull1_expr, isnull2_expr)

# inner_where = where(false, any_null, equal)
inner_where_expr = ops.where_op.as_expr(false_const, any_null_expr, equal_expr)

# outer_where = where(true, both_null, inner_where)
null_safe_eq_expr = ops.where_op.as_expr(
true_const, both_null_expr, inner_where_expr
)

return null_safe_eq_expr


POLARS_LOWER_COMPARISONS = tuple(
CoerceArgsRule(op, dialect="polars")
for op in (
comparison_ops.EqOp,
comparison_ops.EqNullsMatchOp,
Expand All @@ -472,15 +579,27 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression):
)
)

SUBSTRAIT_LOWER_COMPARISONS = tuple(
CoerceArgsRule(op, dialect="substrait")
for op in (
comparison_ops.EqOp,
comparison_ops.NeOp,
comparison_ops.LtOp,
comparison_ops.GtOp,
comparison_ops.LeOp,
comparison_ops.GeOp,
)
)

POLARS_LOWERING_RULES = (
*LOWER_COMPARISONS,
*POLARS_LOWER_COMPARISONS,
LowerAddRule(),
LowerSubRule(),
LowerMulRule(),
LowerDivRule(),
LowerFloorDivRule(),
LowerModRule(),
LowerAsTypeRule(),
LowerAsTypeRule(dialect="polars"),
LowerInvertOp(),
LowerIsinOp(),
LowerLenOp(),
Expand All @@ -493,6 +612,20 @@ def lower_ops_to_polars(root: bigframe_node.BigFrameNode) -> bigframe_node.BigFr
return op_lowering.lower_ops(root, rules=POLARS_LOWERING_RULES)


def lower_ops_to_substrait(
root: bigframe_node.BigFrameNode,
dialect: Literal[
"substrait-datafusion", "substrait-acero"
] = "substrait-datafusion",
) -> bigframe_node.BigFrameNode:
rules = (
SubstraitLowerEqNullsMatchRule(),
*SUBSTRAIT_LOWER_COMPARISONS,
LowerAsTypeRule(dialect=dialect),
)
return op_lowering.lower_ops(root, rules=rules)


def _numeric_to_timedelta(expr: expression.Expression) -> expression.Expression:
"""rounding logic used for emulating timedelta ops"""
rounded_value = ops.where_op.as_expr(
Expand Down
15 changes: 12 additions & 3 deletions packages/bigframes/bigframes/core/compile/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import bigframes.operations.numeric_ops as num_ops
import bigframes.operations.string_ops as string_ops
from bigframes.core import agg_expressions, identifiers, nodes, ordering, window_spec
from bigframes.core.compile.polars import lowering

polars_installed = True
if TYPE_CHECKING:
Expand Down Expand Up @@ -672,6 +671,8 @@ def compile(self, plan: nodes.BigFrameNode) -> pl.LazyFrame:
node = nodes.bottom_up(node, bigframes.core.rewrite.rewrite_slice)
node = bigframes.core.rewrite.pull_out_window_order(node)
node = bigframes.core.rewrite.schema_binding.bind_schema_to_tree(node)
from bigframes.core.compile import lowering

node = lowering.lower_ops_to_polars(node)
return self.compile_node(node)

Expand Down Expand Up @@ -763,7 +764,11 @@ def compile_join(self, node: nodes.JoinNode):
left_on = []
right_on = []
for left_ex, right_ex in node.conditions:
left_ex, right_ex = lowering._coerce_comparables(left_ex, right_ex)
from bigframes.core.compile import lowering

left_ex, right_ex = lowering._coerce_comparables(
left_ex, right_ex, dialect="polars"
)
left_on.append(self.expr_compiler.compile_expression(left_ex))
right_on.append(self.expr_compiler.compile_expression(right_ex))

Expand All @@ -782,7 +787,11 @@ def compile_isin(self, node: nodes.InNode):
right = right.with_columns(pl.lit(True).alias(node.indicator_col.sql))

right_col = ex.ResolvedDerefOp.from_field(node.right_child.fields[0])
left_ex, right_ex = lowering._coerce_comparables(node.left_col, right_col)
from bigframes.core.compile import lowering

left_ex, right_ex = lowering._coerce_comparables(
node.left_col, right_col, dialect="polars"
)

left_pl_ex = self.expr_compiler.compile_expression(left_ex)
right_pl_ex = self.expr_compiler.compile_expression(right_ex)
Expand Down
19 changes: 19 additions & 0 deletions packages/bigframes/bigframes/core/compile/substrait/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from .compiler import SubstraitCompiler

__all__ = ["SubstraitCompiler"]
Loading
Loading