Skip to content

Commit 097fcfa

Browse files
authored
refactor: fix some string ops in the sqlglot compiler (part 3) (#2336)
This change aims to fix some string-related tests failing in #2248. Fixes internal issue 417774347 🦕
1 parent 0b14b17 commit 097fcfa

File tree

4 files changed

+126
-89
lines changed

4 files changed

+126
-89
lines changed

bigframes/core/compile/sqlglot/expressions/array_ops.py

Lines changed: 55 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
import sqlglot.expressions as sge
2121

2222
from bigframes import operations as ops
23+
from bigframes.core.compile.sqlglot.expressions.string_ops import (
24+
string_index,
25+
string_slice,
26+
)
2327
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2428
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
2529
import bigframes.dtypes as dtypes
@@ -31,7 +35,7 @@
3135
@register_unary_op(ops.ArrayIndexOp, pass_op=True)
3236
def _(expr: TypedExpr, op: ops.ArrayIndexOp) -> sge.Expression:
3337
if expr.dtype == dtypes.STRING_DTYPE:
34-
return _string_index(expr, op)
38+
return string_index(expr, op.index)
3539

3640
return sge.Bracket(
3741
this=expr.expr,
@@ -71,29 +75,10 @@ def _(expr: TypedExpr, op: ops.ArrayReduceOp) -> sge.Expression:
7175

7276
@register_unary_op(ops.ArraySliceOp, pass_op=True)
7377
def _(expr: TypedExpr, op: ops.ArraySliceOp) -> sge.Expression:
74-
slice_idx = sg.to_identifier("slice_idx")
75-
76-
conditions: typing.List[sge.Predicate] = [slice_idx >= op.start]
77-
78-
if op.stop is not None:
79-
conditions.append(slice_idx < op.stop)
80-
81-
# local name for each element in the array
82-
el = sg.to_identifier("el")
83-
84-
selected_elements = (
85-
sge.select(el)
86-
.from_(
87-
sge.Unnest(
88-
expressions=[expr.expr],
89-
alias=sge.TableAlias(columns=[el]),
90-
offset=slice_idx,
91-
)
92-
)
93-
.where(*conditions)
94-
)
95-
96-
return sge.array(selected_elements)
78+
if expr.dtype == dtypes.STRING_DTYPE:
79+
return string_slice(expr, op.start, op.stop)
80+
else:
81+
return _array_slice(expr, op)
9782

9883

9984
@register_unary_op(ops.ArrayToStringOp, pass_op=True)
@@ -120,14 +105,51 @@ def _coerce_bool_to_int(typed_expr: TypedExpr) -> sge.Expression:
120105
return typed_expr.expr
121106

122107

123-
def _string_index(expr: TypedExpr, op: ops.ArrayIndexOp) -> sge.Expression:
124-
sub_str = sge.Substring(
125-
this=expr.expr,
126-
start=sge.convert(op.index + 1),
127-
length=sge.convert(1),
108+
def _string_slice(expr: TypedExpr, op: ops.ArraySliceOp) -> sge.Expression:
109+
# local name for each element in the array
110+
el = sg.to_identifier("el")
111+
# local name for the index in the array
112+
slice_idx = sg.to_identifier("slice_idx")
113+
114+
conditions: typing.List[sge.Predicate] = [slice_idx >= op.start]
115+
if op.stop is not None:
116+
conditions.append(slice_idx < op.stop)
117+
118+
selected_elements = (
119+
sge.select(el)
120+
.from_(
121+
sge.Unnest(
122+
expressions=[expr.expr],
123+
alias=sge.TableAlias(columns=[el]),
124+
offset=slice_idx,
125+
)
126+
)
127+
.where(*conditions)
128128
)
129-
return sge.If(
130-
this=sge.NEQ(this=sub_str, expression=sge.convert("")),
131-
true=sub_str,
132-
false=sge.Null(),
129+
130+
return sge.array(selected_elements)
131+
132+
133+
def _array_slice(expr: TypedExpr, op: ops.ArraySliceOp) -> sge.Expression:
134+
# local name for each element in the array
135+
el = sg.to_identifier("el")
136+
# local name for the index in the array
137+
slice_idx = sg.to_identifier("slice_idx")
138+
139+
conditions: typing.List[sge.Predicate] = [slice_idx >= op.start]
140+
if op.stop is not None:
141+
conditions.append(slice_idx < op.stop)
142+
143+
selected_elements = (
144+
sge.select(el)
145+
.from_(
146+
sge.Unnest(
147+
expressions=[expr.expr],
148+
alias=sge.TableAlias(columns=[el]),
149+
offset=slice_idx,
150+
)
151+
)
152+
.where(*conditions)
133153
)
154+
155+
return sge.array(selected_elements)

bigframes/core/compile/sqlglot/expressions/string_ops.py

Lines changed: 66 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -153,12 +153,15 @@ def _(expr: TypedExpr) -> sge.Expression:
153153

154154
@register_unary_op(ops.isdecimal_op)
155155
def _(expr: TypedExpr) -> sge.Expression:
156-
return sge.RegexpLike(this=expr.expr, expression=sge.convert(r"^\d+$"))
156+
return sge.RegexpLike(this=expr.expr, expression=sge.convert(r"^(\p{Nd})+$"))
157157

158158

159159
@register_unary_op(ops.isdigit_op)
160160
def _(expr: TypedExpr) -> sge.Expression:
161-
return sge.RegexpLike(this=expr.expr, expression=sge.convert(r"^\p{Nd}+$"))
161+
regexp_pattern = (
162+
r"^[\p{Nd}\x{00B9}\x{00B2}\x{00B3}\x{2070}\x{2074}-\x{2079}\x{2080}-\x{2089}]+$"
163+
)
164+
return sge.RegexpLike(this=expr.expr, expression=sge.convert(regexp_pattern))
162165

163166

164167
@register_unary_op(ops.islower_op)
@@ -253,32 +256,81 @@ def _(expr: TypedExpr, op: ops.StringSplitOp) -> sge.Expression:
253256

254257
@register_unary_op(ops.StrGetOp, pass_op=True)
255258
def _(expr: TypedExpr, op: ops.StrGetOp) -> sge.Expression:
259+
return string_index(expr, op.i)
260+
261+
262+
@register_unary_op(ops.StrSliceOp, pass_op=True)
263+
def _(expr: TypedExpr, op: ops.StrSliceOp) -> sge.Expression:
264+
return string_slice(expr, op.start, op.end)
265+
266+
267+
@register_unary_op(ops.upper_op)
268+
def _(expr: TypedExpr) -> sge.Expression:
269+
return sge.Upper(this=expr.expr)
270+
271+
272+
@register_binary_op(ops.strconcat_op)
273+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
274+
return sge.Concat(expressions=[left.expr, right.expr])
275+
276+
277+
@register_unary_op(ops.ZfillOp, pass_op=True)
278+
def _(expr: TypedExpr, op: ops.ZfillOp) -> sge.Expression:
279+
length_expr = sge.Greatest(
280+
expressions=[sge.Length(this=expr.expr), sge.convert(op.width)]
281+
)
282+
return sge.Case(
283+
ifs=[
284+
sge.If(
285+
this=sge.func(
286+
"STARTS_WITH",
287+
expr.expr,
288+
sge.convert("-"),
289+
),
290+
true=sge.Concat(
291+
expressions=[
292+
sge.convert("-"),
293+
sge.func(
294+
"LPAD",
295+
sge.Substring(this=expr.expr, start=sge.convert(2)),
296+
length_expr - 1,
297+
sge.convert("0"),
298+
),
299+
]
300+
),
301+
)
302+
],
303+
default=sge.func("LPAD", expr.expr, length_expr, sge.convert("0")),
304+
)
305+
306+
307+
def string_index(expr: TypedExpr, index: int) -> sge.Expression:
256308
sub_str = sge.Substring(
257309
this=expr.expr,
258-
start=sge.convert(op.i + 1),
310+
start=sge.convert(index + 1),
259311
length=sge.convert(1),
260312
)
261-
262313
return sge.If(
263314
this=sge.NEQ(this=sub_str, expression=sge.convert("")),
264315
true=sub_str,
265316
false=sge.Null(),
266317
)
267318

268319

269-
@register_unary_op(ops.StrSliceOp, pass_op=True)
270-
def _(expr: TypedExpr, op: ops.StrSliceOp) -> sge.Expression:
320+
def string_slice(
321+
expr: TypedExpr, op_start: typing.Optional[int], op_end: typing.Optional[int]
322+
) -> sge.Expression:
271323
column_length = sge.Length(this=expr.expr)
272-
if op.start is None:
324+
if op_start is None:
273325
start = 0
274326
else:
275-
start = op.start
327+
start = op_start
276328

277329
start_expr = sge.convert(start) if start < 0 else sge.convert(start + 1)
278330
length_expr: typing.Optional[sge.Expression]
279-
if op.end is None:
331+
if op_end is None:
280332
length_expr = None
281-
elif op.end < 0:
333+
elif op_end < 0:
282334
if start < 0:
283335
start_expr = sge.Greatest(
284336
expressions=[
@@ -289,7 +341,7 @@ def _(expr: TypedExpr, op: ops.StrSliceOp) -> sge.Expression:
289341
length_expr = sge.Greatest(
290342
expressions=[
291343
sge.convert(0),
292-
column_length + sge.convert(op.end),
344+
column_length + sge.convert(op_end),
293345
]
294346
) - sge.Greatest(
295347
expressions=[
@@ -301,7 +353,7 @@ def _(expr: TypedExpr, op: ops.StrSliceOp) -> sge.Expression:
301353
length_expr = sge.Greatest(
302354
expressions=[
303355
sge.convert(0),
304-
column_length + sge.convert(op.end - start),
356+
column_length + sge.convert(op_end - start),
305357
]
306358
)
307359
else: # op.end >= 0
@@ -312,57 +364,17 @@ def _(expr: TypedExpr, op: ops.StrSliceOp) -> sge.Expression:
312364
column_length + sge.convert(start + 1),
313365
]
314366
)
315-
length_expr = sge.convert(op.end) - sge.Greatest(
367+
length_expr = sge.convert(op_end) - sge.Greatest(
316368
expressions=[
317369
sge.convert(0),
318370
column_length + sge.convert(start),
319371
]
320372
)
321373
else:
322-
length_expr = sge.convert(op.end - start)
374+
length_expr = sge.convert(op_end - start)
323375

324376
return sge.Substring(
325377
this=expr.expr,
326378
start=start_expr,
327379
length=length_expr,
328380
)
329-
330-
331-
@register_unary_op(ops.upper_op)
332-
def _(expr: TypedExpr) -> sge.Expression:
333-
return sge.Upper(this=expr.expr)
334-
335-
336-
@register_binary_op(ops.strconcat_op)
337-
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
338-
return sge.Concat(expressions=[left.expr, right.expr])
339-
340-
341-
@register_unary_op(ops.ZfillOp, pass_op=True)
342-
def _(expr: TypedExpr, op: ops.ZfillOp) -> sge.Expression:
343-
length_expr = sge.Greatest(
344-
expressions=[sge.Length(this=expr.expr), sge.convert(op.width)]
345-
)
346-
return sge.Case(
347-
ifs=[
348-
sge.If(
349-
this=sge.func(
350-
"STARTS_WITH",
351-
expr.expr,
352-
sge.convert("-"),
353-
),
354-
true=sge.Concat(
355-
expressions=[
356-
sge.convert("-"),
357-
sge.func(
358-
"LPAD",
359-
sge.Substring(this=expr.expr, start=sge.convert(2)),
360-
length_expr - 1,
361-
sge.convert("0"),
362-
),
363-
]
364-
),
365-
)
366-
],
367-
default=sge.func("LPAD", expr.expr, length_expr, sge.convert("0")),
368-
)

tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_isdecimal/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ WITH `bfcte_0` AS (
55
), `bfcte_1` AS (
66
SELECT
77
*,
8-
REGEXP_CONTAINS(`string_col`, '^\\d+$') AS `bfcol_1`
8+
REGEXP_CONTAINS(`string_col`, '^(\\p{Nd})+$') AS `bfcol_1`
99
FROM `bfcte_0`
1010
)
1111
SELECT

tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_isdigit/out.sql

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@ WITH `bfcte_0` AS (
55
), `bfcte_1` AS (
66
SELECT
77
*,
8-
REGEXP_CONTAINS(`string_col`, '^\\p{Nd}+$') AS `bfcol_1`
8+
REGEXP_CONTAINS(
9+
`string_col`,
10+
'^[\\p{Nd}\\x{00B9}\\x{00B2}\\x{00B3}\\x{2070}\\x{2074}-\\x{2079}\\x{2080}-\\x{2089}]+$'
11+
) AS `bfcol_1`
912
FROM `bfcte_0`
1013
)
1114
SELECT

0 commit comments

Comments
 (0)