Skip to content

Commit

Permalink
Add ignore_columns option to dataframe_comparer.py (#141)
Browse files Browse the repository at this point in the history
* Add ignore_columns option to dataframe_comparer.py

Add ignore_columns option to:
assert_df_equality()
assert_approx_df_equality()

* Update README.md

* Ruff pre-commit updates

* Update __init__.py
  • Loading branch information
kasztp authored Oct 24, 2024
1 parent 50c2411 commit 202ab2a
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 0 deletions.
40 changes: 40 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,46 @@ Here's the error message you'll see if you run `assert_df_equality(df1, df2)`, w

![ignore_column_order_false](https://raw.githubusercontent.com/MrPowers/chispa/main/images/ignore_column_order_false.png)

### Ignore specific columns

This section explains how to compare DataFrames, ignoring specific columns.

Suppose you have the following `df1`:

```
+------------+-------------+
| name | clean_name |
+------------+-------------+
| "matt7" | "matt7" |
| "bill&" | "bill" |
| "isabela*" | "isabela" |
| "None" | "None" |
+------------+-------------+
```

Here are the contents of `df2`:

```
+------------+-------------+
| name | clean_name |
+------------+-------------+
| "matt7" | "matt" |
| "bill&" | "bill" |
| "isabela*" | "isabela" |
| "None" | "None" |
+------------+-------------+
```

Here's how to compare the equality of `df1` and `df2`, ignoring the column `clean_name`:

```python
assert_df_equality(df1, df2, ignore_columns=["clean_name"])
```

Here's the error message you'll see if you run `assert_df_equality(df1, df2)`, without ignoring the column `clean_name`.

![ignore_columns_none](https://raw.githubusercontent.com/MrPowers/chispa/main/images/dfs_not_equal_error.png)

### Ignore nullability

Each column in a schema has three properties: a name, data type, and nullable property. The column can accept null values if `nullable` is set to true.
Expand Down
2 changes: 2 additions & 0 deletions chispa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def assert_df_equality(
ignore_row_order: bool = False,
underline_cells: bool = False,
ignore_metadata: bool = False,
ignore_columns: list[str] | None = None,
) -> None:
return assert_df_equality(
df1,
Expand All @@ -79,6 +80,7 @@ def assert_df_equality(
ignore_row_order,
underline_cells,
ignore_metadata,
ignore_columns,
self.formats,
)

Expand Down
12 changes: 12 additions & 0 deletions chispa/dataframe_comparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def assert_df_equality(
ignore_row_order: bool = False,
underline_cells: bool = False,
ignore_metadata: bool = False,
ignore_columns: list[str] | None = None,
formats: FormattingConfig | None = None,
) -> None:
if not formats:
Expand All @@ -43,9 +44,14 @@ def assert_df_equality(
transforms.append(lambda df: df.select(sorted(df.columns)))
if ignore_row_order:
transforms.append(lambda df: df.sort(df.columns))
if ignore_columns:
transforms.append(lambda df: df.drop(*ignore_columns))

df1 = reduce(lambda acc, fn: fn(acc), transforms, df1)
df2 = reduce(lambda acc, fn: fn(acc), transforms, df2)

assert_schema_equality(df1.schema, df2.schema, ignore_nullable, ignore_metadata)

if allow_nan_equality:
assert_generic_rows_equality(
df1.collect(),
Expand Down Expand Up @@ -81,6 +87,7 @@ def assert_approx_df_equality(
allow_nan_equality: bool = False,
ignore_column_order: bool = False,
ignore_row_order: bool = False,
ignore_columns: list[str] | None = None,
formats: FormattingConfig | None = None,
) -> None:
if not formats:
Expand All @@ -94,9 +101,14 @@ def assert_approx_df_equality(
transforms.append(lambda df: df.select(sorted(df.columns)))
if ignore_row_order:
transforms.append(lambda df: df.sort(df.columns))
if ignore_columns:
transforms.append(lambda df: df.drop(*ignore_columns))

df1 = reduce(lambda acc, fn: fn(acc), transforms, df1)
df2 = reduce(lambda acc, fn: fn(acc), transforms, df2)

assert_schema_equality(df1.schema, df2.schema, ignore_nullable)

if precision != 0:
assert_generic_rows_equality(
df1.collect(),
Expand Down
30 changes: 30 additions & 0 deletions tests/test_dataframe_comparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,21 @@ def it_catches_mismatched_metadata():
with pytest.raises(SchemasNotEqualError):
assert_df_equality(df1, df2)

def it_can_ignore_columns():
data1 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")]
df1 = spark.createDataFrame(data1, ["name", "expected_name"])
data2 = [("bob", "jose"), ("li", "boo"), ("luisa", "boo")]
df2 = spark.createDataFrame(data2, ["name", "expected_name"])
assert_df_equality(df1, df2, ignore_columns=["expected_name"])

def it_throws_when_dfs_are_not_same_with_ignored_columns():
data1 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")]
df1 = spark.createDataFrame(data1, ["name", "expected_name"])
data2 = [("bob", "jose"), ("li", "boo"), ("luisa", "boo")]
df2 = spark.createDataFrame(data2, ["name", "expected_name"])
with pytest.raises(DataFramesNotEqualError):
assert assert_df_equality(df1, df2, ignore_columns=["name"])


def describe_are_dfs_equal():
def it_returns_false_with_schema_mismatches():
Expand Down Expand Up @@ -206,3 +221,18 @@ def it_does_not_throw_with_nan_values():
]
df2 = spark.createDataFrame(data2, ["num", "expected_name"])
assert_approx_df_equality(df1, df2, 0.1, allow_nan_equality=True)

def it_can_ignore_columns():
data1 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")]
df1 = spark.createDataFrame(data1, ["name", "expected_name"])
data2 = [("bob", "jose"), ("li", "boo"), ("luisa", "boo")]
df2 = spark.createDataFrame(data2, ["name", "expected_name"])
assert_approx_df_equality(df1, df2, 0.1, ignore_columns=["expected_name"])

def it_throws_when_dfs_are_not_same_with_ignored_columns():
data1 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")]
df1 = spark.createDataFrame(data1, ["name", "expected_name"])
data2 = [("bob", "jose"), ("li", "boo"), ("luisa", "boo")]
df2 = spark.createDataFrame(data2, ["name", "expected_name"])
with pytest.raises(DataFramesNotEqualError):
assert assert_approx_df_equality(df1, df2, 0.1, ignore_columns=["name"])

0 comments on commit 202ab2a

Please sign in to comment.