Skip to content

Commit

Permalink
fix(document): serialize blob with base64 in dict/json (#63)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao authored Jan 18, 2022
1 parent 3415670 commit 79c13a0
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 3 deletions.
2 changes: 0 additions & 2 deletions docarray/document/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

if TYPE_CHECKING:
from ..types import ArrayType, StructValueType, DocumentContentType
from .. import DocumentArray
from ..score import NamedScore


class Document(AllMixins, BaseDCType):
Expand Down
10 changes: 9 additions & 1 deletion docarray/document/mixins/pydantic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
from collections import defaultdict
from typing import TYPE_CHECKING, Type

Expand Down Expand Up @@ -41,7 +42,6 @@ def from_pydantic_model(cls: Type['T'], model: 'BaseModel') -> 'T':
"""Build a Document object from a Pydantic model
:param model: the pydantic data model object that represents a Document
:param ndarray_as_list: if set to True, `embedding` and `tensor` are auto-casted to ndarray.
:return: a Document object
"""
from ... import Document
Expand All @@ -65,6 +65,14 @@ def from_pydantic_model(cls: Type['T'], model: 'BaseModel') -> 'T':
fields[f_name][k] = NamedScore(v)
elif f_name == 'embedding' or f_name == 'tensor':
fields[f_name] = np.array(value)
elif f_name == 'blob':
# here is a dirty fishy itchy trick
# the original bytes will be encoded two times:
# first time is real during `to_dict/to_json`, it converts into base64 string
# second time is at `from_dict/from_json`, it is unnecessary yet inevitable, the result string get
# converted into a binary string and encoded again.
# consequently, we need to decode two times here!
fields[f_name] = base64.b64decode(base64.b64decode(value))
else:
fields[f_name] = value

Expand Down
9 changes: 9 additions & 0 deletions docarray/document/pydantic_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
from typing import Optional, List, Dict, Any, TYPE_CHECKING, Union

from pydantic import BaseModel, validator
Expand Down Expand Up @@ -43,6 +44,14 @@ class PydanticDocument(BaseModel):
_tensor2list = validator('tensor', allow_reuse=True)(_convert_ndarray_to_list)
_embedding2list = validator('embedding', allow_reuse=True)(_convert_ndarray_to_list)

@validator('blob')
def _blob2base64(cls, v):
if v is not None:
if isinstance(v, bytes):
return base64.b64encode(v).decode('utf8')
else:
raise ValueError('must be bytes')


PydanticDocument.update_forward_refs()

Expand Down
17 changes: 17 additions & 0 deletions tests/unit/test_pydantic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from collections import defaultdict
from typing import List, Optional

Expand Down Expand Up @@ -142,3 +143,19 @@ def test_tags_int_float_str_bool(tag_type, tag_value, protocol):
dd = d.to_dict(protocol=protocol)['tags']['hello'][-1]
assert dd == tag_value
assert isinstance(dd, tag_type)


@pytest.mark.parametrize(
'blob', [None, b'123', bytes(Document()), bytes(bytearray(os.urandom(512 * 4)))]
)
@pytest.mark.parametrize('protocol', ['jsonschema', 'protobuf'])
@pytest.mark.parametrize('to_fn', ['dict', 'json'])
def test_to_from_with_blob(protocol, to_fn, blob):
d = Document(blob=blob)
r_d = getattr(Document, f'from_{to_fn}')(
getattr(d, f'to_{to_fn}')(protocol=protocol), protocol=protocol
)

assert d.blob == r_d.blob
if d.blob:
assert isinstance(r_d.blob, bytes)

0 comments on commit 79c13a0

Please sign in to comment.