diff --git a/packages/bigframes/bigframes/core/compile/polars/lowering.py b/packages/bigframes/bigframes/core/compile/lowering.py similarity index 77% rename from packages/bigframes/bigframes/core/compile/polars/lowering.py rename to packages/bigframes/bigframes/core/compile/lowering.py index 5b3d9154b731..0b8235140c08 100644 --- a/packages/bigframes/bigframes/core/compile/polars/lowering.py +++ b/packages/bigframes/bigframes/core/compile/lowering.py @@ -13,7 +13,7 @@ # limitations under the License. import dataclasses -from typing import cast +from typing import Literal, cast import numpy as np import pandas as pd @@ -36,6 +36,7 @@ @dataclasses.dataclass class CoerceArgsRule(op_lowering.OpLoweringRule): op_type: type[ops.BinaryOp] + dialect: Literal["polars", "substrait"] @property def op(self) -> type[ops.ScalarOp]: @@ -43,7 +44,9 @@ def op(self) -> type[ops.ScalarOp]: def lower(self, expr: expression.OpExpression) -> expression.Expression: assert isinstance(expr.op, self.op_type) - larg, rarg = _coerce_comparables(expr.children[0], expr.children[1]) + larg, rarg = _coerce_comparables( + expr.children[0], expr.children[1], dialect=self.dialect + ) return expr.op.as_expr(larg, rarg) @@ -285,14 +288,22 @@ def lower(self, expr: expression.OpExpression) -> expression.Expression: return wo_bools +@dataclasses.dataclass class LowerAsTypeRule(op_lowering.OpLoweringRule): + dialect: Literal["polars", "substrait-datafusion", "substrait-acero"] + @property def op(self) -> type[ops.ScalarOp]: return ops.AsTypeOp def lower(self, expr: expression.OpExpression) -> expression.Expression: assert isinstance(expr.op, ops.AsTypeOp) - return _lower_cast(expr.op, expr.inputs[0]) + if self.dialect == "polars": + return _lower_cast_to_polars(expr.op, expr.inputs[0]) + else: + return _lower_cast_to_substrait( + expr.op, expr.inputs[0], dialect=self.dialect + ) def invert_bytes(byte_string): @@ -392,6 +403,7 @@ def _coerce_comparables( expr2: expression.Expression, *, bools_only: bool = False, + dialect: Literal["polars", "substrait"], ): if bools_only: if ( @@ -402,13 +414,19 @@ def _coerce_comparables( target_type = dtypes.coerce_to_common(expr1.output_type, expr2.output_type) if expr1.output_type != target_type: - expr1 = _lower_cast(ops.AsTypeOp(target_type), expr1) + if dialect == "polars": + expr1 = _lower_cast_to_polars(ops.AsTypeOp(target_type), expr1) + elif dialect == "substrait": + expr1 = _lower_cast_to_substrait(ops.AsTypeOp(target_type), expr1) if expr2.output_type != target_type: - expr2 = _lower_cast(ops.AsTypeOp(target_type), expr2) + if dialect == "polars": + expr2 = _lower_cast_to_polars(ops.AsTypeOp(target_type), expr2) + elif dialect == "substrait": + expr2 = _lower_cast_to_substrait(ops.AsTypeOp(target_type), expr2) return expr1, expr2 -def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression): +def _lower_cast_to_polars(cast_op: ops.AsTypeOp, arg: expression.Expression): if arg.output_type == cast_op.to_type: return arg if ( @@ -435,8 +453,6 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression): ): return datetime_ops.StrftimeOp("%Y-%m-%d %H:%M:%S%.6f%:::z").as_expr(arg) if arg.output_type == dtypes.BOOL_DTYPE and cast_op.to_type == dtypes.STRING_DTYPE: - # bool -> decimal needs two-step cast - new_arg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(arg) is_true_cond = ops.eq_op.as_expr(arg, expression.const(True)) is_false_cond = ops.eq_op.as_expr(arg, expression.const(False)) return ops.CaseWhenOp().as_expr( @@ -459,8 +475,99 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression): return cast_op.as_expr(arg) -LOWER_COMPARISONS = tuple( - CoerceArgsRule(op) +def _lower_cast_to_substrait( + cast_op: ops.AsTypeOp, + arg: expression.Expression, + dialect: Literal[ + "substrait-datafusion", "substrait-acero" + ] = "substrait-datafusion", +): + if arg.output_type == dtypes.BOOL_DTYPE and dtypes.is_numeric(cast_op.to_type): + # bool -> decimal/numeric needs two-step cast + new_arg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(arg) + return cast_op.as_expr(new_arg) + + if arg.output_type == dtypes.BOOL_DTYPE and cast_op.to_type == dtypes.STRING_DTYPE: + is_true_cond = ops.eq_op.as_expr(arg, expression.const(True)) + is_false_cond = ops.eq_op.as_expr(arg, expression.const(False)) + return ops.CaseWhenOp().as_expr( + is_true_cond, + expression.const("True"), + is_false_cond, + expression.const("False"), + ) + + if cast_op.to_type == dtypes.STRING_DTYPE: + if arg.output_type == dtypes.DATETIME_DTYPE: + cast_expr = cast_op.as_expr(arg) + if dialect == "substrait-datafusion": + return string_ops.ReplaceStrOp(pat="T", repl=" ").as_expr(cast_expr) + else: + # Acero: let it cast natively, compiler will intercept and cast to precision_timestamp(0) + return cast_expr + + elif arg.output_type == dtypes.TIME_DTYPE: + # Both engines use native cast (Acero is excluded in test) + return cast_op.as_expr(arg) + + elif arg.output_type == dtypes.TIMESTAMP_DTYPE: + cast_expr = cast_op.as_expr(arg) + if dialect == "substrait-datafusion": + replaced_t = string_ops.ReplaceStrOp(pat="T", repl=" ").as_expr( + cast_expr + ) + return string_ops.ReplaceStrOp(pat="Z", repl="+00").as_expr(replaced_t) + else: + # Acero: native cast (excluded in test) + return cast_expr + + return cast_op.as_expr(arg) + + +class SubstraitLowerEqNullsMatchRule(op_lowering.OpLoweringRule): + @property + def op(self) -> type[ops.ScalarOp]: + return comparison_ops.EqNullsMatchOp + + def lower(self, expr: expression.OpExpression) -> expression.Expression: + assert isinstance(expr.op, comparison_ops.EqNullsMatchOp) + arg1, arg2 = _coerce_comparables( + expr.children[0], expr.children[1], dialect="substrait" + ) + + # True constant + true_const = expression.const(True) + # False constant + false_const = expression.const(False) + + # equal = arg1 == arg2 + equal_expr = ops.eq_op.as_expr(arg1, arg2) + + # isnull1 = arg1.isnull() + isnull1_expr = ops.isnull_op.as_expr(arg1) + + # isnull2 = arg2.isnull() + isnull2_expr = ops.isnull_op.as_expr(arg2) + + # both_null = isnull1 & isnull2 + both_null_expr = ops.and_op.as_expr(isnull1_expr, isnull2_expr) + + # any_null = isnull1 | isnull2 + any_null_expr = ops.or_op.as_expr(isnull1_expr, isnull2_expr) + + # inner_where = where(false, any_null, equal) + inner_where_expr = ops.where_op.as_expr(false_const, any_null_expr, equal_expr) + + # outer_where = where(true, both_null, inner_where) + null_safe_eq_expr = ops.where_op.as_expr( + true_const, both_null_expr, inner_where_expr + ) + + return null_safe_eq_expr + + +POLARS_LOWER_COMPARISONS = tuple( + CoerceArgsRule(op, dialect="polars") for op in ( comparison_ops.EqOp, comparison_ops.EqNullsMatchOp, @@ -472,15 +579,27 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression): ) ) +SUBSTRAIT_LOWER_COMPARISONS = tuple( + CoerceArgsRule(op, dialect="substrait") + for op in ( + comparison_ops.EqOp, + comparison_ops.NeOp, + comparison_ops.LtOp, + comparison_ops.GtOp, + comparison_ops.LeOp, + comparison_ops.GeOp, + ) +) + POLARS_LOWERING_RULES = ( - *LOWER_COMPARISONS, + *POLARS_LOWER_COMPARISONS, LowerAddRule(), LowerSubRule(), LowerMulRule(), LowerDivRule(), LowerFloorDivRule(), LowerModRule(), - LowerAsTypeRule(), + LowerAsTypeRule(dialect="polars"), LowerInvertOp(), LowerIsinOp(), LowerLenOp(), @@ -493,6 +612,20 @@ def lower_ops_to_polars(root: bigframe_node.BigFrameNode) -> bigframe_node.BigFr return op_lowering.lower_ops(root, rules=POLARS_LOWERING_RULES) +def lower_ops_to_substrait( + root: bigframe_node.BigFrameNode, + dialect: Literal[ + "substrait-datafusion", "substrait-acero" + ] = "substrait-datafusion", +) -> bigframe_node.BigFrameNode: + rules = ( + SubstraitLowerEqNullsMatchRule(), + *SUBSTRAIT_LOWER_COMPARISONS, + LowerAsTypeRule(dialect=dialect), + ) + return op_lowering.lower_ops(root, rules=rules) + + def _numeric_to_timedelta(expr: expression.Expression) -> expression.Expression: """rounding logic used for emulating timedelta ops""" rounded_value = ops.where_op.as_expr( diff --git a/packages/bigframes/bigframes/core/compile/polars/compiler.py b/packages/bigframes/bigframes/core/compile/polars/compiler.py index 2477f27b6432..9f2fe407b714 100644 --- a/packages/bigframes/bigframes/core/compile/polars/compiler.py +++ b/packages/bigframes/bigframes/core/compile/polars/compiler.py @@ -39,7 +39,6 @@ import bigframes.operations.numeric_ops as num_ops import bigframes.operations.string_ops as string_ops from bigframes.core import agg_expressions, identifiers, nodes, ordering, window_spec -from bigframes.core.compile.polars import lowering polars_installed = True if TYPE_CHECKING: @@ -672,6 +671,8 @@ def compile(self, plan: nodes.BigFrameNode) -> pl.LazyFrame: node = nodes.bottom_up(node, bigframes.core.rewrite.rewrite_slice) node = bigframes.core.rewrite.pull_out_window_order(node) node = bigframes.core.rewrite.schema_binding.bind_schema_to_tree(node) + from bigframes.core.compile import lowering + node = lowering.lower_ops_to_polars(node) return self.compile_node(node) @@ -763,7 +764,11 @@ def compile_join(self, node: nodes.JoinNode): left_on = [] right_on = [] for left_ex, right_ex in node.conditions: - left_ex, right_ex = lowering._coerce_comparables(left_ex, right_ex) + from bigframes.core.compile import lowering + + left_ex, right_ex = lowering._coerce_comparables( + left_ex, right_ex, dialect="polars" + ) left_on.append(self.expr_compiler.compile_expression(left_ex)) right_on.append(self.expr_compiler.compile_expression(right_ex)) @@ -782,7 +787,11 @@ def compile_isin(self, node: nodes.InNode): right = right.with_columns(pl.lit(True).alias(node.indicator_col.sql)) right_col = ex.ResolvedDerefOp.from_field(node.right_child.fields[0]) - left_ex, right_ex = lowering._coerce_comparables(node.left_col, right_col) + from bigframes.core.compile import lowering + + left_ex, right_ex = lowering._coerce_comparables( + node.left_col, right_col, dialect="polars" + ) left_pl_ex = self.expr_compiler.compile_expression(left_ex) right_pl_ex = self.expr_compiler.compile_expression(right_ex) diff --git a/packages/bigframes/bigframes/core/compile/substrait/__init__.py b/packages/bigframes/bigframes/core/compile/substrait/__init__.py new file mode 100644 index 000000000000..13021b2fec3f --- /dev/null +++ b/packages/bigframes/bigframes/core/compile/substrait/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from .compiler import SubstraitCompiler + +__all__ = ["SubstraitCompiler"] diff --git a/packages/bigframes/bigframes/core/compile/substrait/compiler.py b/packages/bigframes/bigframes/core/compile/substrait/compiler.py new file mode 100644 index 000000000000..2085e43616a6 --- /dev/null +++ b/packages/bigframes/bigframes/core/compile/substrait/compiler.py @@ -0,0 +1,1481 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from functools import singledispatchmethod +from typing import Any, Dict, Literal, Optional, Sequence + +import pandas as pd +import substrait.algebra_pb2 as algebra_pb2 +import substrait.plan_pb2 as plan_pb2 +from google.protobuf import json_format + +import bigframes.core.expression as ex +import bigframes.dtypes as dtypes +import bigframes.operations as ops +import bigframes.operations.bool_ops as bool_ops +import bigframes.operations.comparison_ops as comparison_ops +import bigframes.operations.generic_ops as generic_ops +import bigframes.operations.numeric_ops as numeric_ops +import bigframes.operations.string_ops as string_ops +import bigframes.operations.struct_ops as struct_ops +from bigframes.core import agg_expressions, bigframe_node, nodes, rewrite +from bigframes.core.compile import lowering + + +class SubstraitCompiler: + """ + Compiles BigFrameNode plans to Substrait schema (JSON representation). + """ + + def __init__( + self, + duration_type: Literal["interval_day", "int"], + use_precision_types: bool = True, + dialect: Literal[ + "substrait-datafusion", "substrait-acero" + ] = "substrait-datafusion", + ): + self._duration_type = duration_type + self._use_precision_types = use_precision_types + self._dialect = dialect + + def compile(self, plan: bigframe_node.BigFrameNode) -> Optional[bytes]: + """ + Compiles a BigFrameNode to Substrait bytes (JSON encoded via protobuf). + """ + if not self.can_compile(plan): + return None + + # Need to bind types in before lowering + plan = rewrite.bind_schema_to_tree(plan) + plan = lowering.lower_ops_to_substrait(plan, dialect=self._dialect) + pb_rel = self._compile_node(plan) + + pb_plan = plan_pb2.Plan() + pb_plan.version.minor_number = 42 + + plan_rel = pb_plan.relations.add() + plan_rel.root.input.CopyFrom(pb_rel) + + for item in plan.schema.items: + plan_rel.root.names.extend( + self._get_substrait_names( + item.column if isinstance(item.column, str) else item.column.sql, + item.dtype, + ) + ) + + # Determine extensions dynamically based on execution engine + extensions = dict(self._EXTENSIONS) + if self._use_precision_types: + # DataFusion supports standard "replace" + extensions["replace"] = 76 + else: + # Acero expects Arrow simple extension function URN for "replace_substring" + extensions["replace_substring"] = 76 + + # Register Arrow simple extension URN at anchor 1 + arrow_uri = pb_plan.extension_urns.add() + arrow_uri.extension_urn_anchor = 1 + arrow_uri.urn = "urn:arrow:substrait_simple_extension_function" + + for name, anchor in extensions.items(): + ext = pb_plan.extensions.add() + ext.extension_function.function_anchor = anchor + ext.extension_function.name = name + if name == "replace_substring" and not self._use_precision_types: + ext.extension_function.extension_urn_reference = 1 + + return pb_plan.SerializeToString() + + def can_compile(self, plan: bigframe_node.BigFrameNode) -> bool: + """ + Checks if the plan can be compiled to Substrait. + """ + supported_nodes = ( + nodes.ReadLocalNode, + nodes.SelectionNode, + nodes.FilterNode, + nodes.ProjectionNode, + nodes.JoinNode, + nodes.AggregateNode, + nodes.WindowOpNode, + nodes.OrderByNode, + nodes.ConcatNode, + ) + return all(isinstance(n, supported_nodes) for n in plan.unique_nodes()) + + def _compile_node(self, node: bigframe_node.BigFrameNode) -> algebra_pb2.Rel: + if isinstance(node, nodes.ReadLocalNode): + return self._compile_read(node) + elif isinstance(node, nodes.SelectionNode): + return self._compile_selection(node) + elif isinstance(node, nodes.FilterNode): + return self._compile_filter(node) + elif isinstance(node, nodes.ProjectionNode): + return self._compile_projection(node) + elif isinstance(node, nodes.JoinNode): + return self._compile_join(node) + elif isinstance(node, nodes.AggregateNode): + return self._compile_aggregate(node) + elif isinstance(node, nodes.WindowOpNode): + return self._compile_window(node) + elif isinstance(node, nodes.OrderByNode): + return self._compile_orderby(node) + elif isinstance(node, nodes.ConcatNode): + return self._compile_concat(node) + else: + raise NotImplementedError( + f"Node type {type(node)} not supported in Substrait compiler yet" + ) + + def _compile_read(self, node: nodes.ReadLocalNode) -> algebra_pb2.Rel: + table_name = f"table_{id(node)}" + + rel = algebra_pb2.Rel() + read_rel = rel.read + read_rel.named_table.names.append(table_name) + + import bigframes.dtypes as dtypes + + fields = [] + types = [] + for item in node.scan_list.items: + col_dtype = node.local_data_source.schema.get_type(item.source_id) + fields.extend(self._get_substrait_names(item.id.sql, col_dtype)) + types.append(self._convert_type(col_dtype)) + + if node.offsets_col is not None: + fields.append(node.offsets_col.sql) + types.append(self._convert_type(dtypes.INT_DTYPE)) + + schema_dict = {"names": fields, "struct": {"types": types}} + json_format.ParseDict(schema_dict, read_rel.base_schema) + + return rel + + def _compile_selection(self, node: nodes.SelectionNode) -> algebra_pb2.Rel: + input_rel = self._compile_node(node.child) + + rel = algebra_pb2.Rel() + project_rel = rel.project + project_rel.input.CopyFrom(input_rel) + + child_ids = list(node.child.ids) + num_exprs = 0 + for aliased_ref in node.input_output_pairs: + source_id = aliased_ref.ref.id + idx = child_ids.index(source_id) + expr = project_rel.expressions.add() + expr.selection.direct_reference.struct_field.field = idx + num_exprs += 1 + + project_rel.common.emit.output_mapping.extend( + [len(child_ids) + i for i in range(num_exprs)] + ) + + return rel + + def _compile_filter(self, node: nodes.FilterNode) -> algebra_pb2.Rel: + input_rel = self._compile_node(node.child) + + rel = algebra_pb2.Rel() + filter_rel = rel.filter + filter_rel.input.CopyFrom(input_rel) + + condition_expr = self._compile_expression(node.predicate, node.child) + filter_rel.condition.CopyFrom(condition_expr) + + return rel + + def _compile_projection(self, node: nodes.ProjectionNode) -> algebra_pb2.Rel: + input_rel = self._compile_node(node.child) + + rel = algebra_pb2.Rel() + project_rel = rel.project + project_rel.input.CopyFrom(input_rel) + + for expr, _ in node.assignments: + expr_pb = self._compile_expression(expr, node.child) + project_rel.expressions.add().CopyFrom(expr_pb) + + child_ids = list(node.child.ids) + num_exprs = len(node.assignments) + project_rel.common.emit.output_mapping.extend(range(len(child_ids) + num_exprs)) + + return rel + + def _compile_join(self, node: nodes.JoinNode) -> algebra_pb2.Rel: + left_rel = self._compile_node(node.left_child) + right_rel = self._compile_node(node.right_child) + + rel = algebra_pb2.Rel() + if node.type == "cross": + cross_rel = rel.cross + cross_rel.left.CopyFrom(left_rel) + cross_rel.right.CopyFrom(right_rel) + return rel + + join_rel = rel.join + join_rel.left.CopyFrom(left_rel) + join_rel.right.CopyFrom(right_rel) + + type_map = { + "inner": algebra_pb2.JoinRel.JOIN_TYPE_INNER, + "left": algebra_pb2.JoinRel.JOIN_TYPE_LEFT, + "right": algebra_pb2.JoinRel.JOIN_TYPE_RIGHT, + "outer": algebra_pb2.JoinRel.JOIN_TYPE_OUTER, + } + join_rel.type = type_map.get( + node.type, algebra_pb2.JoinRel.JOIN_TYPE_UNSPECIFIED + ) + + left_len = len(node.left_child.schema) + + eq_expressions = [] + for left_deref, right_deref in node.conditions: + left_idx = list(node.left_child.ids).index(left_deref.id) + right_idx = list(node.right_child.ids).index(right_deref.id) + left_len + + arg1 = algebra_pb2.Expression() + arg1.selection.direct_reference.struct_field.field = left_idx + + arg2 = algebra_pb2.Expression() + arg2.selection.direct_reference.struct_field.field = right_idx + + eq_expr = algebra_pb2.Expression() + eq_expr.scalar_function.function_reference = self._EXTENSIONS["equal"] + eq_expr.scalar_function.arguments.add().value.CopyFrom(arg1) + eq_expr.scalar_function.arguments.add().value.CopyFrom(arg2) + + isnull1_expr = algebra_pb2.Expression() + isnull1_expr.scalar_function.function_reference = self._EXTENSIONS[ + "is_null" + ] + isnull1_expr.scalar_function.arguments.add().value.CopyFrom(arg1) + + isnull2_expr = algebra_pb2.Expression() + isnull2_expr.scalar_function.function_reference = self._EXTENSIONS[ + "is_null" + ] + isnull2_expr.scalar_function.arguments.add().value.CopyFrom(arg2) + + both_null_expr = algebra_pb2.Expression() + both_null_expr.scalar_function.function_reference = self._EXTENSIONS["and"] + both_null_expr.scalar_function.arguments.add().value.CopyFrom(isnull1_expr) + both_null_expr.scalar_function.arguments.add().value.CopyFrom(isnull2_expr) + + null_safe_eq = algebra_pb2.Expression() + null_safe_eq.scalar_function.function_reference = self._EXTENSIONS["or"] + null_safe_eq.scalar_function.arguments.add().value.CopyFrom(eq_expr) + null_safe_eq.scalar_function.arguments.add().value.CopyFrom(both_null_expr) + + eq_expressions.append(null_safe_eq) + + if len(eq_expressions) > 1: + expr = eq_expressions[0] + for e in eq_expressions[1:]: + and_expr = algebra_pb2.Expression() + and_expr.scalar_function.function_reference = 13 # and + and_expr.scalar_function.arguments.add().value.CopyFrom(expr) + and_expr.scalar_function.arguments.add().value.CopyFrom(e) + expr = and_expr + elif len(eq_expressions) == 1: + expr = eq_expressions[0] + else: + expr = algebra_pb2.Expression() + expr.literal.boolean = True + + join_rel.expression.CopyFrom(expr) + + return rel + + def _compile_bound( + self, + val: Optional[int], + bound_msg: algebra_pb2.Expression.WindowFunction.Bound, + ): + if val is None: + bound_msg.unbounded.CopyFrom( + algebra_pb2.Expression.WindowFunction.Bound.Unbounded() + ) + elif val == 0: + bound_msg.current_row.CopyFrom( + algebra_pb2.Expression.WindowFunction.Bound.CurrentRow() + ) + elif val < 0: + bound_msg.preceding.offset = -val + else: + bound_msg.following.offset = val + + def _compile_concat(self, node: nodes.ConcatNode) -> algebra_pb2.Rel: + rel = algebra_pb2.Rel() + set_rel = rel.set + set_rel.op = algebra_pb2.SetRel.SetOp.SET_OP_UNION_ALL + + for child in node.children: + set_rel.inputs.append(self._compile_node(child)) + + return rel + + def _compile_aggregate(self, node: nodes.AggregateNode) -> algebra_pb2.Rel: + input_rel = self._compile_node(node.child) + + import bigframes.operations.aggregations as agg_ops + + child_ids = list(node.child.ids) + + rel = algebra_pb2.Rel() + agg_rel = rel.aggregate + agg_rel.input.CopyFrom(input_rel) + + if node.by_column_ids: + grouping = agg_rel.groupings.add() + for deref in node.by_column_ids: + idx = child_ids.index(deref.id) + expr = grouping.grouping_expressions.add() + expr.selection.direct_reference.struct_field.field = idx + + for agg_idx, (agg, out_col_id) in enumerate(node.aggregations): + distinct = False + if isinstance(agg.op, agg_ops.SumOp): + func_ref = self._EXTENSIONS["sum"] + elif isinstance(agg.op, agg_ops.MaxOp): + func_ref = self._EXTENSIONS["max"] + elif isinstance(agg.op, agg_ops.MinOp): + func_ref = self._EXTENSIONS["min"] + elif isinstance(agg.op, agg_ops.MeanOp): + func_ref = self._EXTENSIONS["avg"] + elif isinstance(agg.op, agg_ops.CountOp): + func_ref = self._EXTENSIONS["count"] + elif isinstance(agg.op, (agg_ops.SizeOp, agg_ops.SizeUnaryOp)): + func_ref = self._EXTENSIONS["count"] + elif isinstance(agg.op, agg_ops.NuniqueOp): + func_ref = self._EXTENSIONS["count"] + distinct = True + elif isinstance(agg.op, agg_ops.StdOp): + func_ref = self._EXTENSIONS["stddev"] + elif isinstance(agg.op, agg_ops.VarOp): + func_ref = self._EXTENSIONS["var"] + elif isinstance(agg.op, agg_ops.PopVarOp): + func_ref = self._EXTENSIONS["var_pop"] + elif isinstance(agg.op, agg_ops.AnyValueOp): + func_ref = self._EXTENSIONS["min"] + elif isinstance(agg.op, agg_ops.AllOp): + func_ref = self._EXTENSIONS["bool_and"] + elif isinstance(agg.op, agg_ops.AnyOp): + func_ref = self._EXTENSIONS["bool_or"] + elif isinstance(agg.op, agg_ops.ProductOp): + func_ref = self._EXTENSIONS["product"] + elif isinstance(agg.op, agg_ops.MedianOp): + func_ref = self._EXTENSIONS["median"] + elif isinstance(agg.op, agg_ops.CovOp): + func_ref = self._EXTENSIONS["cov"] + elif isinstance(agg.op, agg_ops.CorrOp): + func_ref = self._EXTENSIONS["corr"] + else: + raise NotImplementedError( + f"Aggregation {type(agg.op)} not supported in Substrait compiler yet" + ) + + measure = agg_rel.measures.add() + measure.measure.function_reference = func_ref + measure.measure.phase = algebra_pb2.AGGREGATION_PHASE_INITIAL_TO_RESULT + + output_dtype = agg.output_type + type_dict = self._convert_type(output_dtype) + json_format.ParseDict(type_dict, measure.measure.output_type) + + if distinct or isinstance(agg.op, agg_ops.NuniqueOp): + measure.measure.invocation = ( + algebra_pb2.AggregateFunction.AGGREGATION_INVOCATION_DISTINCT + ) + + if hasattr(agg, "column_references"): + for col_id in agg.column_references: + try: + idx = child_ids.index(col_id) + field_expr = algebra_pb2.Expression() + field_expr.selection.direct_reference.struct_field.field = idx + + arg = measure.measure.arguments.add() + arg.value.CopyFrom(field_expr) + except ValueError: + pass + + if node.dropna and node.by_column_ids: + not_null_exprs = [] + for idx in range(len(node.by_column_ids)): + key_expr = algebra_pb2.Expression() + key_expr.selection.direct_reference.struct_field.field = idx + + not_null_op = algebra_pb2.Expression() + not_null_op.scalar_function.function_reference = self._EXTENSIONS[ + "is_not_null" + ] + json_format.ParseDict( + {"bool": {}}, not_null_op.scalar_function.output_type + ) + not_null_op.scalar_function.arguments.add().value.CopyFrom(key_expr) + not_null_exprs.append(not_null_op) + + if len(not_null_exprs) > 1: + expr = not_null_exprs[0] + for e in not_null_exprs[1:]: + and_expr = algebra_pb2.Expression() + and_expr.scalar_function.function_reference = self._EXTENSIONS[ + "and" + ] + json_format.ParseDict( + {"bool": {}}, and_expr.scalar_function.output_type + ) + and_expr.scalar_function.arguments.add().value.CopyFrom(expr) + and_expr.scalar_function.arguments.add().value.CopyFrom(e) + expr = and_expr + else: + expr = not_null_exprs[0] + + filter_rel = algebra_pb2.Rel() + filter_rel.filter.input.CopyFrom(rel) + filter_rel.filter.condition.CopyFrom(expr) + rel = filter_rel + + return rel + + def _compile_window(self, node: nodes.WindowOpNode) -> algebra_pb2.Rel: + input_rel = self._compile_node(node.child) + + import bigframes.dtypes as dtypes + import bigframes.operations.aggregations as agg_ops + from bigframes.core import window_spec + + child_ids = list(node.child.ids) + + rel = algebra_pb2.Rel() + proj = rel.project + proj.input.CopyFrom(input_rel) + + # 1. Project all child columns first + for idx in range(len(child_ids)): + expr = proj.expressions.add() + expr.selection.direct_reference.struct_field.field = idx + + # 2. Map window frame bounds (RowsWindowBounds / RangeWindowBounds / None) + bounds_type = ( + algebra_pb2.Expression.WindowFunction.BoundsType.BOUNDS_TYPE_UNSPECIFIED + ) + lower_bound = algebra_pb2.Expression.WindowFunction.Bound() + upper_bound = algebra_pb2.Expression.WindowFunction.Bound() + + if node.window_spec.bounds is not None: + if isinstance(node.window_spec.bounds, window_spec.RowsWindowBounds): + bounds_type = ( + algebra_pb2.Expression.WindowFunction.BoundsType.BOUNDS_TYPE_ROWS + ) + + # Lower bound mapping + start = node.window_spec.bounds.start + if start is None: + lower_bound.unbounded.CopyFrom( + algebra_pb2.Expression.WindowFunction.Bound.Unbounded() + ) + elif start == 0: + lower_bound.current_row.CopyFrom( + algebra_pb2.Expression.WindowFunction.Bound.CurrentRow() + ) + elif start < 0: + lower_bound.preceding.offset = -start + else: + lower_bound.following.offset = start + + # Upper bound mapping + end = node.window_spec.bounds.end + if end is None: + upper_bound.unbounded.CopyFrom( + algebra_pb2.Expression.WindowFunction.Bound.Unbounded() + ) + elif end == 0: + upper_bound.current_row.CopyFrom( + algebra_pb2.Expression.WindowFunction.Bound.CurrentRow() + ) + elif end < 0: + upper_bound.preceding.offset = -end + else: + upper_bound.following.offset = end + + elif isinstance(node.window_spec.bounds, window_spec.RangeWindowBounds): + bounds_type = ( + algebra_pb2.Expression.WindowFunction.BoundsType.BOUNDS_TYPE_RANGE + ) + range_start = node.window_spec.bounds.start + if range_start is None: + lower_bound.unbounded.CopyFrom( + algebra_pb2.Expression.WindowFunction.Bound.Unbounded() + ) + elif range_start == pd.Timedelta(0): + lower_bound.current_row.CopyFrom( + algebra_pb2.Expression.WindowFunction.Bound.CurrentRow() + ) + else: + raise NotImplementedError( + "Range window bounds with non-zero offsets are not supported yet" + ) + + range_end = node.window_spec.bounds.end + if range_end is None: + upper_bound.unbounded.CopyFrom( + algebra_pb2.Expression.WindowFunction.Bound.Unbounded() + ) + elif range_end == pd.Timedelta(0): + upper_bound.current_row.CopyFrom( + algebra_pb2.Expression.WindowFunction.Bound.CurrentRow() + ) + else: + raise NotImplementedError( + "Range window bounds with non-zero offsets are not supported yet" + ) + else: + bounds_type = ( + algebra_pb2.Expression.WindowFunction.BoundsType.BOUNDS_TYPE_ROWS + ) + lower_bound.unbounded.CopyFrom( + algebra_pb2.Expression.WindowFunction.Bound.Unbounded() + ) + upper_bound.unbounded.CopyFrom( + algebra_pb2.Expression.WindowFunction.Bound.Unbounded() + ) + + # 3. Project each window aggregation expression as a WindowFunction expression + for agg_idx, col_def in enumerate(node.agg_exprs): + agg = col_def.expression + assert isinstance(agg, agg_expressions.Aggregation) + distinct = False + + if isinstance(agg.op, agg_ops.SumOp): + func_ref = self._EXTENSIONS["sum"] + elif isinstance(agg.op, agg_ops.MaxOp): + func_ref = self._EXTENSIONS["max"] + elif isinstance(agg.op, agg_ops.MinOp): + func_ref = self._EXTENSIONS["min"] + elif isinstance(agg.op, agg_ops.MeanOp): + func_ref = self._EXTENSIONS["avg"] + elif isinstance(agg.op, agg_ops.CountOp): + func_ref = self._EXTENSIONS["count"] + elif isinstance(agg.op, (agg_ops.SizeOp, agg_ops.SizeUnaryOp)): + func_ref = self._EXTENSIONS["count"] + elif isinstance(agg.op, agg_ops.NuniqueOp): + func_ref = self._EXTENSIONS["count"] + distinct = True + elif isinstance(agg.op, agg_ops.StdOp): + func_ref = self._EXTENSIONS["stddev"] + elif isinstance(agg.op, agg_ops.VarOp): + func_ref = self._EXTENSIONS["var"] + elif isinstance(agg.op, agg_ops.PopVarOp): + func_ref = self._EXTENSIONS["var_pop"] + elif isinstance(agg.op, agg_ops.AnyValueOp): + func_ref = self._EXTENSIONS["min"] + elif isinstance(agg.op, agg_ops.AllOp): + func_ref = self._EXTENSIONS["bool_and"] + elif isinstance(agg.op, agg_ops.AnyOp): + func_ref = self._EXTENSIONS["bool_or"] + elif isinstance(agg.op, agg_ops.ProductOp): + func_ref = self._EXTENSIONS["product"] + elif isinstance(agg.op, agg_ops.MedianOp): + func_ref = self._EXTENSIONS["median"] + elif isinstance(agg.op, agg_ops.CovOp): + func_ref = self._EXTENSIONS["cov"] + elif isinstance(agg.op, agg_ops.CorrOp): + func_ref = self._EXTENSIONS["corr"] + else: + raise NotImplementedError( + f"Aggregation {type(agg.op)} not supported in window function yet" + ) + + expr = proj.expressions.add() + win_func = expr.window_function + win_func.function_reference = func_ref + win_func.phase = algebra_pb2.AGGREGATION_PHASE_INITIAL_TO_RESULT + + bound_expr = ex.bind_schema_fields(agg, node.child.field_by_id) + type_dict = self._convert_type( + dtypes.dtype_for_etype(bound_expr.output_type) + ) + json_format.ParseDict(type_dict, win_func.output_type) + + if distinct or isinstance(agg.op, agg_ops.NuniqueOp): + win_func.invocation = ( + algebra_pb2.AggregateFunction.AGGREGATION_INVOCATION_DISTINCT + ) + + # Set bounds + win_func.lower_bound.CopyFrom(lower_bound) + win_func.upper_bound.CopyFrom(upper_bound) + win_func.bounds_type = bounds_type + + # Set partitioning keys (partitions) + for partition_expr in node.window_spec.grouping_keys: + partition_pb = self._compile_expression(partition_expr, node.child) + win_func.partitions.add().CopyFrom(partition_pb) + + # Set sorting keys (sorts) + for ord_expr in node.window_spec.ordering: + sort_field = win_func.sorts.add() + sort_pb = self._compile_expression( + ord_expr.scalar_expression, node.child + ) + sort_field.expr.CopyFrom(sort_pb) + + is_asc = ord_expr.direction.is_ascending + if is_asc: + if ord_expr.na_last: + sort_field.direction = algebra_pb2.SortField.SortDirection.SORT_DIRECTION_ASC_NULLS_LAST + else: + sort_field.direction = algebra_pb2.SortField.SortDirection.SORT_DIRECTION_ASC_NULLS_FIRST + else: + if ord_expr.na_last: + sort_field.direction = algebra_pb2.SortField.SortDirection.SORT_DIRECTION_DESC_NULLS_LAST + else: + sort_field.direction = algebra_pb2.SortField.SortDirection.SORT_DIRECTION_DESC_NULLS_FIRST + + # Set arguments + if hasattr(agg, "column_references"): + for col_id in agg.column_references: + try: + idx = child_ids.index(col_id) + field_expr = algebra_pb2.Expression() + field_expr.selection.direct_reference.struct_field.field = idx + + arg = win_func.arguments.add() + arg.value.CopyFrom(field_expr) + except ValueError: + pass + + # Emit all columns (child columns + new window columns) + proj.common.emit.output_mapping.extend( + range(len(child_ids) + len(node.agg_exprs)) + ) + + return rel + + def _compile_orderby(self, node: nodes.OrderByNode) -> algebra_pb2.Rel: + input_rel = self._compile_node(node.child) + + rel = algebra_pb2.Rel() + sort_rel = rel.sort + sort_rel.input.CopyFrom(input_rel) + + for ord_expr in node.by: + sort_field = sort_rel.sorts.add() + + # Compile the expression: + expr_pb = self._compile_expression(ord_expr.scalar_expression, node.child) + sort_field.expr.CopyFrom(expr_pb) + + # Map sort direction: + is_asc = ord_expr.direction.is_ascending + if is_asc: + if ord_expr.na_last: + sort_field.direction = algebra_pb2.SortField.SortDirection.SORT_DIRECTION_ASC_NULLS_LAST + else: + sort_field.direction = algebra_pb2.SortField.SortDirection.SORT_DIRECTION_ASC_NULLS_FIRST + else: + if ord_expr.na_last: + sort_field.direction = algebra_pb2.SortField.SortDirection.SORT_DIRECTION_DESC_NULLS_LAST + else: + sort_field.direction = algebra_pb2.SortField.SortDirection.SORT_DIRECTION_DESC_NULLS_FIRST + + return rel + + _EXTENSIONS = { + "add": 1, + "subtract": 2, + "multiply": 3, + "divide": 4, + "equal": 5, + "not_equal": 6, + "lt": 7, + "gt": 8, + "lte": 9, + "gte": 10, + "sum": 11, + "max": 12, + "and": 13, + "min": 14, + "avg": 15, + "count": 16, + "stddev": 17, + "var": 18, + "any_value": 19, + "all": 20, + "any": 21, + "coalesce": 22, + "or": 23, + "least": 24, + "greatest": 25, + "is_null": 26, + "is_not_null": 27, + "nullif": 28, + "sqrt": 29, + "bool_and": 30, + "bool_or": 31, + "product": 32, + "not": 33, + "mod": 34, + "floor": 35, + "abs": 36, + "ceil": 37, + "median": 38, + "xor": 40, + "var_pop": 53, + "row_number": 60, + "rank": 61, + "dense_rank": 62, + "first_value": 63, + "last_value": 64, + "lag": 65, + "lead": 66, + "struct": 67, + "get_field": 68, + "pow": 69, + "cov": 70, + "corr": 71, + "bitwise_and": 72, + "bitwise_or": 73, + "bitwise_xor": 74, + } + + _OP_TO_EXTENSION = { + numeric_ops.AddOp: "add", + numeric_ops.SubOp: "subtract", + numeric_ops.MulOp: "multiply", + numeric_ops.DivOp: "divide", + numeric_ops.ModOp: "mod", + numeric_ops.PowOp: "pow", + numeric_ops.UnsafePowOp: "pow", + comparison_ops.EqOp: "equal", + comparison_ops.NeOp: "not_equal", + comparison_ops.LtOp: "lt", + comparison_ops.GtOp: "gt", + comparison_ops.LeOp: "lte", + comparison_ops.GeOp: "gte", + generic_ops.FillNaOp: "coalesce", + generic_ops.CoalesceOp: "coalesce", + bool_ops.AndOp: "and", + bool_ops.OrOp: "or", + bool_ops.XorOp: "xor", + generic_ops.InvertOp: "not", + numeric_ops.AbsOp: "abs", + numeric_ops.CeilOp: "ceil", + numeric_ops.FloorOp: "floor", + generic_ops.IsNullOp: "is_null", + generic_ops.NotNullOp: "is_not_null", + } + + @singledispatchmethod + def _compile_expression( + self, expr: ex.Expression, child: nodes.BigFrameNode + ) -> algebra_pb2.Expression: + raise NotImplementedError( + f"Expression type {type(expr)} not supported in Substrait compiler yet" + ) + + @_compile_expression.register + def _compile_scalar_constant( + self, expr: ex.ScalarConstantExpression, child: nodes.BigFrameNode + ) -> algebra_pb2.Expression: + import datetime + + pb_expr = algebra_pb2.Expression() + val = expr.value + if isinstance(val, bool): + pb_expr.literal.boolean = val + elif isinstance(val, int): + pb_expr.literal.i64 = val + elif isinstance(val, float): + pb_expr.literal.fp64 = val + elif isinstance(val, str): + pb_expr.literal.string = val + elif isinstance(val, (pd.Timestamp, datetime.datetime)): + if getattr(val, "tzinfo", None) is not None: + epoch = pd.Timestamp("1970-01-01", tz=val.tzinfo) + us = int((val - epoch).total_seconds() * 1_000_000) + pb_expr.literal.precision_timestamp_tz.precision = 6 + pb_expr.literal.precision_timestamp_tz.value = us + else: + epoch = pd.Timestamp("1970-01-01") + us = int((val - epoch).total_seconds() * 1_000_000) + pb_expr.literal.precision_timestamp.precision = 6 + pb_expr.literal.precision_timestamp.value = us + elif isinstance(val, datetime.date): + date_epoch = datetime.date(1970, 1, 1) + days = (val - date_epoch).days + pb_expr.literal.date = days + elif pd.isna(val): # type: ignore[call-overload] + pb_expr.literal.null.varchar.length = 0 + else: + pb_expr.literal.string = str(val) + return pb_expr + + @_compile_expression.register + def _compile_deref( + self, expr: ex.DerefOp, child: nodes.BigFrameNode + ) -> algebra_pb2.Expression: + pb_expr = algebra_pb2.Expression() + try: + idx = list(child.ids).index(expr.id) + pb_expr.selection.direct_reference.struct_field.field = idx + return pb_expr + except ValueError: + raise ValueError(f"Column {expr.id} not found in child schema") + + @_compile_expression.register + def _compile_op_expr( + self, expr: ex.OpExpression, child: nodes.BigFrameNode + ) -> algebra_pb2.Expression: + pb_expr = self._compile_op(expr.op, expr.inputs, child) + if pb_expr.HasField("scalar_function"): + if not pb_expr.scalar_function.HasField("output_type"): + output_dtype = self._get_expression_dtype(expr, child) + type_dict = self._convert_type(output_dtype) + json_format.ParseDict(type_dict, pb_expr.scalar_function.output_type) + return pb_expr + + @singledispatchmethod + def _compile_op( + self, op: Any, inputs: Sequence[ex.Expression], child: nodes.BigFrameNode + ) -> algebra_pb2.Expression: + raise NotImplementedError( + f"Op type {type(op)} not supported in Substrait compiler yet" + ) + + @_compile_op.register(ops.AsTypeOp) + def _compile_astype( + self, + op: ops.AsTypeOp, + inputs: Sequence[ex.Expression], + child: nodes.BigFrameNode, + ) -> algebra_pb2.Expression: + arg_expr = self._compile_expression(inputs[0], child) + from_dtype = self._get_expression_dtype(inputs[0], child) + + if op.to_type == dtypes.STRING_DTYPE: + if from_dtype == dtypes.DATETIME_DTYPE: + # This is only reached for Acero (dialect="substrait-acero"), + # because DataFusion was lowered to ReplaceStrOp in lowering.py! + if not self._use_precision_types: + # Cast to precision_timestamp with precision 0 first, then to string + second_ts_expr = self._compile_cast_with_type_dict( + arg_expr, {"precision_timestamp": {"precision": 0}} + ) + return self._compile_cast(second_ts_expr, dtypes.STRING_DTYPE) + + return self._compile_cast(arg_expr, op.to_type) + + @_compile_op.register(string_ops.ReplaceStrOp) + def _compile_replace_str( + self, + op: string_ops.ReplaceStrOp, + inputs: Sequence[ex.Expression], + child: nodes.BigFrameNode, + ) -> algebra_pb2.Expression: + arg_expr = self._compile_expression(inputs[0], child) + return self._compile_replace(arg_expr, op.pat, op.repl) + + def _compile_cast_with_type_dict( + self, input_expr: algebra_pb2.Expression, type_dict: Dict[str, Any] + ) -> algebra_pb2.Expression: + pb_expr = algebra_pb2.Expression() + cast = pb_expr.cast + cast.input.CopyFrom(input_expr) + json_format.ParseDict(type_dict, cast.type) + cast.failure_behavior = ( + algebra_pb2.Expression.Cast.FAILURE_BEHAVIOR_THROW_EXCEPTION + ) + return pb_expr + + def _compile_replace( + self, + str_expr: algebra_pb2.Expression, + search: str, + replacement: str, + ) -> algebra_pb2.Expression: + pb_expr = algebra_pb2.Expression() + pb_expr.scalar_function.function_reference = ( + 76 # "replace" or "replace_substring" + ) + + pb_expr.scalar_function.arguments.add().value.CopyFrom(str_expr) + + search_expr = algebra_pb2.Expression() + search_expr.literal.string = search + pb_expr.scalar_function.arguments.add().value.CopyFrom(search_expr) + + replace_expr = algebra_pb2.Expression() + replace_expr.literal.string = replacement + pb_expr.scalar_function.arguments.add().value.CopyFrom(replace_expr) + + return pb_expr + + @_compile_op.register(struct_ops.StructOp) + def _compile_struct_op( + self, + op: struct_ops.StructOp, + inputs: Sequence[ex.Expression], + child: nodes.BigFrameNode, + ) -> algebra_pb2.Expression: + pb_expr = algebra_pb2.Expression() + pb_expr.scalar_function.function_reference = self._EXTENSIONS["struct"] + for arg in inputs: + arg_expr = self._compile_expression(arg, child) + pb_expr.scalar_function.arguments.add().value.CopyFrom(arg_expr) + return pb_expr + + @_compile_op.register(struct_ops.StructFieldOp) + def _compile_struct_field_op( + self, + op: struct_ops.StructFieldOp, + inputs: Sequence[ex.Expression], + child: nodes.BigFrameNode, + ) -> algebra_pb2.Expression: + pb_expr = algebra_pb2.Expression() + pb_expr.scalar_function.function_reference = self._EXTENSIONS["get_field"] + + # Arg 0: the struct + arg_expr = self._compile_expression(inputs[0], child) + pb_expr.scalar_function.arguments.add().value.CopyFrom(arg_expr) + + # Arg 1: the field name as string literal + literal_expr = algebra_pb2.Expression() + literal_expr.literal.string = str(op.name_or_index) + pb_expr.scalar_function.arguments.add().value.CopyFrom(literal_expr) + return pb_expr + + def _compile_cast( + self, input_expr: algebra_pb2.Expression, target_dtype: Any + ) -> algebra_pb2.Expression: + if input_expr.HasField("literal") and input_expr.literal.HasField("null"): + pb_expr = algebra_pb2.Expression() + type_dict = self._convert_type(target_dtype) + json_format.ParseDict(type_dict, pb_expr.literal.null) + return pb_expr + + pb_expr = algebra_pb2.Expression() + cast = pb_expr.cast + cast.input.CopyFrom(input_expr) + + type_dict = self._convert_type(target_dtype) + json_format.ParseDict(type_dict, cast.type) + + # alternative: FAILURE_BEHAVIOR_RETURN_NULL not supported by acero + cast.failure_behavior = ( + algebra_pb2.Expression.Cast.FAILURE_BEHAVIOR_THROW_EXCEPTION + ) + return pb_expr + + def _get_expression_dtype( + self, expr: ex.Expression, child: nodes.BigFrameNode + ) -> Any: + import bigframes.dtypes as dtypes + + if isinstance(expr, ex.ScalarConstantExpression): + if expr.value is None or pd.isna(expr.value): # type: ignore[call-overload] + return None + return expr.dtype or dtypes.infer_literal_type(expr.value) + elif isinstance(expr, ex.DerefOp): + try: + idx = list(child.ids).index(expr.id) + return child.schema.items[idx].dtype + except ValueError: + pass + elif isinstance(expr, ex.OpExpression): + try: + input_dtypes = [ + self._get_expression_dtype(inp, child) for inp in expr.inputs + ] + return expr.op.output_type(*input_dtypes) + except Exception: + pass + return dtypes.STRING_DTYPE + + def _get_common_type(self, dtypes_list: Sequence[Any]) -> Any: + import bigframes.dtypes as dtypes + + non_null_dtypes = [dt for dt in dtypes_list if dt is not None] + if not non_null_dtypes: + return dtypes.STRING_DTYPE + if len(set(non_null_dtypes)) == 1: + return non_null_dtypes[0] + if any(dt == dtypes.STRING_DTYPE for dt in non_null_dtypes): + return dtypes.STRING_DTYPE + if any(dt == dtypes.FLOAT_DTYPE for dt in non_null_dtypes): + return dtypes.FLOAT_DTYPE + if any(dt == dtypes.INT_DTYPE for dt in non_null_dtypes): + return dtypes.INT_DTYPE + return dtypes.STRING_DTYPE + + @_compile_op.register(ops.CaseWhenOp) + def _compile_casewhen( + self, + op: ops.CaseWhenOp, + inputs: Sequence[ex.Expression], + child: nodes.BigFrameNode, + ) -> algebra_pb2.Expression: + pb_expr = algebra_pb2.Expression() + ifthen = pb_expr.if_then + + then_dtypes = [ + self._get_expression_dtype(inputs[idx], child) + for idx in range(1, len(inputs), 2) + ] + common_dtype = self._get_common_type(then_dtypes) + + for idx in range(0, len(inputs), 2): + pred = self._compile_expression(inputs[idx], child) + val_expr = self._compile_expression(inputs[idx + 1], child) + + val_dtype = then_dtypes[idx // 2] + if val_dtype != common_dtype: + val = self._compile_cast(val_expr, common_dtype) + else: + val = val_expr + + if_clause = ifthen.ifs.add() + getattr(if_clause, "if").CopyFrom(pred) + if_clause.then.CopyFrom(val) + + type_dict = self._convert_type(common_dtype) + json_format.ParseDict(type_dict, getattr(ifthen, "else").literal.null) + return pb_expr + + @_compile_op.register(generic_ops.WhereOp) + def _compile_where( + self, + op: generic_ops.WhereOp, + inputs: Sequence[ex.Expression], + child: nodes.BigFrameNode, + ) -> algebra_pb2.Expression: + pb_expr = algebra_pb2.Expression() + ifthen = pb_expr.if_then + + pred = self._compile_expression(inputs[1], child) + then_val = self._compile_expression(inputs[0], child) + else_val = self._compile_expression(inputs[2], child) + + then_dtype = self._get_expression_dtype(inputs[0], child) + else_dtype = self._get_expression_dtype(inputs[2], child) + common_dtype = self._get_common_type([then_dtype, else_dtype]) + + casted_then = self._compile_cast(then_val, common_dtype) + casted_else = self._compile_cast(else_val, common_dtype) + + if_clause = ifthen.ifs.add() + getattr(if_clause, "if").CopyFrom(pred) + if_clause.then.CopyFrom(casted_then) + + getattr(ifthen, "else").CopyFrom(casted_else) + return pb_expr + + @_compile_op.register(numeric_ops.DivOp) + def _compile_div_op( + self, + op: numeric_ops.DivOp, + inputs: Sequence[ex.Expression], + child: nodes.BigFrameNode, + ) -> algebra_pb2.Expression: + import bigframes.dtypes as dtypes + + pb_expr = algebra_pb2.Expression() + pb_expr.scalar_function.function_reference = self._EXTENSIONS["divide"] + for arg in inputs: + arg_expr = self._compile_expression(arg, child) + casted_arg = self._compile_cast(arg_expr, dtypes.FLOAT_DTYPE) + pb_expr.scalar_function.arguments.add().value.CopyFrom(casted_arg) + return pb_expr + + @_compile_op.register(numeric_ops.FloorDivOp) + def _compile_floor_div_op( + self, + op: numeric_ops.FloorDivOp, + inputs: Sequence[ex.Expression], + child: nodes.BigFrameNode, + ) -> algebra_pb2.Expression: + import bigframes.dtypes as dtypes + + dividend_expr = self._compile_expression(inputs[0], child) + divisor_expr = self._compile_expression(inputs[1], child) + + # Calculate standard floor division + div_expr = algebra_pb2.Expression() + div_expr.scalar_function.function_reference = self._EXTENSIONS["divide"] + + # Cast to float for standard division + casted_dividend = self._compile_cast(dividend_expr, dtypes.FLOAT_DTYPE) + casted_divisor = self._compile_cast(divisor_expr, dtypes.FLOAT_DTYPE) + + div_expr.scalar_function.arguments.add().value.CopyFrom(casted_dividend) + div_expr.scalar_function.arguments.add().value.CopyFrom(casted_divisor) + + floor_expr = algebra_pb2.Expression() + floor_expr.scalar_function.function_reference = self._EXTENSIONS["floor"] + floor_expr.scalar_function.arguments.add().value.CopyFrom(div_expr) + + # If both operands are integer/boolean, we short-circuit division by 0 to return 0 + left_dtype = self._get_expression_dtype(inputs[0], child) + right_dtype = self._get_expression_dtype(inputs[1], child) + + is_left_int = left_dtype == dtypes.INT_DTYPE or left_dtype == dtypes.BOOL_DTYPE + is_right_int = ( + right_dtype == dtypes.INT_DTYPE or right_dtype == dtypes.BOOL_DTYPE + ) + + if is_left_int and is_right_int: + # If divisor is 0, return 0 * dividend (to propagate nulls) + zero_i64 = algebra_pb2.Expression() + zero_i64.literal.i64 = 0 + + eq_expr = algebra_pb2.Expression() + eq_expr.scalar_function.function_reference = self._EXTENSIONS["equal"] + eq_expr.scalar_function.arguments.add().value.CopyFrom(divisor_expr) + eq_expr.scalar_function.arguments.add().value.CopyFrom(zero_i64) + + zero_result = algebra_pb2.Expression() + zero_result.scalar_function.function_reference = self._EXTENSIONS[ + "multiply" + ] + zero_result.scalar_function.arguments.add().value.CopyFrom(dividend_expr) + zero_result.scalar_function.arguments.add().value.CopyFrom(zero_i64) + + pb_expr = algebra_pb2.Expression() + ifthen = pb_expr.if_then + if_clause = ifthen.ifs.add() + getattr(if_clause, "if").CopyFrom(eq_expr) + if_clause.then.CopyFrom(zero_result) + + # Else, cast float floor_expr to int64 + casted_floor = self._compile_cast(floor_expr, dtypes.INT_DTYPE) + getattr(ifthen, "else").CopyFrom(casted_floor) + return pb_expr + + return floor_expr + + @_compile_op.register(generic_ops.IsInOp) + def _compile_isin( + self, + op: generic_ops.IsInOp, + inputs: Sequence[ex.Expression], + child: nodes.BigFrameNode, + ) -> algebra_pb2.Expression: + pb_expr = algebra_pb2.Expression() + pb_expr.singular_or_list.value.CopyFrom( + self._compile_expression(inputs[0], child) + ) + for val in op.values: + opt_expr = self._compile_expression(ex.const(val), child) + pb_expr.singular_or_list.options.add().CopyFrom(opt_expr) + return pb_expr + + @_compile_op.register(generic_ops.FillNaOp) + def _compile_fillna_op( + self, + op: ops.BinaryOp, + inputs: Sequence[ex.Expression], + child: nodes.BigFrameNode, + ) -> algebra_pb2.Expression: + first_expr = self._compile_expression(inputs[0], child) + first_dtype = self._get_expression_dtype(inputs[0], child) + second_expr = self._compile_expression(inputs[1], child) + second_dtype = self._get_expression_dtype(inputs[1], child) + + if first_dtype is not None and second_dtype != first_dtype: + second_expr = self._compile_cast(second_expr, first_dtype) + + pb_expr = algebra_pb2.Expression() + pb_expr.scalar_function.function_reference = self._EXTENSIONS["coalesce"] + pb_expr.scalar_function.arguments.add().value.CopyFrom(first_expr) + pb_expr.scalar_function.arguments.add().value.CopyFrom(second_expr) + return pb_expr + + @_compile_op.register(generic_ops.CoalesceOp) + @_compile_op.register(numeric_ops.AddOp) + @_compile_op.register(numeric_ops.SubOp) + @_compile_op.register(numeric_ops.MulOp) + @_compile_op.register(numeric_ops.PowOp) + @_compile_op.register(numeric_ops.UnsafePowOp) + @_compile_op.register(comparison_ops.EqOp) + @_compile_op.register(comparison_ops.NeOp) + @_compile_op.register(comparison_ops.LtOp) + @_compile_op.register(comparison_ops.GtOp) + @_compile_op.register(comparison_ops.LeOp) + @_compile_op.register(comparison_ops.GeOp) + def _compile_basic_binops( + self, op: Any, inputs: Sequence[ex.Expression], child: nodes.BigFrameNode + ) -> algebra_pb2.Expression: + op_class = type(op) + ext_name = self._OP_TO_EXTENSION[op_class] + return self._compile_basic_binop(ext_name, inputs, child) + + @_compile_op.register(bool_ops.AndOp) + @_compile_op.register(bool_ops.OrOp) + @_compile_op.register(bool_ops.XorOp) + def _compile_logical_binops( + self, op: Any, inputs: Sequence[ex.Expression], child: nodes.BigFrameNode + ) -> algebra_pb2.Expression: + import bigframes.dtypes as dtypes + + input_dtype = self._get_expression_dtype(inputs[0], child) + if input_dtype == dtypes.INT_DTYPE: + if isinstance(op, bool_ops.AndOp): + ext_name = "bitwise_and" + elif isinstance(op, bool_ops.OrOp): + ext_name = "bitwise_or" + elif isinstance(op, bool_ops.XorOp): + ext_name = "bitwise_xor" + else: + raise NotImplementedError(f"Unsupported binary bitwise op: {type(op)}") + else: + op_class = type(op) + ext_name = self._OP_TO_EXTENSION[op_class] + return self._compile_basic_binop(ext_name, inputs, child) + + def _compile_basic_binop( + self, ext_name: str, inputs: Sequence[ex.Expression], child: nodes.BigFrameNode + ) -> algebra_pb2.Expression: + pb_expr = algebra_pb2.Expression() + pb_expr.scalar_function.function_reference = self._EXTENSIONS[ext_name] + for arg in inputs: + arg_expr = self._compile_expression(arg, child) + pb_expr.scalar_function.arguments.add().value.CopyFrom(arg_expr) + return pb_expr + + @_compile_op.register(numeric_ops.ModOp) + def _compile_mod_op( + self, + op: numeric_ops.ModOp, + inputs: Sequence[ex.Expression], + child: nodes.BigFrameNode, + ) -> algebra_pb2.Expression: + import bigframes.dtypes as dtypes + + a_expr = self._compile_expression(inputs[0], child) + b_expr = self._compile_expression(inputs[1], child) + + div_expr = algebra_pb2.Expression() + div_expr.scalar_function.function_reference = self._EXTENSIONS["divide"] + + a_float = self._compile_cast(a_expr, dtypes.FLOAT_DTYPE) + b_float = self._compile_cast(b_expr, dtypes.FLOAT_DTYPE) + div_expr.scalar_function.arguments.add().value.CopyFrom(a_float) + div_expr.scalar_function.arguments.add().value.CopyFrom(b_float) + + floor_expr = algebra_pb2.Expression() + floor_expr.scalar_function.function_reference = self._EXTENSIONS["floor"] + floor_expr.scalar_function.arguments.add().value.CopyFrom(div_expr) + + mul_expr = algebra_pb2.Expression() + mul_expr.scalar_function.function_reference = self._EXTENSIONS["multiply"] + mul_expr.scalar_function.arguments.add().value.CopyFrom(b_float) + mul_expr.scalar_function.arguments.add().value.CopyFrom(floor_expr) + + sub_expr = algebra_pb2.Expression() + sub_expr.scalar_function.function_reference = self._EXTENSIONS["subtract"] + sub_expr.scalar_function.arguments.add().value.CopyFrom(a_float) + sub_expr.scalar_function.arguments.add().value.CopyFrom(mul_expr) + + a_dtype = self._get_expression_dtype(inputs[0], child) + b_dtype = self._get_expression_dtype(inputs[1], child) + common_dtype = self._get_common_type([a_dtype, b_dtype]) + + if common_dtype == dtypes.INT_DTYPE: + return self._compile_cast(sub_expr, dtypes.INT_DTYPE) + return sub_expr + + @_compile_op.register(numeric_ops.AbsOp) + @_compile_op.register(numeric_ops.CeilOp) + @_compile_op.register(numeric_ops.FloorOp) + @_compile_op.register(generic_ops.IsNullOp) + @_compile_op.register(generic_ops.NotNullOp) + def _compile_standard_unaryops( + self, op: Any, inputs: Sequence[ex.Expression], child: nodes.BigFrameNode + ) -> algebra_pb2.Expression: + op_class = type(op) + ext_name = self._OP_TO_EXTENSION[op_class] + return self._compile_basic_unaryop(ext_name, inputs, child) + + @_compile_op.register(numeric_ops.PosOp) + def _compile_pos_op( + self, + op: ops.UnaryOp, + inputs: Sequence[ex.Expression], + child: nodes.BigFrameNode, + ) -> algebra_pb2.Expression: + # Unary plus is a no-op + return self._compile_expression(inputs[0], child) + + @_compile_op.register(numeric_ops.NegOp) + def _compile_neg_op( + self, + op: ops.UnaryOp, + inputs: Sequence[ex.Expression], + child: nodes.BigFrameNode, + ) -> algebra_pb2.Expression: + # Compile negation as subtraction: 0 - x + arg_expr = self._compile_expression(inputs[0], child) + arg_dtype = self._get_expression_dtype(inputs[0], child) + + zero_expr = algebra_pb2.Expression() + if arg_dtype == dtypes.FLOAT_DTYPE: + zero_expr.literal.fp64 = 0.0 + else: + zero_expr.literal.i64 = 0 + + sub_expr = algebra_pb2.Expression() + sub_expr.scalar_function.function_reference = self._EXTENSIONS["subtract"] + sub_expr.scalar_function.arguments.add().value.CopyFrom(zero_expr) + sub_expr.scalar_function.arguments.add().value.CopyFrom(arg_expr) + return sub_expr + + @_compile_op.register(generic_ops.InvertOp) + def _compile_invert_op( + self, + op: ops.UnaryOp, + inputs: Sequence[ex.Expression], + child: nodes.BigFrameNode, + ) -> algebra_pb2.Expression: + arg_expr = self._compile_expression(inputs[0], child) + arg_dtype = self._get_expression_dtype(inputs[0], child) + + if arg_dtype == dtypes.BOOL_DTYPE: + # Logical negation + not_expr = algebra_pb2.Expression() + not_expr.scalar_function.function_reference = self._EXTENSIONS["not"] + not_expr.scalar_function.arguments.add().value.CopyFrom(arg_expr) + return not_expr + else: + # Bitwise negation (two's complement mathematically equivalent to: -x - 1) + zero_i64 = algebra_pb2.Expression() + zero_i64.literal.i64 = 0 + + neg_expr = algebra_pb2.Expression() + neg_expr.scalar_function.function_reference = self._EXTENSIONS["subtract"] + neg_expr.scalar_function.arguments.add().value.CopyFrom(zero_i64) + neg_expr.scalar_function.arguments.add().value.CopyFrom(arg_expr) + + one_i64 = algebra_pb2.Expression() + one_i64.literal.i64 = 1 + + result_expr = algebra_pb2.Expression() + result_expr.scalar_function.function_reference = self._EXTENSIONS[ + "subtract" + ] + result_expr.scalar_function.arguments.add().value.CopyFrom(neg_expr) + result_expr.scalar_function.arguments.add().value.CopyFrom(one_i64) + return result_expr + + def _compile_basic_unaryop( + self, ext_name: str, inputs: Sequence[ex.Expression], child: nodes.BigFrameNode + ) -> algebra_pb2.Expression: + pb_expr = algebra_pb2.Expression() + pb_expr.scalar_function.function_reference = self._EXTENSIONS[ext_name] + arg_expr = self._compile_expression(inputs[0], child) + pb_expr.scalar_function.arguments.add().value.CopyFrom(arg_expr) + return pb_expr + + def _convert_schema(self, schema: Any) -> Dict[str, Any]: + # Convert bigframes schema to Substrait Type.NamedStruct + fields = [] + types = [] + for item in schema.items: + col = item.column + name = col.name if hasattr(col, "name") else str(col) + fields.append(name) + types.append(self._convert_type(item.dtype)) + + return {"names": fields, "struct": {"types": types}} + + def _get_substrait_names(self, name: str, dtype: Any) -> list[str]: + import bigframes.dtypes as dtypes + + names = [name] + if dtypes.is_struct_like(dtype): + fields_dict = dtypes.get_struct_fields(dtype) + for f_name, f_dtype in fields_dict.items(): + names.extend(self._get_substrait_names(f_name, f_dtype)) + return names + + def _convert_type(self, dtype: Any) -> Dict[str, Any]: + import bigframes.dtypes + + if dtype == bigframes.dtypes.INT_DTYPE: + return {"i64": {}} + elif dtype == bigframes.dtypes.FLOAT_DTYPE: + return {"fp64": {}} + elif dtype == bigframes.dtypes.BOOL_DTYPE: + return {"bool": {}} + elif dtype == bigframes.dtypes.STRING_DTYPE: + return {"string": {}} + elif dtype == bigframes.dtypes.BYTES_DTYPE: + return {"binary": {}} + elif dtype == bigframes.dtypes.DATE_DTYPE: + return {"date": {}} + elif dtype == bigframes.dtypes.DATETIME_DTYPE: + if self._use_precision_types: + return {"precision_timestamp": {"precision": 6}} + else: + return {"timestamp": {}} + elif dtype == bigframes.dtypes.TIMESTAMP_DTYPE: + if self._use_precision_types: + return {"precision_timestamp_tz": {"precision": 6}} + else: + return {"timestamp_tz": {}} + elif dtype == bigframes.dtypes.TIME_DTYPE: + if self._use_precision_types: + # type_variation_reference 1 is for time64, precision 6 is for microseconds + return { + "precision_time": {"precision": 6, "type_variation_reference": 1} + } + else: + return {"time": {}} + elif dtype in ( + bigframes.dtypes.NUMERIC_DTYPE, + bigframes.dtypes.BIGNUMERIC_DTYPE, + ): + arrow_dtype = dtype.pyarrow_dtype + return { + "decimal": { + "precision": arrow_dtype.precision, + "scale": arrow_dtype.scale, + } + } + elif dtype == bigframes.dtypes.TIMEDELTA_DTYPE: + if self._duration_type == "interval_day": + return {"interval_day": {"precision": 6, "type_variation_reference": 1}} + else: + return {"i64": {}} + elif bigframes.dtypes.is_struct_like(dtype): + fields_dict = bigframes.dtypes.get_struct_fields(dtype) + return { + "struct": { + "types": [ + self._convert_type(f_dtype) for f_dtype in fields_dict.values() + ] + } + } + elif bigframes.dtypes.is_array_like(dtype): + inner_dtype = bigframes.dtypes.get_array_inner_type(dtype) + return {"list": {"type": self._convert_type(inner_dtype)}} + else: + # Fallback to string for now + return {"string": {}} diff --git a/packages/bigframes/bigframes/core/rewrite/__init__.py b/packages/bigframes/bigframes/core/rewrite/__init__.py index ae4b142b1a46..057b1ec0bc5e 100644 --- a/packages/bigframes/bigframes/core/rewrite/__init__.py +++ b/packages/bigframes/bigframes/core/rewrite/__init__.py @@ -25,8 +25,13 @@ try_reduce_to_local_scan, try_reduce_to_table_scan, ) +from bigframes.core.rewrite.schema_binding import bind_schema_to_tree from bigframes.core.rewrite.select_pullup import defer_selection from bigframes.core.rewrite.slices import pull_out_limit, pull_up_limits, rewrite_slice +from bigframes.core.rewrite.substrait_agg import ( + rewrite_substrait_aggregations, + rewrite_substrait_windows, +) from bigframes.core.rewrite.timedeltas import rewrite_timedelta_expressions from bigframes.core.rewrite.udfs import lower_udfs from bigframes.core.rewrite.windows import ( @@ -37,10 +42,13 @@ __all__ = [ "as_sql_nodes", + "bind_schema_to_tree", "extract_ctes", "legacy_join_as_projection", "try_row_join", "rewrite_slice", + "rewrite_substrait_aggregations", + "rewrite_substrait_windows", "rewrite_timedelta_expressions", "pull_up_limits", "pull_out_limit", diff --git a/packages/bigframes/bigframes/core/rewrite/substrait_agg.py b/packages/bigframes/bigframes/core/rewrite/substrait_agg.py new file mode 100644 index 000000000000..5017c52dcec9 --- /dev/null +++ b/packages/bigframes/bigframes/core/rewrite/substrait_agg.py @@ -0,0 +1,287 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import dataclasses + +import bigframes.dtypes as dtypes +import bigframes.operations as ops +import bigframes.operations.aggregations as agg_ops +from bigframes.core import ( + agg_expressions, + expression, + identifiers, + nodes, +) + + +def rewrite_substrait_aggregations(node: nodes.BigFrameNode) -> nodes.BigFrameNode: + """ + Rewrites AggregateNodes for Substrait compatibility: + 1. Pre-projects casts (like bool->float) and size literals, ensuring aggregate arguments + are direct references. + 2. Post-projects casts to enforce correct Python output schema/types. + """ + if not isinstance(node, nodes.AggregateNode): + return node + + child = node.child + child_ids = list(child.ids) + + # Collect size aggregations + size_aggs = [ + (i, agg) + for i, (agg, _) in enumerate(node.aggregations) + if isinstance(agg.op, (agg_ops.SizeOp, agg_ops.SizeUnaryOp)) + ] + + # Collect cast aggregations (bool->float for mean/stddev/var, bool->int for sum) + cast_aggs: list[tuple[int, identifiers.ColumnId, dtypes.Dtype]] = [] + for agg_idx, (agg, _) in enumerate(node.aggregations): + if hasattr(agg, "column_references"): + for col_id in agg.column_references: + idx = child_ids.index(col_id) + col_dtype = child.schema.items[idx].dtype + is_bool = col_dtype == dtypes.BOOL_DTYPE + if isinstance( + agg.op, (agg_ops.StdOp, agg_ops.VarOp, agg_ops.PopVarOp) + ) or (isinstance(agg.op, agg_ops.MeanOp) and is_bool): + cast_aggs.append((agg_idx, col_id, dtypes.FLOAT_DTYPE)) + elif isinstance(agg.op, agg_ops.SumOp) and is_bool: + cast_aggs.append((agg_idx, col_id, dtypes.INT_DTYPE)) + + # If we need pre-projection (casts or size constants) + if size_aggs or cast_aggs: + assignments = [] + + cast_agg_to_col_id = {} + for agg_idx, col_id, target_dtype in cast_aggs: + new_id = identifiers.ColumnId(f"bf_cast_{col_id.name}_{agg_idx}") + cast_expr = ops.AsTypeOp(to_type=target_dtype).as_expr( + expression.deref(col_id.name) + ) + assignments.append((cast_expr, new_id)) + cast_agg_to_col_id[(agg_idx, col_id)] = new_id + + size_agg_to_col_id = {} + for size_idx, (agg_idx, _) in enumerate(size_aggs): + new_id = identifiers.ColumnId(f"bf_size_const_{agg_idx}") + const_expr = expression.const(size_idx + 1) + assignments.append((const_expr, new_id)) + size_agg_to_col_id[agg_idx] = new_id + + # Wrap child in ProjectionNode + pre_project = nodes.ProjectionNode( + child, + assignments=tuple(assignments), + ) + child = pre_project + + # Rewrite aggregations to use the projected columns + rewritten_aggs = [] + for agg_idx, (agg, out_col_id) in enumerate(node.aggregations): + rewritten_agg: agg_expressions.Aggregation + if isinstance(agg.op, (agg_ops.SizeOp, agg_ops.SizeUnaryOp)): + new_col_id = size_agg_to_col_id[agg_idx] + rewritten_agg = agg_expressions.UnaryAggregation( + agg_ops.SizeUnaryOp(), expression.deref(new_col_id.name) + ) + elif hasattr(agg, "column_references"): + new_exprs = [] + for col_id in agg.column_references: + if (agg_idx, col_id) in cast_agg_to_col_id: + new_exprs.append( + expression.deref(cast_agg_to_col_id[(agg_idx, col_id)].name) + ) + else: + new_exprs.append(expression.deref(col_id.name)) + + rewritten_agg = agg.replace_args(*new_exprs) + else: + rewritten_agg = agg + + rewritten_aggs.append((rewritten_agg, out_col_id)) + + node = dataclasses.replace( + node, + child=child, + aggregations=tuple(rewritten_aggs), + ) + + # Post-projection to enforce output schema types: + group_ids = [deref.id for deref in node.by_column_ids] + agg_ids = [out_id for _, out_id in node.aggregations] + output_ids = group_ids + agg_ids + + expected_types = [item.dtype for item in node.schema.items] + + assignments = [] + selection_pairs = [] + for idx, (out_id, out_dtype) in enumerate(zip(output_ids, expected_types)): + cast_id = identifiers.ColumnId(f"bf_out_cast_{out_id.name}") + cast_expr = ops.AsTypeOp(to_type=out_dtype).as_expr( + expression.deref(out_id.name) + ) + assignments.append((cast_expr, cast_id)) + + selection_pairs.append( + (nodes.AliasedRef(expression.deref(cast_id.name), out_id), out_id) + ) + + post_project = nodes.ProjectionNode( + node, + assignments=tuple(assignments), + ) + post_selection = nodes.SelectionNode( + post_project, + input_output_pairs=tuple(ref for ref, _ in selection_pairs), + ) + + return post_selection + + +def rewrite_substrait_windows(node: nodes.BigFrameNode) -> nodes.BigFrameNode: + """ + Rewrites WindowOpNode for Substrait compatibility: + 1. Pre-projects casts (like bool->float) and size literals, ensuring aggregate arguments + are direct references. + 2. Post-projects casts to enforce correct Python output schema/types for window columns. + """ + if not isinstance(node, nodes.WindowOpNode): + return node + + child = node.child + child_ids = list(child.ids) + + # Collect size and cast requirements for the window agg expressions + size_aggs = [] + cast_aggs: list[tuple[int, identifiers.ColumnId, dtypes.Dtype]] = [] + + for agg_idx, col_def in enumerate(node.agg_exprs): + agg = col_def.expression + assert isinstance(agg, agg_expressions.Aggregation) + if isinstance(agg.op, (agg_ops.SizeOp, agg_ops.SizeUnaryOp)): + size_aggs.append((agg_idx, agg)) + elif hasattr(agg, "column_references"): + for col_id in agg.column_references: + idx = child_ids.index(col_id) + col_dtype = child.schema.items[idx].dtype + is_bool = col_dtype == dtypes.BOOL_DTYPE + if isinstance( + agg.op, (agg_ops.StdOp, agg_ops.VarOp, agg_ops.PopVarOp) + ) or (isinstance(agg.op, agg_ops.MeanOp) and is_bool): + cast_aggs.append((agg_idx, col_id, dtypes.FLOAT_DTYPE)) + elif isinstance(agg.op, agg_ops.SumOp) and is_bool: + cast_aggs.append((agg_idx, col_id, dtypes.INT_DTYPE)) + + # If we need pre-projection (casts or size constants) + if size_aggs or cast_aggs: + assignments = [] + + cast_agg_to_col_id = {} + for agg_idx, col_id, target_dtype in cast_aggs: + new_id = identifiers.ColumnId(f"bf_window_cast_{col_id.name}_{agg_idx}") + cast_expr = ops.AsTypeOp(to_type=target_dtype).as_expr( + expression.deref(col_id.name) + ) + assignments.append((cast_expr, new_id)) + cast_agg_to_col_id[(agg_idx, col_id)] = new_id + + size_agg_to_col_id = {} + for size_idx, (agg_idx, _) in enumerate(size_aggs): + new_id = identifiers.ColumnId(f"bf_window_size_const_{agg_idx}") + const_expr = expression.const(size_idx + 1) + assignments.append((const_expr, new_id)) + size_agg_to_col_id[agg_idx] = new_id + + # Wrap child in ProjectionNode + pre_project = nodes.ProjectionNode( + child, + assignments=tuple(assignments), + ) + child = pre_project + + # Rewrite window expressions to use the projected columns + rewritten_agg_exprs = [] + for agg_idx, col_def in enumerate(node.agg_exprs): + agg = col_def.expression + assert isinstance(agg, agg_expressions.Aggregation) + out_col_id = col_def.id + rewritten_agg: agg_expressions.Aggregation + + if isinstance(agg.op, (agg_ops.SizeOp, agg_ops.SizeUnaryOp)): + new_col_id = size_agg_to_col_id[agg_idx] + rewritten_agg = agg_expressions.UnaryAggregation( + agg_ops.SizeUnaryOp(), expression.deref(new_col_id.name) + ) + elif hasattr(agg, "column_references"): + new_exprs = [] + for col_id in agg.column_references: + if (agg_idx, col_id) in cast_agg_to_col_id: + new_exprs.append( + expression.deref(cast_agg_to_col_id[(agg_idx, col_id)].name) + ) + else: + new_exprs.append(expression.deref(col_id.name)) + + rewritten_agg = agg.replace_args(*new_exprs) + else: + rewritten_agg = agg + + rewritten_agg_exprs.append(nodes.ColumnDef(rewritten_agg, out_col_id)) + + node = dataclasses.replace( + node, + child=child, + agg_exprs=tuple(rewritten_agg_exprs), + ) + + # Post-projection to enforce output schema types for newly introduced window columns: + child_output_ids = ( + list(node.child.ids) + if not isinstance(child, nodes.ProjectionNode) + else list(child.child.ids) + ) + + assignments = [] + selection_pairs = [] + + for child_id in child_output_ids: + selection_pairs.append((nodes.AliasedRef.identity(child_id), child_id)) + + for col_def, field in zip(node.agg_exprs, node.added_fields): + out_id = col_def.id + out_dtype = field.dtype + + cast_id = identifiers.ColumnId(f"bf_window_out_cast_{out_id.name}") + cast_expr = ops.AsTypeOp(to_type=out_dtype).as_expr( + expression.deref(out_id.name) + ) + assignments.append((cast_expr, cast_id)) + + selection_pairs.append( + (nodes.AliasedRef(expression.deref(cast_id.name), out_id), out_id) + ) + + post_project = nodes.ProjectionNode( + node, + assignments=tuple(assignments), + ) + post_selection = nodes.SelectionNode( + post_project, + input_output_pairs=tuple(ref for ref, _ in selection_pairs), + ) + + return post_selection diff --git a/packages/bigframes/bigframes/session/substrait_executor.py b/packages/bigframes/bigframes/session/substrait_executor.py new file mode 100644 index 000000000000..1de3ba3e16d2 --- /dev/null +++ b/packages/bigframes/bigframes/session/substrait_executor.py @@ -0,0 +1,213 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import abc +import asyncio +from typing import TYPE_CHECKING, Optional, cast + +import bigframes.core.compile.substrait.compiler as substrait_compiler +import bigframes.core.rewrite as rewrite +from bigframes.core import bigframe_node, nodes +from bigframes.session import execution_spec, executor, semi_executor + +if TYPE_CHECKING: + import pyarrow as pa + + +class SubstraitConsumer(abc.ABC): + """ + Interface for consuming Substrait plans and executing them. + This acts as a plugin interface for different Substrait execution engines. + """ + + @abc.abstractmethod + def consume(self, plan: bytes, tables: dict[str, pa.Table]) -> pa.Table: + """ + Executes a Substrait plan and returns a PyArrow Table. + + Args: + plan: The Substrait plan as bytes (usually a serialized Protobuf). + tables: A dictionary of table names to PyArrow Tables for local data. + + Returns: + A PyArrow Table containing the results. + """ + pass + + +class DataFusionSubstraitConsumer(SubstraitConsumer): + """ + Executes Substrait plans using Apache DataFusion. + """ + + def consume(self, plan_proto: bytes, tables: dict[str, pa.Table]) -> pa.Table: + # Import datafusion lazily to avoid hard dependency + try: + import datafusion + except ImportError: + raise ImportError( + "The datafusion package is required to use DataFusionSubstraitConsumer. " + "Install it with `pip install datafusion`." + ) + + ctx = datafusion.SessionContext() + + for name, table in tables.items(): + df = ctx.from_arrow(table) + ctx.register_table(name, df) + + import datafusion.substrait + + datafusion_substrait_plan = datafusion.substrait.Serde.deserialize_bytes( + plan_proto + ) + logical_plan = datafusion.substrait.Consumer.from_substrait_plan( + ctx, datafusion_substrait_plan + ) + df = ctx.create_dataframe_from_logical_plan(logical_plan) + return df.to_arrow_table() + + +class AceroSubstraitConsumer(SubstraitConsumer): + """ + Executes Substrait plans using Apache Arrow Acero. + """ + + def consume(self, plan_proto: bytes, tables: dict[str, pa.Table]) -> pa.Table: + import pyarrow.substrait as pa_substrait + + def provide_table(name: list[str], schema: pa.Schema) -> pa.Table: + return tables[name[0]] + + batch_reader = pa_substrait.run_query(plan_proto, table_provider=provide_table) + return batch_reader.read_all() + + +class SubstraitExecutor(semi_executor.SemiExecutor): + """ + Executes plans by compiling them to Substrait and running them via a consumer. + """ + + def __init__( + self, + consumer: SubstraitConsumer, + compiler: substrait_compiler.SubstraitCompiler, + ): + self._consumer = consumer + self._compiler = compiler + + @classmethod + def default_for_engine(cls, engine_name: str) -> SubstraitExecutor: + if engine_name == "acero": + return cls( + AceroSubstraitConsumer(), + substrait_compiler.SubstraitCompiler( + duration_type="int", + use_precision_types=False, + dialect="substrait-acero", + ), + ) + elif engine_name == "datafusion": + return cls( + DataFusionSubstraitConsumer(), + substrait_compiler.SubstraitCompiler( + duration_type="int", + dialect="substrait-datafusion", + ), + ) + else: + raise ValueError(f"Unknown engine: {engine_name}") + + async def execute( + self, + plan: bigframe_node.BigFrameNode, + execution_spec: execution_spec.ExecutionSpec, + ) -> Optional[executor.ExecuteResult]: + ordered = execution_spec.ordered + peek = execution_spec.peek + plan = plan.bottom_up(rewrite.rewrite_slice) + # Only needed for acero technically, datafusion can handle timedeltas + plan = plan.bottom_up(rewrite.rewrite_timedelta_expressions) + plan = plan.bottom_up(rewrite.rewrite_substrait_aggregations) + plan = plan.bottom_up(rewrite.rewrite_substrait_windows) + + from bigframes.core import expression + + output_cols = tuple((expression.DerefOp(id), id.name) for id in plan.ids) + result_node = nodes.ResultNode( + plan, + output_cols=output_cols, + ) + result_node = cast(nodes.ResultNode, rewrite.column_pruning(result_node)) + result_node = rewrite.defer_order(result_node, output_hidden_row_keys=False) + + rewritten_plan = result_node.child + + if ( + ordered + and result_node.order_by + and result_node.order_by.all_ordering_columns + ): + rewritten_plan = nodes.OrderByNode( + rewritten_plan, + by=tuple(result_node.order_by.all_ordering_columns), + ) + + original_ids = tuple(id for id in plan.ids) + if rewritten_plan.ids != original_ids: + rewritten_plan = nodes.SelectionNode( + rewritten_plan, + input_output_pairs=tuple( + nodes.AliasedRef.identity(id) for id in original_ids + ), + ) + + if not self._can_execute(rewritten_plan): + return None + + substrait_plan_proto = self._compiler.compile(rewritten_plan) + if substrait_plan_proto is None: + return None + + tables = {} + for node in rewritten_plan.unique_nodes(): + if isinstance(node, nodes.ReadLocalNode): + table_name = f"table_{id(node)}" + table = node.local_data_source.to_pyarrow_table(duration_type="int") + table = table.select([item.source_id for item in node.scan_list.items]) + table = table.rename_columns( + [item.id.sql for item in node.scan_list.items] + ) + if node.offsets_col is not None: + from bigframes.core import pyarrow_utils + + table = pyarrow_utils.append_offsets(table, node.offsets_col.sql) + tables[table_name] = table + + pa_table = await asyncio.to_thread( + self._consumer.consume, substrait_plan_proto, tables + ) + + if peek is not None: + pa_table = pa_table.slice(0, peek) + + return executor.LocalExecuteResult( + data=pa_table, + bf_schema=rewritten_plan.schema, + ) + + def _can_execute(self, plan: bigframe_node.BigFrameNode) -> bool: + return self._compiler.can_compile(plan) diff --git a/packages/bigframes/bigframes/testing/substrait_session.py b/packages/bigframes/bigframes/testing/substrait_session.py new file mode 100644 index 000000000000..aa5a897798b7 --- /dev/null +++ b/packages/bigframes/bigframes/testing/substrait_session.py @@ -0,0 +1,130 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import weakref +from typing import TYPE_CHECKING, Union + +import pandas + +import bigframes +import bigframes.core.blocks +import bigframes.dataframe +import bigframes.session.execution_spec +import bigframes.session.executor +import bigframes.session.metrics + +if TYPE_CHECKING: + import bigframes.core + + +class SubstraitTestExecutor(bigframes.session.executor.Executor): + def __init__( + self, consumer: bigframes.session.substrait_executor.SubstraitConsumer + ): + from bigframes.core.compile.substrait.compiler import SubstraitCompiler + from bigframes.session.substrait_executor import ( + AceroSubstraitConsumer, + SubstraitExecutor, + ) + + if isinstance(consumer, AceroSubstraitConsumer): + compiler = SubstraitCompiler(duration_type="int", use_precision_types=False) + else: + compiler = SubstraitCompiler(duration_type="int") + + self.executor = SubstraitExecutor(consumer, compiler) + + def execute( + self, + array_value: bigframes.core.ArrayValue, + execution_spec: bigframes.session.execution_spec.ExecutionSpec, + ): + if execution_spec.destination_spec is not None: + raise ValueError( + f"SubstraitTestExecutor does not support destination spec: {execution_spec.destination_spec}" + ) + + result = asyncio.run( + self.executor.execute( + array_value.node, + bigframes.session.execution_spec.ExecutionSpec( + ordered=True, peek=execution_spec.peek + ), + ) + ) + if result is None: + raise NotImplementedError("SubstraitExecutor cannot execute this plan") + + return result + + def cached( + self, + array_value: bigframes.core.ArrayValue, + *, + config, + ) -> None: + return + + +class TestSession(bigframes.session.Session): + def __init__(self, executor: SubstraitTestExecutor): + self._location = None # type: ignore + self._bq_kms_key_name = None # type: ignore + self._clients_provider = None # type: ignore + self._bq_connection = None # type: ignore + self._skip_bq_connection_check = True + self._session_id: str = "substrait_test_session" + self._objects: list[ + weakref.ReferenceType[ + Union[ + bigframes.core.indexes.Index, + bigframes.series.Series, + bigframes.dataframe.DataFrame, + ] + ] + ] = [] + self._strictly_ordered: bool = True + self._allow_ambiguity = False # type: ignore + self._default_index_type = bigframes.enums.DefaultIndexKind.SEQUENTIAL_INT64 + self._metrics = bigframes.session.metrics.ExecutionMetrics() + self._function_session = None # type: ignore + self._temp_storage_manager = None # type: ignore + self._executor = executor + self._loader = None # type: ignore + + def read_pandas(self, pandas_dataframe, write_engine="default"): + original_input = pandas_dataframe + + if isinstance(pandas_dataframe, (pandas.Series, pandas.Index)): + pandas_dataframe = pandas_dataframe.to_frame() + + local_block = bigframes.core.blocks.Block.from_local(pandas_dataframe, self) + bf_df = bigframes.dataframe.DataFrame(local_block) + + if isinstance(original_input, pandas.Series): + series = bf_df[bf_df.columns[0]] + series.name = original_input.name + return series + + if isinstance(original_input, pandas.Index): + return bf_df.index + + return bf_df + + @property + def bqclient(self): + return None diff --git a/packages/bigframes/mypy.ini b/packages/bigframes/mypy.ini index e3f44c262ac6..13a6a34fdb87 100644 --- a/packages/bigframes/mypy.ini +++ b/packages/bigframes/mypy.ini @@ -45,5 +45,11 @@ ignore_missing_imports = True [mypy-anywidget] ignore_missing_imports = True +[mypy-substrait.*] +ignore_missing_imports = True + +[mypy-datafusion.*] +ignore_missing_imports = True + [mypy-bigframes_vendored.*] ignore_errors = True diff --git a/packages/bigframes/noxfile.py b/packages/bigframes/noxfile.py index 0a33264aa8ea..b3a0408667dd 100644 --- a/packages/bigframes/noxfile.py +++ b/packages/bigframes/noxfile.py @@ -72,12 +72,12 @@ UNIT_TEST_DEPENDENCIES: List[str] = [] UNIT_TEST_EXTRAS: List[str] = ["tests"] UNIT_TEST_EXTRAS_BY_PYTHON: Dict[str, List[str]] = { - "3.10": ["tests", "scikit-learn", "anywidget"], - "3.11": ["tests", "polars", "scikit-learn", "anywidget"], + "3.10": ["tests", "polars", "datafusion", "substrait", "scikit-learn", "anywidget"], + "3.11": ["tests", "polars", "datafusion", "substrait", "scikit-learn", "anywidget"], # Make sure we leave some versions without "extras" so we know those # dependencies are actually optional. - "3.13": ["tests", "polars", "scikit-learn", "anywidget"], - "3.14": ["tests", "polars", "scikit-learn", "anywidget"], + "3.13": ["tests", "polars", "datafusion", "substrait", "scikit-learn", "anywidget"], + "3.14": ["tests", "polars", "datafusion", "substrait", "scikit-learn", "anywidget"], } # 3.11 is used by colab. @@ -106,9 +106,9 @@ # Make sure we leave some versions without "extras" so we know those # dependencies are actually optional. "3.10": ["tests", "scikit-learn", "anywidget"], - "3.12": ["tests", "scikit-learn", "polars", "anywidget"], - "3.13": ["tests", "polars", "anywidget"], - "3.14": ["tests", "polars", "anywidget"], + "3.12": ["tests", "scikit-learn", "polars", "datafusion", "substrait", "anywidget"], + "3.13": ["tests", "polars", "datafusion", "substrait", "anywidget"], + "3.14": ["tests", "polars", "datafusion", "substrait", "anywidget"], } LOGGING_NAME_ENV_VAR = "BIGFRAMES_PERFORMANCE_LOG_NAME" diff --git a/packages/bigframes/setup.py b/packages/bigframes/setup.py index 76b98b88d312..f4037e0d911a 100644 --- a/packages/bigframes/setup.py +++ b/packages/bigframes/setup.py @@ -42,9 +42,9 @@ "google-cloud-bigquery[bqstorage,pandas] >=3.36.0", # 2.30 needed for arrow support. "google-cloud-bigquery-storage >= 2.30.0, < 3.0.0", - "google-cloud-functions >=1.12.0", - "google-cloud-bigquery-connection >=1.12.0", - "google-cloud-resource-manager >=1.10.3", + "google-cloud-functions >=1.20.0", + "google-cloud-bigquery-connection >=1.20.0", + "google-cloud-resource-manager >=1.14.0", "google-cloud-storage >=2.0.0", "google-crc32c >=1.0.0,<2.0.0", "grpc-google-iam-v1 >= 0.14.2", @@ -73,12 +73,14 @@ "tests": [ "freezegun", "pytest-snapshot", - "google-cloud-bigtable >=2.24.0", - "google-cloud-pubsub >=2.21.4", + "google-cloud-bigtable >=2.28.0", + "google-cloud-pubsub >=2.28.0", "tzdata", ], # used for local engine "polars": ["polars >= 1.21.0"], + "datafusion": ["datafusion >= 45.2.0", "substrait >= 0.29"], + "substrait": ["substrait >= 0.29"], "scikit-learn": ["scikit-learn>=1.2.2"], # Packages required for basic development flow. "dev": [ diff --git a/packages/bigframes/testing/constraints-3.10.txt b/packages/bigframes/testing/constraints-3.10.txt index 0c76f1dda750..3c1646987ff3 100644 --- a/packages/bigframes/testing/constraints-3.10.txt +++ b/packages/bigframes/testing/constraints-3.10.txt @@ -4,18 +4,19 @@ fsspec==2023.3.0 gcsfs==2023.3.0 geopandas==0.12.2 google-auth==2.15.0 -google-cloud-bigtable==2.24.0 -google-cloud-pubsub==2.21.4 google-cloud-bigquery==3.36.0 -google-cloud-functions==1.12.0 -google-cloud-bigquery-connection==1.12.0 +google-cloud-functions==1.20.0 +google-cloud-bigquery-connection==1.20.0 google-cloud-iam==2.12.1 -google-cloud-resource-manager==1.10.3 +google-cloud-resource-manager==1.14.0 google-cloud-storage==2.0.0 grpc-google-iam-v1==0.14.2 numpy==1.24.0 pandas==1.5.3 pandas-gbq==0.26.1 +polars==1.21.0 +substrait==0.29.0 +datafusion==45.2.0 pyarrow==23.0.1 pydata-google-auth==1.8.2 pyiceberg==0.7.1 @@ -95,7 +96,7 @@ pluggy==1.6.0 prompt_toolkit==3.0.52 propcache==0.4.1 proto-plus==1.27.1 -protobuf==4.25.8 +protobuf==5.26.0 psygnal==0.15.1 ptyprocess==0.7.0 pure_eval==0.2.3 diff --git a/packages/bigframes/testing/constraints-3.11.txt b/packages/bigframes/testing/constraints-3.11.txt index 1f569a4f244c..eae6cb7caa1c 100644 --- a/packages/bigframes/testing/constraints-3.11.txt +++ b/packages/bigframes/testing/constraints-3.11.txt @@ -153,7 +153,6 @@ google-auth-httplib2==0.2.0 google-auth-oauthlib==1.2.2 google-cloud-aiplatform==1.106.0 google-cloud-bigquery==3.36.0 -google-cloud-bigquery-connection==1.18.3 google-cloud-bigquery-storage==2.32.0 google-cloud-core==2.4.3 google-cloud-dataproc==5.21.0 diff --git a/packages/bigframes/tests/system/small/engines/conftest.py b/packages/bigframes/tests/system/small/engines/conftest.py index 823ba9806d58..bdaa0d74749a 100644 --- a/packages/bigframes/tests/system/small/engines/conftest.py +++ b/packages/bigframes/tests/system/small/engines/conftest.py @@ -17,6 +17,18 @@ import google.cloud.bigquery_storage_v1 import pandas as pd import pytest + +# Skip the entire engines test directory if required libraries are missing +try: + import datafusion # noqa: F401 + import polars # noqa: F401 + import substrait # noqa: F401 +except ImportError as e: + pytest.skip( + f"Skipping engines tests because dependencies are missing: {e}", + allow_module_level=True, + ) + from google.cloud import bigquery import bigframes @@ -26,6 +38,7 @@ local_scan_executor, polars_executor, semi_executor, + substrait_executor, ) CURRENT_DIR = pathlib.Path(__file__).parent @@ -80,9 +93,35 @@ def sqlglot_engine( ) -@pytest.fixture(scope="session", params=["pyarrow", "polars", "bq", "bq-sqlglot"]) +@pytest.fixture(scope="session") +def substrait_datafusion_engine() -> semi_executor.SemiExecutor: + return substrait_executor.SubstraitExecutor.default_for_engine("datafusion") + + +@pytest.fixture(scope="session") +def substrait_acero_engine() -> semi_executor.SemiExecutor: + return substrait_executor.SubstraitExecutor.default_for_engine("acero") + + +@pytest.fixture( + scope="session", + params=[ + "pyarrow", + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], +) def engine( - request, pyarrow_engine, polars_engine, bq_engine, sqlglot_engine + request, + pyarrow_engine, + polars_engine, + bq_engine, + sqlglot_engine, + substrait_datafusion_engine, + substrait_acero_engine, ) -> semi_executor.SemiExecutor: if request.param == "pyarrow": return pyarrow_engine @@ -92,6 +131,10 @@ def engine( return bq_engine if request.param == "bq-sqlglot": return sqlglot_engine + if request.param == "substrait-datafusion": + return substrait_datafusion_engine + if request.param == "substrait-acero": + return substrait_acero_engine raise ValueError(f"Unrecognized param: {request.param}") diff --git a/packages/bigframes/tests/system/small/engines/test_aggregation.py b/packages/bigframes/tests/system/small/engines/test_aggregation.py index 669eae9ebf75..5816b8d7e6e5 100644 --- a/packages/bigframes/tests/system/small/engines/test_aggregation.py +++ b/packages/bigframes/tests/system/small/engines/test_aggregation.py @@ -73,7 +73,11 @@ def test_engines_aggregate_post_filter_size( assert_equivalence_execution(plan, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + ["polars", "bq", "bq-sqlglot", "substrait-datafusion", "substrait-acero"], + indirect=True, +) def test_engines_aggregate_size( scalars_array_value: array_value.ArrayValue, engine, @@ -96,7 +100,11 @@ def test_engines_aggregate_size( assert_equivalence_execution(node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + ["polars", "bq", "bq-sqlglot", "substrait-datafusion", "substrait-acero"], + indirect=True, +) @pytest.mark.parametrize( "op", [agg_ops.min_op, agg_ops.max_op, agg_ops.mean_op, agg_ops.sum_op, agg_ops.count_op], @@ -137,7 +145,11 @@ def test_sql_engines_median_op_aggregates( assert_equivalence_execution(node, bq_engine, sqlglot_engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + ["polars", "bq", "bq-sqlglot", "substrait-datafusion", "substrait-acero"], + indirect=True, +) @pytest.mark.parametrize( "grouping_cols", [ diff --git a/packages/bigframes/tests/system/small/engines/test_bool_ops.py b/packages/bigframes/tests/system/small/engines/test_bool_ops.py index a6ef702885b0..b3a66526515b 100644 --- a/packages/bigframes/tests/system/small/engines/test_bool_ops.py +++ b/packages/bigframes/tests/system/small/engines/test_bool_ops.py @@ -46,7 +46,17 @@ def apply_op_pairwise( return new_arr -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) @pytest.mark.parametrize( "op", [ diff --git a/packages/bigframes/tests/system/small/engines/test_comparison_ops.py b/packages/bigframes/tests/system/small/engines/test_comparison_ops.py index de0f110fa9b1..223e3068cb40 100644 --- a/packages/bigframes/tests/system/small/engines/test_comparison_ops.py +++ b/packages/bigframes/tests/system/small/engines/test_comparison_ops.py @@ -48,7 +48,17 @@ def apply_op_pairwise( return new_arr -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) @pytest.mark.parametrize( "op", [ diff --git a/packages/bigframes/tests/system/small/engines/test_filtering.py b/packages/bigframes/tests/system/small/engines/test_filtering.py index fcb85aa8859b..5f0c4ad052ca 100644 --- a/packages/bigframes/tests/system/small/engines/test_filtering.py +++ b/packages/bigframes/tests/system/small/engines/test_filtering.py @@ -24,7 +24,17 @@ REFERENCE_ENGINE = polars_executor.PolarsExecutor() -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_filter_bool_col( scalars_array_value: array_value.ArrayValue, engine, @@ -35,7 +45,17 @@ def test_engines_filter_bool_col( assert_equivalence_execution(node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_filter_expr_cond( scalars_array_value: array_value.ArrayValue, engine, @@ -47,7 +67,17 @@ def test_engines_filter_expr_cond( assert_equivalence_execution(node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_filter_true( scalars_array_value: array_value.ArrayValue, engine, @@ -57,7 +87,17 @@ def test_engines_filter_true( assert_equivalence_execution(node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_filter_false( scalars_array_value: array_value.ArrayValue, engine, diff --git a/packages/bigframes/tests/system/small/engines/test_generic_ops.py b/packages/bigframes/tests/system/small/engines/test_generic_ops.py index 05739a1c1b63..9937d80610ad 100644 --- a/packages/bigframes/tests/system/small/engines/test_generic_ops.py +++ b/packages/bigframes/tests/system/small/engines/test_generic_ops.py @@ -52,7 +52,17 @@ def apply_op( return new_arr -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_astype_int(scalars_array_value: array_value.ArrayValue, engine): polars_version = tuple([int(part) for part in polars.__version__.split(".")]) if polars_version >= (1, 34, 0): @@ -69,7 +79,11 @@ def test_engines_astype_int(scalars_array_value: array_value.ArrayValue, engine) assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + ["polars", "bq", "bq-sqlglot", "substrait-datafusion", "substrait-acero"], + indirect=True, +) def test_engines_astype_string_int(scalars_array_value: array_value.ArrayValue, engine): vals = ["1", "100", "-3"] arr, _ = scalars_array_value.compute_values( @@ -84,7 +98,17 @@ def test_engines_astype_string_int(scalars_array_value: array_value.ArrayValue, assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_astype_float(scalars_array_value: array_value.ArrayValue, engine): arr = apply_op( scalars_array_value, @@ -95,7 +119,17 @@ def test_engines_astype_float(scalars_array_value: array_value.ArrayValue, engin assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_astype_string_float( scalars_array_value: array_value.ArrayValue, engine ): @@ -112,7 +146,17 @@ def test_engines_astype_string_float( assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_astype_bool(scalars_array_value: array_value.ArrayValue, engine): arr = apply_op( scalars_array_value, ops.AsTypeOp(to_type=bigframes.dtypes.BOOL_DTYPE) @@ -121,19 +165,51 @@ def test_engines_astype_bool(scalars_array_value: array_value.ArrayValue, engine assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_astype_string(scalars_array_value: array_value.ArrayValue, engine): # floats work slightly different with trailing zeroes rn + excluded_cols = ["float64_col"] + + # Acero's Substrait consumer lacks support for string functions like replace/replace_substring + # and precision_time, so we cannot format time_col and timestamp_col inside the Substrait plan. + from bigframes.session.substrait_executor import SubstraitExecutor + + if ( + isinstance(engine, SubstraitExecutor) + and not engine._compiler._use_precision_types + ): + excluded_cols.extend(["time_col", "timestamp_col"]) + arr = apply_op( scalars_array_value, ops.AsTypeOp(to_type=bigframes.dtypes.STRING_DTYPE), - excluded_cols=["float64_col"], + excluded_cols=excluded_cols, ) assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_astype_numeric(scalars_array_value: array_value.ArrayValue, engine): arr = apply_op( scalars_array_value, @@ -144,7 +220,17 @@ def test_engines_astype_numeric(scalars_array_value: array_value.ArrayValue, eng assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_astype_string_numeric( scalars_array_value: array_value.ArrayValue, engine ): @@ -161,7 +247,17 @@ def test_engines_astype_string_numeric( assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_astype_date(scalars_array_value: array_value.ArrayValue, engine): arr = apply_op( scalars_array_value, @@ -172,7 +268,17 @@ def test_engines_astype_date(scalars_array_value: array_value.ArrayValue, engine assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_astype_string_date( scalars_array_value: array_value.ArrayValue, engine ): @@ -189,7 +295,17 @@ def test_engines_astype_string_date( assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_astype_datetime(scalars_array_value: array_value.ArrayValue, engine): arr = apply_op( scalars_array_value, @@ -200,7 +316,17 @@ def test_engines_astype_datetime(scalars_array_value: array_value.ArrayValue, en assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_astype_string_datetime( scalars_array_value: array_value.ArrayValue, engine ): @@ -217,7 +343,17 @@ def test_engines_astype_string_datetime( assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_astype_timestamp(scalars_array_value: array_value.ArrayValue, engine): arr = apply_op( scalars_array_value, @@ -228,7 +364,17 @@ def test_engines_astype_timestamp(scalars_array_value: array_value.ArrayValue, e assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_astype_string_timestamp( scalars_array_value: array_value.ArrayValue, engine ): @@ -249,7 +395,17 @@ def test_engines_astype_string_timestamp( assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_astype_time(scalars_array_value: array_value.ArrayValue, engine): arr = apply_op( scalars_array_value, @@ -260,7 +416,15 @@ def test_engines_astype_time(scalars_array_value: array_value.ArrayValue, engine assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + ], + indirect=True, +) def test_engines_astype_from_json(scalars_array_value: array_value.ArrayValue, engine): exprs = [ ops.JSONDecode(to_type=bigframes.dtypes.INT_DTYPE).as_expr( @@ -281,7 +445,15 @@ def test_engines_astype_from_json(scalars_array_value: array_value.ArrayValue, e assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + ], + indirect=True, +) def test_engines_astype_to_json(scalars_array_value: array_value.ArrayValue, engine): exprs = [ ops.ToJSON().as_expr(expression.deref("int64_col")), @@ -300,7 +472,17 @@ def test_engines_astype_to_json(scalars_array_value: array_value.ArrayValue, eng assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + "substrait-acero", + ], + indirect=True, +) def test_engines_astype_timedelta(scalars_array_value: array_value.ArrayValue, engine): arr = apply_op( scalars_array_value, diff --git a/packages/bigframes/tests/system/small/engines/test_numeric_ops.py b/packages/bigframes/tests/system/small/engines/test_numeric_ops.py index c188e37370c5..e573b848e0e3 100644 --- a/packages/bigframes/tests/system/small/engines/test_numeric_ops.py +++ b/packages/bigframes/tests/system/small/engines/test_numeric_ops.py @@ -94,7 +94,15 @@ def test_engines_project_floor( assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + ], + indirect=True, +) def test_engines_project_add( scalars_array_value: array_value.ArrayValue, engine, diff --git a/packages/bigframes/tests/system/small/engines/test_slicing.py b/packages/bigframes/tests/system/small/engines/test_slicing.py index 022758893d29..99e2e2f1bdd3 100644 --- a/packages/bigframes/tests/system/small/engines/test_slicing.py +++ b/packages/bigframes/tests/system/small/engines/test_slicing.py @@ -24,7 +24,16 @@ REFERENCE_ENGINE = polars_executor.PolarsExecutor() -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + [ + "polars", + "bq", + "bq-sqlglot", + "substrait-datafusion", + ], + indirect=True, +) @pytest.mark.parametrize( ("start", "stop", "step"), [ diff --git a/packages/bigframes/tests/system/small/engines/test_windowing.py b/packages/bigframes/tests/system/small/engines/test_windowing.py index 8235fe0ef6bf..a866368e4a68 100644 --- a/packages/bigframes/tests/system/small/engines/test_windowing.py +++ b/packages/bigframes/tests/system/small/engines/test_windowing.py @@ -32,7 +32,11 @@ REFERENCE_ENGINE = polars_executor.PolarsExecutor() -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize( + "engine", + ["polars", "bq", "bq-sqlglot", "substrait-datafusion", "substrait-acero"], + indirect=True, +) def test_engines_with_offsets( scalars_array_value: array_value.ArrayValue, engine, @@ -41,13 +45,25 @@ def test_engines_with_offsets( assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine) +@pytest.mark.parametrize( + "engine", + ["bq", "bq-sqlglot", "substrait-datafusion", "substrait-acero"], + indirect=True, +) @pytest.mark.parametrize("agg_op", [agg_ops.sum_op, agg_ops.count_op]) def test_engines_with_rows_window( scalars_array_value: array_value.ArrayValue, agg_op, + engine, bq_engine, - sqlglot_engine, ): + from bigframes.session.substrait_executor import SubstraitExecutor + + if isinstance(engine, SubstraitExecutor): + pytest.skip( + f"Substrait engine ({type(engine._consumer).__name__}) does not support windowing execution" + ) + window = window_spec.WindowSpec( bounds=window_spec.RowsWindowBounds.from_window_size(3, "left"), ) @@ -61,4 +77,4 @@ def test_engines_with_rows_window( ), window_spec=window, ) - assert_equivalence_execution(window_node, bq_engine, sqlglot_engine) + assert_equivalence_execution(window_node, bq_engine, engine)