Skip to content

Commit

Permalink
fix: default dict from/to protobuf (#47)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao authored Jan 14, 2022
1 parent d764341 commit e022463
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
6 changes: 3 additions & 3 deletions docarray/document/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import uuid
from collections import defaultdict
from dataclasses import dataclass, field, fields
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

if TYPE_CHECKING:
from ..score import NamedScore
Expand Down Expand Up @@ -51,8 +51,8 @@ class DocumentData:
location: Optional[List[float]] = None
embedding: Optional['ArrayType'] = field(default=None, hash=False, compare=False)
modality: Optional[str] = None
evaluations: Optional[Dict[str, 'NamedScore']] = None
scores: Optional[Dict[str, 'NamedScore']] = None
evaluations: Optional[Dict[str, Union['NamedScore', Dict]]] = None
scores: Optional[Dict[str, Union['NamedScore', Dict]]] = None
chunks: Optional['DocumentArray'] = None
matches: Optional['DocumentArray'] = None

Expand Down
3 changes: 2 additions & 1 deletion docarray/proto/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
from typing import TYPE_CHECKING

from google.protobuf.json_format import MessageToDict
Expand Down Expand Up @@ -26,7 +27,7 @@ def parse_proto(pb_msg: 'DocumentProto') -> 'Document':
elif f_name == 'location':
fields[f_name] = list(value)
elif f_name == 'scores' or f_name == 'evaluations':
fields[f_name] = {}
fields[f_name] = defaultdict()
for k, v in value.items():
fields[f_name][k] = NamedScore(
{ff.name: vv for (ff, vv) in v.ListFields()}
Expand Down
13 changes: 13 additions & 0 deletions tests/unit/document/test_protobuf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections import defaultdict

import numpy as np
import pytest

Expand Down Expand Up @@ -32,3 +34,14 @@ def test_to_protobuf():
)
assert Document(tags={'hello': 'world'}).to_protobuf().tags
assert len(Document(chunks=[Document(), Document()]).to_protobuf().chunks) == 2


@pytest.mark.parametrize('meth', ['protobuf', 'dict'])
@pytest.mark.parametrize('attr', ['scores', 'evaluations'])
def test_from_to_namescore_default_dict(attr, meth):
d = Document()
getattr(d, attr)['relevance'].value = 3.0
assert isinstance(d.scores, defaultdict)

r_d = getattr(Document, f'from_{meth}')(getattr(d, f'to_{meth}')())
assert isinstance(r_d.scores, defaultdict)

0 comments on commit e022463

Please sign in to comment.