Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 73 additions & 20 deletions diffly/_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,22 +140,22 @@ def _compare_columns(
elif isinstance(dtype_left, pl.List | pl.Array) and isinstance(
dtype_right, pl.List | pl.Array
):
return _compare_sequence_columns(
col_left=col_left,
col_right=col_right,
dtype_left=dtype_left,
dtype_right=dtype_right,
max_list_length=max_list_length,
abs_tol=abs_tol,
rel_tol=rel_tol,
abs_tol_temporal=abs_tol_temporal,
)

if (
isinstance(dtype_left, pl.Enum)
and isinstance(dtype_right, pl.Enum)
and dtype_left != dtype_right
) or _enum_and_categorical(dtype_left, dtype_right):
if _needs_element_wise_comparison(dtype_left.inner, dtype_right.inner):
return _compare_sequence_columns(
col_left=col_left,
col_right=col_right,
dtype_left=dtype_left,
dtype_right=dtype_right,
max_list_length=max_list_length,
abs_tol=abs_tol,
rel_tol=rel_tol,
abs_tol_temporal=abs_tol_temporal,
)
return col_left.eq_missing(col_right)

if _different_enums(dtype_left, dtype_right) or _enum_and_categorical(
dtype_left, dtype_right
):
# Enums with different categories as well as enums and categoricals
# can't be compared directly.
# Fall back to comparison of strings.
Expand Down Expand Up @@ -237,6 +237,55 @@ def _get_element(col: pl.Expr, dtype: DataType | DataTypeClass, i: int) -> pl.Ex
return _eq_missing(has_same_length & elements_match, col_left, col_right)


def _is_float_numeric_pair(
dtype_left: DataType | DataTypeClass,
dtype_right: DataType | DataTypeClass,
) -> bool:
return (dtype_left.is_float() or dtype_right.is_float()) and (
dtype_left.is_numeric() and dtype_right.is_numeric()
)


def _is_temporal_pair(
dtype_left: DataType | DataTypeClass,
dtype_right: DataType | DataTypeClass,
) -> bool:
return dtype_left.is_temporal() and dtype_right.is_temporal()


def _needs_element_wise_comparison(
dtype_left: DataType | DataTypeClass,
dtype_right: DataType | DataTypeClass,
) -> bool:
"""Check if two dtypes require element-wise comparison (tolerances or special
handling).

Returns False when eq_missing() on the whole column would produce identical results,
allowing us to skip the expensive element-wise iteration for list/array columns.
"""
if _is_float_numeric_pair(dtype_left, dtype_right):
return True
if _is_temporal_pair(dtype_left, dtype_right):
return True
if _different_enums(dtype_left, dtype_right) or _enum_and_categorical(
dtype_left, dtype_right
):
return True
if isinstance(dtype_left, pl.Struct) and isinstance(dtype_right, pl.Struct):
fields_left = {f.name: f.dtype for f in dtype_left.fields}
fields_right = {f.name: f.dtype for f in dtype_right.fields}
return any(
_needs_element_wise_comparison(fields_left[name], fields_right[name])
for name in fields_left
if name in fields_right
)
if isinstance(dtype_left, pl.List | pl.Array) and isinstance(
dtype_right, pl.List | pl.Array
):
return _needs_element_wise_comparison(dtype_left.inner, dtype_right.inner)
return False


def _compare_primitive_columns(
col_left: pl.Expr,
col_right: pl.Expr,
Expand All @@ -246,13 +295,11 @@ def _compare_primitive_columns(
rel_tol: float,
abs_tol_temporal: dt.timedelta,
) -> pl.Expr:
if (dtype_left.is_float() or dtype_right.is_float()) and (
dtype_left.is_numeric() and dtype_right.is_numeric()
):
if _is_float_numeric_pair(dtype_left, dtype_right):
return col_left.is_close(col_right, abs_tol=abs_tol, rel_tol=rel_tol).pipe(
_eq_missing_with_nan, lhs=col_left, rhs=col_right
)
elif dtype_left.is_temporal() and dtype_right.is_temporal():
elif _is_temporal_pair(dtype_left, dtype_right):
diff_less_than_tolerance = (col_left - col_right).abs() <= abs_tol_temporal
return diff_less_than_tolerance.pipe(_eq_missing, lhs=col_left, rhs=col_right)

Expand All @@ -270,6 +317,12 @@ def _eq_missing_with_nan(expr: pl.Expr, lhs: pl.Expr, rhs: pl.Expr) -> pl.Expr:
return _eq_missing(expr, lhs, rhs) | both_nan


def _different_enums(
left: DataType | DataTypeClass, right: DataType | DataTypeClass
) -> bool:
return isinstance(left, pl.Enum) and isinstance(right, pl.Enum) and left != right


def _enum_and_categorical(
left: DataType | DataTypeClass, right: DataType | DataTypeClass
) -> bool:
Expand Down
115 changes: 114 additions & 1 deletion tests/test_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
import polars as pl
import pytest

from diffly._conditions import _can_compare_dtypes, condition_equal_columns
from diffly._conditions import (
_can_compare_dtypes,
_needs_element_wise_comparison,
condition_equal_columns,
)
from diffly.comparison import compare_frames


Expand Down Expand Up @@ -512,6 +516,45 @@ def test_condition_equal_columns_lists_only_inner() -> None:
assert actual.to_list() == [True, False]


def test_condition_equal_columns_list_of_different_enums() -> None:
# Arrange
first_enum = pl.Enum(["one", "two"])
second_enum = pl.Enum(["one", "two", "three"])

lhs = pl.DataFrame(
{"pk": [1, 2], "a": [["one", "two"], ["one", "one"]]},
schema_overrides={"a": pl.List(first_enum)},
)
rhs = pl.DataFrame(
{"pk": [1, 2], "a": [["one", "two"], ["one", "three"]]},
schema_overrides={"a": pl.List(second_enum)},
)
c = compare_frames(lhs, rhs, primary_key="pk")

# Act
lhs = lhs.rename({"a": "a_left"})
rhs = rhs.rename({"a": "a_right"})
actual = (
lhs.join(rhs, on="pk", maintain_order="left")
.select(
condition_equal_columns(
"a",
dtype_left=lhs.schema["a_left"],
dtype_right=rhs.schema["a_right"],
max_list_length=c._max_list_lengths_by_column.get("a"),
abs_tol=c.abs_tol_by_column["a"],
rel_tol=c.rel_tol_by_column["a"],
)
)
.to_series()
)

# Assert
assert c._max_list_lengths_by_column == {"a": 2}
assert _needs_element_wise_comparison(first_enum, second_enum)
assert actual.to_list() == [True, False]


@pytest.mark.parametrize(
("dtype_left", "dtype_right", "can_compare_dtypes"),
[
Expand All @@ -534,3 +577,73 @@ def test_can_compare_dtypes(
dtype_left=dtype_left, dtype_right=dtype_right
)
assert can_compare_dtypes_actual == can_compare_dtypes


@pytest.mark.parametrize(
("dtype_left", "dtype_right", "expected"),
[
# Primitives that don't need element-wise comparison
(pl.Int64, pl.Int64, False),
(pl.String, pl.String, False),
(pl.Boolean, pl.Boolean, False),
# Float/numeric pairs
(pl.Float64, pl.Float64, True),
(pl.Int64, pl.Float64, True),
(pl.Float32, pl.Int32, True),
# Temporal pairs
(pl.Datetime, pl.Datetime, True),
(pl.Date, pl.Date, True),
(pl.Datetime, pl.Date, True),
# Enum/categorical
(pl.Enum(["a", "b"]), pl.Enum(["a", "b"]), False),
(pl.Enum(["a", "b"]), pl.Enum(["a", "b", "c"]), True),
(pl.Enum(["a"]), pl.Categorical(), True),
(pl.Categorical(), pl.Enum(["a"]), True),
# Struct with no tolerance-requiring fields
(
pl.Struct({"x": pl.Int64, "y": pl.String}),
pl.Struct({"x": pl.Int64, "y": pl.String}),
False,
),
# Struct with a float field
(
pl.Struct({"x": pl.Int64, "y": pl.Float64}),
pl.Struct({"x": pl.Int64, "y": pl.Float64}),
True,
),
# Struct with different-category enums
(
pl.Struct({"x": pl.Enum(["a"])}),
pl.Struct({"x": pl.Enum(["b"])}),
True,
),
# List/Array with non-tolerance inner type
(pl.List(pl.Int64), pl.List(pl.Int64), False),
(pl.Array(pl.String, shape=3), pl.Array(pl.String, shape=3), False),
# List/Array with tolerance-requiring inner type
(pl.List(pl.Float64), pl.List(pl.Float64), True),
(pl.Array(pl.Datetime, shape=2), pl.Array(pl.Datetime, shape=2), True),
# Nested: list of structs with a float field
(
pl.List(pl.Struct({"x": pl.Float64})),
pl.List(pl.Struct({"x": pl.Float64})),
True,
),
# Nested: list of structs without tolerance-requiring fields
(
pl.List(pl.Struct({"x": pl.Int64})),
pl.List(pl.Struct({"x": pl.Int64})),
False,
),
# Deeply nested: struct with a list of structs with a float field
(
pl.List(pl.Struct({"x": pl.String, "y": pl.List(pl.Float64)})),
pl.List(pl.Struct({"x": pl.String, "y": pl.List(pl.Float64)})),
True,
),
],
)
def test_needs_element_wise_comparison(
dtype_left: pl.DataType, dtype_right: pl.DataType, expected: bool
) -> None:
assert _needs_element_wise_comparison(dtype_left, dtype_right) == expected
14 changes: 7 additions & 7 deletions tests/test_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ def expensive_computation(col: pl.Expr) -> pl.Expr:
)


def test_element_wise_comparison_slower_than_eq_missing_for_list_columns() -> None:
"""Confirm that comparing list columns with non-tolerance inner types via
eq_missing() is significantly faster than the element-wise
_compare_sequence_columns() path."""
def test_eq_missing_not_slower_than_element_wise_for_list_columns() -> None:
"""Ensure that comparing list columns with non-tolerance inner types via
eq_missing() is not slower than the element-wise _compare_sequence_columns()
path."""
n_rows = 500_000
list_len = 20
num_runs_measured = 10
Expand Down Expand Up @@ -126,10 +126,10 @@ def test_element_wise_comparison_slower_than_eq_missing_for_list_columns() -> No
mean_time_cond = statistics.mean(times_cond[num_runs_warmup:])

ratio = mean_time_cond / mean_time_eq
assert ratio > 2.0, (
f"Element-wise comparison was only {ratio:.1f}x slower than eq_missing "
assert ratio < 1.25, (
f"condition_equal_columns was {ratio:.1f}x slower than eq_missing "
f"({mean_time_cond:.3f}s vs {mean_time_eq:.3f}s). "
f"Expected at least 2x slowdown to justify the optimization."
f"Expected comparable performance since list<i64> should use eq_missing directly."
)


Expand Down
Loading