diff --git a/bigframes/core/compile/sqlglot/expressions/array_ops.py b/bigframes/core/compile/sqlglot/expressions/array_ops.py index f7b96d0418..28b3693caf 100644 --- a/bigframes/core/compile/sqlglot/expressions/array_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/array_ops.py @@ -20,6 +20,10 @@ import sqlglot.expressions as sge from bigframes import operations as ops +from bigframes.core.compile.sqlglot.expressions.string_ops import ( + string_index, + string_slice, +) from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler import bigframes.dtypes as dtypes @@ -31,7 +35,7 @@ @register_unary_op(ops.ArrayIndexOp, pass_op=True) def _(expr: TypedExpr, op: ops.ArrayIndexOp) -> sge.Expression: if expr.dtype == dtypes.STRING_DTYPE: - return _string_index(expr, op) + return string_index(expr, op.index) return sge.Bracket( this=expr.expr, @@ -71,29 +75,10 @@ def _(expr: TypedExpr, op: ops.ArrayReduceOp) -> sge.Expression: @register_unary_op(ops.ArraySliceOp, pass_op=True) def _(expr: TypedExpr, op: ops.ArraySliceOp) -> sge.Expression: - slice_idx = sg.to_identifier("slice_idx") - - conditions: typing.List[sge.Predicate] = [slice_idx >= op.start] - - if op.stop is not None: - conditions.append(slice_idx < op.stop) - - # local name for each element in the array - el = sg.to_identifier("el") - - selected_elements = ( - sge.select(el) - .from_( - sge.Unnest( - expressions=[expr.expr], - alias=sge.TableAlias(columns=[el]), - offset=slice_idx, - ) - ) - .where(*conditions) - ) - - return sge.array(selected_elements) + if expr.dtype == dtypes.STRING_DTYPE: + return string_slice(expr, op.start, op.stop) + else: + return _array_slice(expr, op) @register_unary_op(ops.ArrayToStringOp, pass_op=True) @@ -120,14 +105,51 @@ def _coerce_bool_to_int(typed_expr: TypedExpr) -> sge.Expression: return typed_expr.expr -def _string_index(expr: TypedExpr, op: ops.ArrayIndexOp) -> sge.Expression: - sub_str = sge.Substring( - this=expr.expr, - start=sge.convert(op.index + 1), - length=sge.convert(1), +def _string_slice(expr: TypedExpr, op: ops.ArraySliceOp) -> sge.Expression: + # local name for each element in the array + el = sg.to_identifier("el") + # local name for the index in the array + slice_idx = sg.to_identifier("slice_idx") + + conditions: typing.List[sge.Predicate] = [slice_idx >= op.start] + if op.stop is not None: + conditions.append(slice_idx < op.stop) + + selected_elements = ( + sge.select(el) + .from_( + sge.Unnest( + expressions=[expr.expr], + alias=sge.TableAlias(columns=[el]), + offset=slice_idx, + ) + ) + .where(*conditions) ) - return sge.If( - this=sge.NEQ(this=sub_str, expression=sge.convert("")), - true=sub_str, - false=sge.Null(), + + return sge.array(selected_elements) + + +def _array_slice(expr: TypedExpr, op: ops.ArraySliceOp) -> sge.Expression: + # local name for each element in the array + el = sg.to_identifier("el") + # local name for the index in the array + slice_idx = sg.to_identifier("slice_idx") + + conditions: typing.List[sge.Predicate] = [slice_idx >= op.start] + if op.stop is not None: + conditions.append(slice_idx < op.stop) + + selected_elements = ( + sge.select(el) + .from_( + sge.Unnest( + expressions=[expr.expr], + alias=sge.TableAlias(columns=[el]), + offset=slice_idx, + ) + ) + .where(*conditions) ) + + return sge.array(selected_elements) diff --git a/bigframes/core/compile/sqlglot/expressions/string_ops.py b/bigframes/core/compile/sqlglot/expressions/string_ops.py index 3f0578f843..6af9b6a526 100644 --- a/bigframes/core/compile/sqlglot/expressions/string_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/string_ops.py @@ -153,12 +153,15 @@ def _(expr: TypedExpr) -> sge.Expression: @register_unary_op(ops.isdecimal_op) def _(expr: TypedExpr) -> sge.Expression: - return sge.RegexpLike(this=expr.expr, expression=sge.convert(r"^\d+$")) + return sge.RegexpLike(this=expr.expr, expression=sge.convert(r"^(\p{Nd})+$")) @register_unary_op(ops.isdigit_op) def _(expr: TypedExpr) -> sge.Expression: - return sge.RegexpLike(this=expr.expr, expression=sge.convert(r"^\p{Nd}+$")) + regexp_pattern = ( + r"^[\p{Nd}\x{00B9}\x{00B2}\x{00B3}\x{2070}\x{2074}-\x{2079}\x{2080}-\x{2089}]+$" + ) + return sge.RegexpLike(this=expr.expr, expression=sge.convert(regexp_pattern)) @register_unary_op(ops.islower_op) @@ -253,12 +256,60 @@ def _(expr: TypedExpr, op: ops.StringSplitOp) -> sge.Expression: @register_unary_op(ops.StrGetOp, pass_op=True) def _(expr: TypedExpr, op: ops.StrGetOp) -> sge.Expression: + return string_index(expr, op.i) + + +@register_unary_op(ops.StrSliceOp, pass_op=True) +def _(expr: TypedExpr, op: ops.StrSliceOp) -> sge.Expression: + return string_slice(expr, op.start, op.end) + + +@register_unary_op(ops.upper_op) +def _(expr: TypedExpr) -> sge.Expression: + return sge.Upper(this=expr.expr) + + +@register_binary_op(ops.strconcat_op) +def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + return sge.Concat(expressions=[left.expr, right.expr]) + + +@register_unary_op(ops.ZfillOp, pass_op=True) +def _(expr: TypedExpr, op: ops.ZfillOp) -> sge.Expression: + length_expr = sge.Greatest( + expressions=[sge.Length(this=expr.expr), sge.convert(op.width)] + ) + return sge.Case( + ifs=[ + sge.If( + this=sge.func( + "STARTS_WITH", + expr.expr, + sge.convert("-"), + ), + true=sge.Concat( + expressions=[ + sge.convert("-"), + sge.func( + "LPAD", + sge.Substring(this=expr.expr, start=sge.convert(2)), + length_expr - 1, + sge.convert("0"), + ), + ] + ), + ) + ], + default=sge.func("LPAD", expr.expr, length_expr, sge.convert("0")), + ) + + +def string_index(expr: TypedExpr, index: int) -> sge.Expression: sub_str = sge.Substring( this=expr.expr, - start=sge.convert(op.i + 1), + start=sge.convert(index + 1), length=sge.convert(1), ) - return sge.If( this=sge.NEQ(this=sub_str, expression=sge.convert("")), true=sub_str, @@ -266,19 +317,20 @@ def _(expr: TypedExpr, op: ops.StrGetOp) -> sge.Expression: ) -@register_unary_op(ops.StrSliceOp, pass_op=True) -def _(expr: TypedExpr, op: ops.StrSliceOp) -> sge.Expression: +def string_slice( + expr: TypedExpr, op_start: typing.Optional[int], op_end: typing.Optional[int] +) -> sge.Expression: column_length = sge.Length(this=expr.expr) - if op.start is None: + if op_start is None: start = 0 else: - start = op.start + start = op_start start_expr = sge.convert(start) if start < 0 else sge.convert(start + 1) length_expr: typing.Optional[sge.Expression] - if op.end is None: + if op_end is None: length_expr = None - elif op.end < 0: + elif op_end < 0: if start < 0: start_expr = sge.Greatest( expressions=[ @@ -289,7 +341,7 @@ def _(expr: TypedExpr, op: ops.StrSliceOp) -> sge.Expression: length_expr = sge.Greatest( expressions=[ sge.convert(0), - column_length + sge.convert(op.end), + column_length + sge.convert(op_end), ] ) - sge.Greatest( expressions=[ @@ -301,7 +353,7 @@ def _(expr: TypedExpr, op: ops.StrSliceOp) -> sge.Expression: length_expr = sge.Greatest( expressions=[ sge.convert(0), - column_length + sge.convert(op.end - start), + column_length + sge.convert(op_end - start), ] ) else: # op.end >= 0 @@ -312,57 +364,17 @@ def _(expr: TypedExpr, op: ops.StrSliceOp) -> sge.Expression: column_length + sge.convert(start + 1), ] ) - length_expr = sge.convert(op.end) - sge.Greatest( + length_expr = sge.convert(op_end) - sge.Greatest( expressions=[ sge.convert(0), column_length + sge.convert(start), ] ) else: - length_expr = sge.convert(op.end - start) + length_expr = sge.convert(op_end - start) return sge.Substring( this=expr.expr, start=start_expr, length=length_expr, ) - - -@register_unary_op(ops.upper_op) -def _(expr: TypedExpr) -> sge.Expression: - return sge.Upper(this=expr.expr) - - -@register_binary_op(ops.strconcat_op) -def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: - return sge.Concat(expressions=[left.expr, right.expr]) - - -@register_unary_op(ops.ZfillOp, pass_op=True) -def _(expr: TypedExpr, op: ops.ZfillOp) -> sge.Expression: - length_expr = sge.Greatest( - expressions=[sge.Length(this=expr.expr), sge.convert(op.width)] - ) - return sge.Case( - ifs=[ - sge.If( - this=sge.func( - "STARTS_WITH", - expr.expr, - sge.convert("-"), - ), - true=sge.Concat( - expressions=[ - sge.convert("-"), - sge.func( - "LPAD", - sge.Substring(this=expr.expr, start=sge.convert(2)), - length_expr - 1, - sge.convert("0"), - ), - ] - ), - ) - ], - default=sge.func("LPAD", expr.expr, length_expr, sge.convert("0")), - ) diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_isdecimal/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_isdecimal/out.sql index 7355ab7aa7..d4dddc348f 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_isdecimal/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_isdecimal/out.sql @@ -5,7 +5,7 @@ WITH `bfcte_0` AS ( ), `bfcte_1` AS ( SELECT *, - REGEXP_CONTAINS(`string_col`, '^\\d+$') AS `bfcol_1` + REGEXP_CONTAINS(`string_col`, '^(\\p{Nd})+$') AS `bfcol_1` FROM `bfcte_0` ) SELECT diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_isdigit/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_isdigit/out.sql index d7dd8c0729..eba0e51ed0 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_isdigit/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_isdigit/out.sql @@ -5,7 +5,10 @@ WITH `bfcte_0` AS ( ), `bfcte_1` AS ( SELECT *, - REGEXP_CONTAINS(`string_col`, '^\\p{Nd}+$') AS `bfcol_1` + REGEXP_CONTAINS( + `string_col`, + '^[\\p{Nd}\\x{00B9}\\x{00B2}\\x{00B3}\\x{2070}\\x{2074}-\\x{2079}\\x{2080}-\\x{2089}]+$' + ) AS `bfcol_1` FROM `bfcte_0` ) SELECT