Skip to content

Commit

Permalink
chore: add static typing with mypy (#130)
Browse files Browse the repository at this point in the history
* remove bcolors

remove bcolors

fix

add mypy

add typing

fix mypy

fix mypy

* update with master

* fix

* fix mypy
  • Loading branch information
fpgmaas authored Sep 9, 2024
1 parent 0e513ed commit cccff07
Show file tree
Hide file tree
Showing 16 changed files with 195 additions and 97 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ jobs:
- name: Check lock file
run: poetry lock --check

- name: Run pre-commit hooks
run: poetry run pre-commit run -a
- name: Run code quality checks
run: poetry run make check

test:
runs-on: ubuntu-latest
Expand Down
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ install: ## Install the Poetry environment
check: ## Run code quality checks
@echo "Running pre-commit hooks"
@poetry run pre-commit run -a
@poetry run mypy chispa

.PHONY: test
test: ## Run unit tests
Expand Down
29 changes: 15 additions & 14 deletions chispa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
import os
import sys
from glob import glob
from typing import Callable

from pyspark.sql import DataFrame

# Add PySpark to the library path based on the value of SPARK_HOME if pyspark is not already in our path
try:
from pyspark import context # noqa: F401
except ImportError:
# We need to add PySpark, try use findspark, or failback to the "manually" find it
try:
import findspark
import findspark # type: ignore[import-untyped]

findspark.init()
except ImportError:
Expand Down Expand Up @@ -46,28 +49,26 @@


class Chispa:
def __init__(self, formats: FormattingConfig | None = None, default_output=None):
def __init__(self, formats: FormattingConfig | None = None) -> None:
if not formats:
self.formats = FormattingConfig()
elif isinstance(formats, FormattingConfig):
self.formats = formats
else:
self.formats = FormattingConfig._from_arbitrary_dataclass(formats)

self.default_outputs = default_output

def assert_df_equality(
self,
df1,
df2,
ignore_nullable=False,
transforms=None,
allow_nan_equality=False,
ignore_column_order=False,
ignore_row_order=False,
underline_cells=False,
ignore_metadata=False,
):
df1: DataFrame,
df2: DataFrame,
ignore_nullable: bool = False,
transforms: list[Callable] | None = None, # type: ignore[type-arg]
allow_nan_equality: bool = False,
ignore_column_order: bool = False,
ignore_row_order: bool = False,
underline_cells: bool = False,
ignore_metadata: bool = False,
) -> None:
return assert_df_equality(
df1,
df2,
Expand Down
2 changes: 1 addition & 1 deletion chispa/bcolors.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class bcolors:
Bold = "\033[1m"
Underline = "\033[4m"

def __init__(self):
def __init__(self) -> None:
warnings.warn("The `bcolors` class is deprecated and will be removed in a future version.", DeprecationWarning)


Expand Down
25 changes: 13 additions & 12 deletions chispa/column_comparer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from prettytable import PrettyTable
from pyspark.sql import DataFrame

from chispa.formatting import blue

Expand All @@ -11,12 +12,12 @@ class ColumnsNotEqualError(Exception):
pass


def assert_column_equality(df, col_name1, col_name2):
elements = df.select(col_name1, col_name2).collect()
colName1Elements = list(map(lambda x: x[0], elements))
colName2Elements = list(map(lambda x: x[1], elements))
if colName1Elements != colName2Elements:
zipped = list(zip(colName1Elements, colName2Elements))
def assert_column_equality(df: DataFrame, col_name1: str, col_name2: str) -> None:
rows = df.select(col_name1, col_name2).collect()
col_name_1_elements = [x[0] for x in rows]
col_name_2_elements = [x[1] for x in rows]
if col_name_1_elements != col_name_2_elements:
zipped = list(zip(col_name_1_elements, col_name_2_elements))
t = PrettyTable([col_name1, col_name2])
for elements in zipped:
if elements[0] == elements[1]:
Expand All @@ -26,18 +27,18 @@ def assert_column_equality(df, col_name1, col_name2):
raise ColumnsNotEqualError("\n" + t.get_string())


def assert_approx_column_equality(df, col_name1, col_name2, precision):
elements = df.select(col_name1, col_name2).collect()
colName1Elements = list(map(lambda x: x[0], elements))
colName2Elements = list(map(lambda x: x[1], elements))
def assert_approx_column_equality(df: DataFrame, col_name1: str, col_name2: str, precision: float) -> None:
rows = df.select(col_name1, col_name2).collect()
col_name_1_elements = [x[0] for x in rows]
col_name_2_elements = [x[1] for x in rows]
all_rows_equal = True
zipped = list(zip(colName1Elements, colName2Elements))
zipped = list(zip(col_name_1_elements, col_name_2_elements))
t = PrettyTable([col_name1, col_name2])
for elements in zipped:
first = blue(str(elements[0]))
second = blue(str(elements[1]))
# when one is None and the other isn't, they're not equal
if (elements[0] is None and elements[1] is not None) or (elements[0] is not None and elements[1] is None):
if (elements[0] is None) != (elements[1] is None):
all_rows_equal = False
t.add_row([str(elements[0]), str(elements[1])])
# when both are None, they're equal
Expand Down
55 changes: 30 additions & 25 deletions chispa/dataframe_comparer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import annotations

from functools import reduce
from typing import Callable

from pyspark.sql import DataFrame

from chispa.formatting import FormattingConfig
from chispa.row_comparer import are_rows_approx_equal, are_rows_equal_enhanced
Expand All @@ -18,17 +21,17 @@ class DataFramesNotEqualError(Exception):


def assert_df_equality(
df1,
df2,
ignore_nullable=False,
transforms=None,
allow_nan_equality=False,
ignore_column_order=False,
ignore_row_order=False,
underline_cells=False,
ignore_metadata=False,
df1: DataFrame,
df2: DataFrame,
ignore_nullable: bool = False,
transforms: list[Callable] | None = None, # type: ignore[type-arg]
allow_nan_equality: bool = False,
ignore_column_order: bool = False,
ignore_row_order: bool = False,
underline_cells: bool = False,
ignore_metadata: bool = False,
formats: FormattingConfig | None = None,
):
) -> None:
if not formats:
formats = FormattingConfig()
elif not isinstance(formats, FormattingConfig):
Expand All @@ -48,7 +51,7 @@ def assert_df_equality(
df1.collect(),
df2.collect(),
are_rows_equal_enhanced,
[True],
{"allow_nan_equality": True},
underline_cells=underline_cells,
formats=formats,
)
Expand All @@ -61,7 +64,7 @@ def assert_df_equality(
)


def are_dfs_equal(df1, df2):
def are_dfs_equal(df1: DataFrame, df2: DataFrame) -> bool:
if df1.schema != df2.schema:
return False
if df1.collect() != df2.collect():
Expand All @@ -70,16 +73,16 @@ def are_dfs_equal(df1, df2):


def assert_approx_df_equality(
df1,
df2,
precision,
ignore_nullable=False,
transforms=None,
allow_nan_equality=False,
ignore_column_order=False,
ignore_row_order=False,
df1: DataFrame,
df2: DataFrame,
precision: float,
ignore_nullable: bool = False,
transforms: list[Callable] | None = None, # type: ignore[type-arg]
allow_nan_equality: bool = False,
ignore_column_order: bool = False,
ignore_row_order: bool = False,
formats: FormattingConfig | None = None,
):
) -> None:
if not formats:
formats = FormattingConfig()
elif not isinstance(formats, FormattingConfig):
Expand All @@ -99,10 +102,12 @@ def assert_approx_df_equality(
df1.collect(),
df2.collect(),
are_rows_approx_equal,
[precision, allow_nan_equality],
formats,
{"precision": precision, "allow_nan_equality": allow_nan_equality},
formats=formats,
)
elif allow_nan_equality:
assert_generic_rows_equality(df1.collect(), df2.collect(), are_rows_equal_enhanced, [True], formats)
assert_generic_rows_equality(
df1.collect(), df2.collect(), are_rows_equal_enhanced, {"allow_nan_equality": True}, formats=formats
)
else:
assert_basic_rows_equality(df1.collect(), df2.collect(), formats)
assert_basic_rows_equality(df1.collect(), df2.collect(), formats=formats)
2 changes: 1 addition & 1 deletion chispa/default_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class DefaultFormats:
mismatched_cells: list[str] = field(default_factory=lambda: ["red", "underline"])
matched_cells: list[str] = field(default_factory=lambda: ["blue"])

def __post_init__(self):
def __post_init__(self) -> None:
warnings.warn(
"DefaultFormats is deprecated. Use `chispa.formatting.FormattingConfig` instead.", DeprecationWarning
)
2 changes: 1 addition & 1 deletion chispa/formatting/format_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@ def format_string(input_string: str, format: Format) -> str:
return formatted_string


def blue(string: str):
def blue(string: str) -> str:
return Color.LIGHT_BLUE + string + Color.LIGHT_RED
9 changes: 6 additions & 3 deletions chispa/formatting/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class Format:
style: list[Style] | None = None

@classmethod
def from_dict(cls, format_dict: dict) -> Format:
def from_dict(cls, format_dict: dict[str, str | list[str]]) -> Format:
"""
Create a Format instance from a dictionary.
Expand All @@ -72,7 +72,10 @@ def from_dict(cls, format_dict: dict) -> Format:
if invalid_keys:
raise ValueError(f"Invalid keys in format dictionary: {invalid_keys}. Valid keys are {valid_keys}")

color = cls._get_color_enum(format_dict.get("color"))
if isinstance(format_dict.get("color"), list):
raise TypeError("The value for key 'color' should be a string, not a list!")
color = cls._get_color_enum(format_dict.get("color")) # type: ignore[arg-type]

style = format_dict.get("style")
if isinstance(style, str):
styles = [cls._get_style_enum(style)]
Expand All @@ -81,7 +84,7 @@ def from_dict(cls, format_dict: dict) -> Format:
else:
styles = None

return cls(color=color, style=styles)
return cls(color=color, style=styles) # type: ignore[arg-type]

@classmethod
def from_list(cls, values: list[str]) -> Format:
Expand Down
10 changes: 5 additions & 5 deletions chispa/formatting/formatting_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ class FormattingConfig:

def __init__(
self,
mismatched_rows: Format | dict = Format(Color.RED),
matched_rows: Format | dict = Format(Color.BLUE),
mismatched_cells: Format | dict = Format(Color.RED, [Style.UNDERLINE]),
matched_cells: Format | dict = Format(Color.BLUE),
mismatched_rows: Format | dict[str, str | list[str]] = Format(Color.RED),
matched_rows: Format | dict[str, str | list[str]] = Format(Color.BLUE),
mismatched_cells: Format | dict[str, str | list[str]] = Format(Color.RED, [Style.UNDERLINE]),
matched_cells: Format | dict[str, str | list[str]] = Format(Color.BLUE),
):
"""
Initializes the FormattingConfig with given or default formatting.
Expand All @@ -46,7 +46,7 @@ def __init__(
self.mismatched_cells: Format = self._parse_format(mismatched_cells)
self.matched_cells: Format = self._parse_format(matched_cells)

def _parse_format(self, format: Format | dict) -> Format:
def _parse_format(self, format: Format | dict[str, str | list[str]]) -> Format:
if isinstance(format, Format):
return format
elif isinstance(format, dict):
Expand Down
8 changes: 5 additions & 3 deletions chispa/number_helpers.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from __future__ import annotations

import math
from decimal import Decimal
from typing import Any


def isnan(x):
def isnan(x: Any) -> bool:
try:
return math.isnan(x)
except TypeError:
return False


def nan_safe_equality(x, y) -> bool:
def nan_safe_equality(x: int | float, y: int | float | Decimal) -> bool:
return (x == y) or (isnan(x) and isnan(y))


def nan_safe_approx_equality(x, y, precision) -> bool:
def nan_safe_approx_equality(x: int | float, y: int | float, precision: float | Decimal) -> bool:
return (abs(x - y) <= precision) or (isnan(x) and isnan(y))
8 changes: 4 additions & 4 deletions chispa/row_comparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ def are_rows_equal(r1: Row, r2: Row) -> bool:
return r1 == r2


def are_rows_equal_enhanced(r1: Row, r2: Row, allow_nan_equality: bool) -> bool:
def are_rows_equal_enhanced(r1: Row | None, r2: Row | None, allow_nan_equality: bool) -> bool:
if r1 is None and r2 is None:
return True
if (r1 is None and r2 is not None) or (r2 is None and r1 is not None):
if r1 is None or r2 is None:
return False
d1 = r1.asDict()
d2 = r2.asDict()
Expand All @@ -27,10 +27,10 @@ def are_rows_equal_enhanced(r1: Row, r2: Row, allow_nan_equality: bool) -> bool:
return r1 == r2


def are_rows_approx_equal(r1: Row, r2: Row, precision: float, allow_nan_equality=False) -> bool:
def are_rows_approx_equal(r1: Row | None, r2: Row | None, precision: float, allow_nan_equality: bool = False) -> bool:
if r1 is None and r2 is None:
return True
if (r1 is None and r2 is not None) or (r2 is None and r1 is not None):
if r1 is None or r2 is None:
return False
d1 = r1.asDict()
d2 = r2.asDict()
Expand Down
Loading

0 comments on commit cccff07

Please sign in to comment.