Skip to content

Commit

Permalink
De-duplicate union schemas (#80)
Browse files Browse the repository at this point in the history
  • Loading branch information
faph authored Sep 3, 2024
2 parents 42b69d2 + 7616a3a commit a68745c
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 2 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ requires-python = ">=3.9"
dependencies = [
"avro~=1.10",
"memoization~=0.4",
"more-itertools~=10.0",
"orjson~=3.5",
"typeguard~=4.0",
]
Expand Down
13 changes: 11 additions & 2 deletions src/py_avro_schema/_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
get_type_hints,
)

import more_itertools
import orjson
import typeguard

Expand Down Expand Up @@ -696,9 +697,17 @@ def __init__(self, py_type: Type[Union[Any]], namespace: Optional[str] = None, o
args = get_args(py_type)
self.item_schemas = [_schema_obj(arg, namespace=namespace, options=options) for arg in args]

def data(self, names: NamesType) -> JSONArray:
def data(self, names: NamesType) -> JSONType:
"""Return the schema data"""
return [schema.data(names=names) for schema in self.item_schemas]
# Render the item schemas
schemas = (item_schema.data(names=names) for item_schema in self.item_schemas)
# We need to deduplicate the schemas **after** rendering. This is because **different** Python types might
# result in the **same** Avro schema. Preserving order as order may be significant in an Avro schema.
unique_schemas = list(more_itertools.unique_everseen(schemas))
if len(unique_schemas) > 1:
return unique_schemas
else:
return unique_schemas[0]

def sort_item_schemas(self, default_value: Any) -> None:
"""Re-order the union's schemas such that the first item corresponds with a record field's default value"""
Expand Down
12 changes: 12 additions & 0 deletions tests/test_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,18 @@ def test_union_of_union_string_int():
assert_schema(py_type, expected)


def test_union_str_str():
py_type = Union[str, str]
expected = "string"
assert_schema(py_type, expected)


def test_union_str_annotated_str():
py_type = Union[str, Annotated[str, ...]]
expected = "string"
assert_schema(py_type, expected)


def test_literal_different_types():
py_type = Literal["", 42]
with pytest.raises(
Expand Down

0 comments on commit a68745c

Please sign in to comment.