Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 48 additions & 25 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,21 +726,6 @@ def verify_element(result):
return lambda k, v, s: [(wrapped(k, v, s), return_type)]


def wrap_grouped_agg_pandas_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf):
func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, kwargs_offsets)

def wrapped(*series):
import pandas as pd

result = func(*series)
return pd.Series([result])

return (
args_kwargs_offsets,
lambda *a: (wrapped(*a), return_type),
)


def wrap_grouped_agg_pandas_iter_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf):
func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, kwargs_offsets)

Expand Down Expand Up @@ -1031,11 +1016,8 @@ def read_single_udf(pickleSer, udf_info, eval_type, runner_conf, udf_index):
elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF:
argspec = inspect.getfullargspec(chained_func) # signature was lost when wrapping it
return func, args_offsets, return_type, len(argspec.args)
elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
return wrap_grouped_agg_pandas_udf(
func, args_offsets, kwargs_offsets, return_type, runner_conf
)
elif eval_type in (
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF,
):
Expand Down Expand Up @@ -2239,14 +2221,14 @@ def read_udfs(pickleSer, udf_info_list, eval_type, runner_conf, eval_conf):
PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF,
PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF,
):
ser = ArrowStreamGroupSerializer(write_start_stream=True)
elif eval_type == PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF:
ser = ArrowStreamGroupSerializer(write_start_stream=True)
elif eval_type in (
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF,
PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
):
Expand Down Expand Up @@ -2572,6 +2554,50 @@ def grouped_func(
# profiling is not supported for UDF
return grouped_func, None, ser, ser

if eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
import pyarrow as pa
import pandas as pd

col_names = ["_%d" % i for i in range(len(udfs))]
output_schema = StructType(
[StructField(name, rt) for name, (_, _, _, rt) in zip(col_names, udfs)]
)

def grouped_func(
split_index: int, data: Iterator["GroupedBatch"]
) -> Iterator[pa.RecordBatch]:
for group in data:
batch_list = list(group)
if not batch_list:
continue
table = pa.Table.from_batches(batch_list).combine_chunks()
all_series = ArrowBatchTransformer.to_pandas(
table,
timezone=runner_conf.timezone,
prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype,
)
results = [
udf_func(
*[all_series[o] for o in args_offsets],
**{k: all_series[v] for k, v in kwargs_offsets.items()},
)
for udf_func, args_offsets, kwargs_offsets, _ in udfs
]
result_series = [pd.Series([r]) for r in results]
yield PandasToArrowConversion.convert(
result_series,
output_schema,
timezone=runner_conf.timezone,
safecheck=runner_conf.safecheck,
arrow_cast=True,
prefers_large_types=runner_conf.use_large_var_types,
assign_cols_by_name=runner_conf.assign_cols_by_name,
int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled,
)

# profiling is not supported for UDF
return grouped_func, None, ser, ser

if eval_type == PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF:
import pyarrow as pa

Expand Down Expand Up @@ -3523,13 +3549,10 @@ def mapper(batch_iter):
)
return f(series_iter)

elif eval_type in (
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
):
elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
import pandas as pd

# For SQL_GROUPED_AGG_PANDAS_UDF and SQL_WINDOW_AGG_PANDAS_UDF,
# For SQL_WINDOW_AGG_PANDAS_UDF,
# convert iterator of batch tuples to concatenated pandas Series
def mapper(batch_iter):
# batch_iter is Iterator[Tuple[pd.Series, ...]] where each tuple represents one batch
Expand Down