Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion bigframes/core/compile/sqlglot/expressions/array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,12 @@

@register_unary_op(ops.ArrayIndexOp, pass_op=True)
def _(expr: TypedExpr, op: ops.ArrayIndexOp) -> sge.Expression:
if expr.dtype == dtypes.STRING_DTYPE:
return _string_index(expr, op)

return sge.Bracket(
this=expr.expr,
expressions=[sge.Literal.number(op.index)],
expressions=[sge.convert(op.index)],
safe=True,
offset=False,
)
Expand Down Expand Up @@ -115,3 +118,16 @@ def _coerce_bool_to_int(typed_expr: TypedExpr) -> sge.Expression:
if typed_expr.dtype == dtypes.BOOL_DTYPE:
return sge.Cast(this=typed_expr.expr, to="INT64")
return typed_expr.expr


def _string_index(expr: TypedExpr, op: ops.ArrayIndexOp) -> sge.Expression:
sub_str = sge.Substring(
this=expr.expr,
start=sge.convert(op.index + 1),
length=sge.convert(1),
)
return sge.If(
this=sge.NEQ(this=sub_str, expression=sge.convert("")),
true=sub_str,
false=sge.Null(),
)
150 changes: 101 additions & 49 deletions bigframes/core/compile/sqlglot/expressions/string_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import functools
import typing

import sqlglot.expressions as sge

Expand All @@ -29,7 +30,7 @@

@register_unary_op(ops.capitalize_op)
def _(expr: TypedExpr) -> sge.Expression:
return sge.Initcap(this=expr.expr)
return sge.Initcap(this=expr.expr, expression=sge.convert(""))


@register_unary_op(ops.StrContainsOp, pass_op=True)
Expand All @@ -44,9 +45,17 @@ def _(expr: TypedExpr, op: ops.StrContainsRegexOp) -> sge.Expression:

@register_unary_op(ops.StrExtractOp, pass_op=True)
def _(expr: TypedExpr, op: ops.StrExtractOp) -> sge.Expression:
return sge.RegexpExtract(
this=expr.expr, expression=sge.convert(op.pat), group=sge.convert(op.n)
)
# Cannot use BigQuery's REGEXP_EXTRACT function, which only allows one
# capturing group.
pat_expr = sge.convert(op.pat)
if op.n != 0:
pat_expr = sge.func("CONCAT", sge.convert(".*?"), pat_expr, sge.convert(".*"))
else:
pat_expr = sge.func("CONCAT", sge.convert(".*?("), pat_expr, sge.convert(").*"))

rex_replace = sge.func("REGEXP_REPLACE", expr.expr, pat_expr, sge.convert(r"\1"))
rex_contains = sge.func("REGEXP_CONTAINS", expr.expr, sge.convert(op.pat))
return sge.If(this=rex_contains, true=rex_replace, false=sge.null())


@register_unary_op(ops.StrFindOp, pass_op=True)
Expand Down Expand Up @@ -75,47 +84,43 @@ def _(expr: TypedExpr, op: ops.StrFindOp) -> sge.Expression:

@register_unary_op(ops.StrLstripOp, pass_op=True)
def _(expr: TypedExpr, op: ops.StrLstripOp) -> sge.Expression:
return sge.Trim(this=expr.expr, expression=sge.convert(op.to_strip), side="LEFT")
return sge.func("LTRIM", expr.expr, sge.convert(op.to_strip))


@register_unary_op(ops.StrRstripOp, pass_op=True)
def _(expr: TypedExpr, op: ops.StrRstripOp) -> sge.Expression:
return sge.func("RTRIM", expr.expr, sge.convert(op.to_strip))


@register_unary_op(ops.StrPadOp, pass_op=True)
def _(expr: TypedExpr, op: ops.StrPadOp) -> sge.Expression:
pad_length = sge.func(
"GREATEST", sge.Length(this=expr.expr), sge.convert(op.length)
)
expr_length = sge.Length(this=expr.expr)
fillchar = sge.convert(op.fillchar)
pad_length = sge.func("GREATEST", expr_length, sge.convert(op.length))

if op.side == "left":
return sge.func(
"LPAD",
expr.expr,
pad_length,
sge.convert(op.fillchar),
)
return sge.func("LPAD", expr.expr, pad_length, fillchar)
elif op.side == "right":
return sge.func(
"RPAD",
expr.expr,
pad_length,
sge.convert(op.fillchar),
)
return sge.func("RPAD", expr.expr, pad_length, fillchar)
else: # side == both
lpad_amount = sge.Cast(
this=sge.func(
"SAFE_DIVIDE",
sge.Sub(this=pad_length, expression=sge.Length(this=expr.expr)),
sge.convert(2),
),
to="INT64",
) + sge.Length(this=expr.expr)
lpad_amount = (
sge.Cast(
this=sge.Floor(
this=sge.func(
"SAFE_DIVIDE",
sge.Sub(this=pad_length, expression=expr_length),
sge.convert(2),
)
),
to="INT64",
)
+ expr_length
)
return sge.func(
"RPAD",
sge.func(
"LPAD",
expr.expr,
lpad_amount,
sge.convert(op.fillchar),
),
sge.func("LPAD", expr.expr, lpad_amount, fillchar),
pad_length,
sge.convert(op.fillchar),
fillchar,
)


Expand Down Expand Up @@ -224,11 +229,6 @@ def _(expr: TypedExpr) -> sge.Expression:
return sge.func("REVERSE", expr.expr)


@register_unary_op(ops.StrRstripOp, pass_op=True)
def _(expr: TypedExpr, op: ops.StrRstripOp) -> sge.Expression:
return sge.Trim(this=expr.expr, expression=sge.convert(op.to_strip), side="RIGHT")


@register_unary_op(ops.StartsWithOp, pass_op=True)
def _(expr: TypedExpr, op: ops.StartsWithOp) -> sge.Expression:
if not op.pat:
Expand All @@ -253,26 +253,78 @@ def _(expr: TypedExpr, op: ops.StringSplitOp) -> sge.Expression:

@register_unary_op(ops.StrGetOp, pass_op=True)
def _(expr: TypedExpr, op: ops.StrGetOp) -> sge.Expression:
return sge.Substring(
sub_str = sge.Substring(
this=expr.expr,
start=sge.convert(op.i + 1),
length=sge.convert(1),
)

return sge.If(
this=sge.NEQ(this=sub_str, expression=sge.convert("")),
true=sub_str,
false=sge.Null(),
)


@register_unary_op(ops.StrSliceOp, pass_op=True)
def _(expr: TypedExpr, op: ops.StrSliceOp) -> sge.Expression:
start = op.start + 1 if op.start is not None else None
if op.end is None:
length = None
elif op.start is None:
length = op.end
column_length = sge.Length(this=expr.expr)
if op.start is None:
start = 0
else:
length = op.end - op.start
start = op.start

start_expr = sge.convert(start) if start < 0 else sge.convert(start + 1)
length_expr: typing.Optional[sge.Expression]
if op.end is None:
length_expr = None
elif op.end < 0:
if start < 0:
start_expr = sge.Greatest(
expressions=[
sge.convert(1),
column_length + sge.convert(start + 1),
]
)
length_expr = sge.Greatest(
expressions=[
sge.convert(0),
column_length + sge.convert(op.end),
]
) - sge.Greatest(
expressions=[
sge.convert(0),
column_length + sge.convert(start),
]
)
else:
length_expr = sge.Greatest(
expressions=[
sge.convert(0),
column_length + sge.convert(op.end - start),
]
)
else: # op.end >= 0
if start < 0:
start_expr = sge.Greatest(
expressions=[
sge.convert(1),
column_length + sge.convert(start + 1),
]
)
length_expr = sge.convert(op.end) - sge.Greatest(
expressions=[
sge.convert(0),
column_length + sge.convert(start),
]
)
else:
length_expr = sge.convert(op.end - start)

return sge.Substring(
this=expr.expr,
start=sge.convert(start) if start is not None else None,
length=sge.convert(length) if length is not None else None,
start=start_expr,
length=length_expr,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ WITH `bfcte_0` AS (
), `bfcte_1` AS (
SELECT
*,
INITCAP(`string_col`) AS `bfcol_1`
INITCAP(`string_col`, '') AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ WITH `bfcte_0` AS (
), `bfcte_1` AS (
SELECT
*,
TRIM(`string_col`, ' ') AS `bfcol_1`
LTRIM(`string_col`, ' ') AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ WITH `bfcte_0` AS (
), `bfcte_1` AS (
SELECT
*,
TRIM(`string_col`, ' ') AS `bfcol_1`
RTRIM(`string_col`, ' ') AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ WITH `bfcte_0` AS (
), `bfcte_1` AS (
SELECT
*,
REGEXP_EXTRACT(`string_col`, '([a-z]*)') AS `bfcol_1`
IF(
REGEXP_CONTAINS(`string_col`, '([a-z]*)'),
REGEXP_REPLACE(`string_col`, CONCAT('.*?', '([a-z]*)', '.*'), '\\1'),
NULL
) AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ WITH `bfcte_0` AS (
), `bfcte_1` AS (
SELECT
*,
SUBSTRING(`string_col`, 2, 1) AS `bfcol_1`
IF(SUBSTRING(`string_col`, 2, 1) <> '', SUBSTRING(`string_col`, 2, 1), NULL) AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ WITH `bfcte_0` AS (
RPAD(
LPAD(
`string_col`,
CAST(SAFE_DIVIDE(GREATEST(LENGTH(`string_col`), 10) - LENGTH(`string_col`), 2) AS INT64) + LENGTH(`string_col`),
CAST(FLOOR(SAFE_DIVIDE(GREATEST(LENGTH(`string_col`), 10) - LENGTH(`string_col`), 2)) AS INT64) + LENGTH(`string_col`),
'-'
),
GREATEST(LENGTH(`string_col`), 10),
Expand Down