Skip to content

Commit

Permalink
feat: collate function in embed function (#187)
Browse files Browse the repository at this point in the history
* feat: init add collate function

* fix: dict inputs

* fix: use batch

* fix: remove copy

* fix: batch inputs

* fix: revert batch iterator

* fix: add bert embeded test

* fix: address comments

* fix: add paddle test

* fix: clean
  • Loading branch information
numb3r3 authored Mar 14, 2022
1 parent f1287e3 commit 34ad3fd
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 14 deletions.
63 changes: 53 additions & 10 deletions docarray/array/mixins/embed.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import warnings
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable, Optional, Any, Mapping

if TYPE_CHECKING:
from ...types import T, AnyDNN
from ... import DocumentArray

CollateFnType = Callable[
[DocumentArray],
Any,
] #: The type of collate function


class EmbedMixin:
Expand All @@ -14,6 +20,7 @@ def embed(
device: str = 'cpu',
batch_size: int = 256,
to_numpy: bool = False,
collate_fn: Optional['CollateFnType'] = None,
) -> 'T':
"""Fill :attr:`.embedding` of Documents inplace by using `embed_model`
Expand All @@ -22,18 +29,28 @@ def embed(
`cpu` or `cuda`.
:param batch_size: number of Documents in a batch for embedding
:param to_numpy: if to store embeddings back to Document in ``numpy.ndarray`` or original framework format.
:param collate_fn: create a mini-batch of Input(s) from the given `DocumentArray`. Default built-in collate_fn
is to use the `tensors` of the documents.
:return: itself after modified.
"""

if collate_fn is None:

def default_collate_fn(da: 'DocumentArray'):
return da.tensors

collate_fn = default_collate_fn

fm = get_framework(embed_model)
getattr(self, f'_set_embeddings_{fm}')(
embed_model, device, batch_size, to_numpy
embed_model, collate_fn, device, batch_size, to_numpy
)
return self

def _set_embeddings_keras(
self: 'T',
embed_model: 'AnyDNN',
collate_fn: 'CollateFnType',
device: str = 'cpu',
batch_size: int = 256,
to_numpy: bool = False,
Expand All @@ -43,7 +60,12 @@ def _set_embeddings_keras(
device = tf.device('/GPU:0') if device == 'cuda' else tf.device('/CPU:0')
with device:
for b_ids in self.batch_ids(batch_size):
r = embed_model(self[b_ids, 'tensor'], training=False)
batch_inputs = collate_fn(self[b_ids])
if isinstance(batch_inputs, Mapping):
r = embed_model(**batch_inputs, training=False)
else:
r = embed_model(batch_inputs, training=False)

if not isinstance(r, tf.Tensor):
# NOTE: Transformers has own output class.
from transformers.modeling_outputs import ModelOutput
Expand All @@ -55,6 +77,7 @@ def _set_embeddings_keras(
def _set_embeddings_torch(
self: 'T',
embed_model: 'AnyDNN',
collate_fn: 'CollateFnType',
device: str = 'cpu',
batch_size: int = 256,
to_numpy: bool = False,
Expand All @@ -66,15 +89,24 @@ def _set_embeddings_torch(
embed_model.eval()
with torch.inference_mode():
for b_ids in self.batch_ids(batch_size):
batch_inputs = torch.tensor(self[b_ids, 'tensor'], device=device)
r = embed_model(batch_inputs)
batch_inputs = collate_fn(self[b_ids])

if isinstance(batch_inputs, Mapping):
for k, v in batch_inputs.items():
batch_inputs[k] = torch.tensor(v, device=device)
r = embed_model(**batch_inputs)
else:
batch_inputs = torch.tensor(batch_inputs, device=device)
r = embed_model(batch_inputs)

if isinstance(r, torch.Tensor):
r = r.cpu().detach()
else:
# NOTE: Transformers has own output class.
from transformers.modeling_outputs import ModelOutput

r = r.pooler_output.cpu().detach() # type: ModelOutput

self[b_ids, 'embedding'] = r.numpy() if to_numpy else r

if is_training_before:
Expand All @@ -83,6 +115,7 @@ def _set_embeddings_torch(
def _set_embeddings_paddle(
self: 'T',
embed_model,
collate_fn: 'CollateFnType',
device: str = 'cpu',
batch_size: int = 256,
to_numpy: bool = False,
Expand All @@ -93,8 +126,15 @@ def _set_embeddings_paddle(
embed_model.to(device=device)
embed_model.eval()
for b_ids in self.batch_ids(batch_size):
batch_inputs = paddle.to_tensor(self[b_ids, 'tensor'], place=device)
r = embed_model(batch_inputs)
batch_inputs = collate_fn(self[b_ids])
if isinstance(batch_inputs, Mapping):
for k, v in batch_inputs.items():
batch_inputs[k] = paddle.to_tensor(v, place=device)
r = embed_model(**batch_inputs)
else:
batch_inputs = paddle.to_tensor(batch_inputs, place=device)
r = embed_model(batch_inputs)

self[b_ids, 'embedding'] = r.numpy() if to_numpy else r

if is_training_before:
Expand All @@ -103,6 +143,7 @@ def _set_embeddings_paddle(
def _set_embeddings_onnx(
self: 'T',
embed_model,
collate_fn: 'CollateFnType',
device: str = 'cpu',
batch_size: int = 256,
*args,
Expand All @@ -119,9 +160,11 @@ def _set_embeddings_onnx(
)

for b_ids in self.batch_ids(batch_size):
self[b_ids, 'embedding'] = embed_model.run(
None, {embed_model.get_inputs()[0].name: self[b_ids, 'tensor']}
)[0]
batch_inputs = collate_fn(self[b_ids])
if not isinstance(batch_inputs, Mapping):
batch_inputs = {embed_model.get_inputs()[0].name: batch_inputs}

self[b_ids, 'embedding'] = embed_model.run(None, batch_inputs)[0]


def get_framework(dnn_model) -> str:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
'onnx',
'onnxruntime',
'jupyterlab',
'transformers==4.16.2',
'transformers>=4.16.2',
'weaviate-client~=3.3.0',
'annlite>=0.3.0',
'jina',
Expand Down
Binary file modified tests/unit/array/mixins/test-net.onnx
Binary file not shown.
60 changes: 57 additions & 3 deletions tests/unit/array/mixins/test_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,20 @@
import pytest
import tensorflow as tf
import torch
from transformers import TFViTModel, ViTConfig, ViTModel
from transformers import (
TFViTModel,
ViTConfig,
ViTModel,
BertModel,
BertConfig,
TFBertModel,
)

from docarray import DocumentArray
from docarray.array.annlite import DocumentArrayAnnlite
from docarray.array.memory import DocumentArrayInMemory
from docarray.array.qdrant import DocumentArrayQdrant
from docarray.array.sqlite import DocumentArraySqlite
from docarray.array.annlite import DocumentArrayAnnlite, AnnliteConfig
from docarray.array.storage.weaviate import WeaviateConfig
from docarray.array.weaviate import DocumentArrayWeaviate

random_embed_models = {
Expand Down Expand Up @@ -111,3 +117,51 @@ def test_embedding_on_random_network(
da[: int(N / 2)].embed(embed_model, batch_size=batch_size, to_numpy=to_numpy)
da[-int(N / 2) :].embed(embed_model, batch_size=batch_size, to_numpy=to_numpy)
np.testing.assert_array_almost_equal(da.embeddings, embed1)


@pytest.fixture
def paddle_model():
import paddle.nn as nn

class DummyPaddleLayer(nn.Layer):
def forward(self, x, y):
return (x + y) / 2.0

return DummyPaddleLayer()


def test_embeded_paddle_model(paddle_model):
def collate_fn(da):
return {'x': da.tensors, 'y': da.tensors}

docs = DocumentArray.empty(3)
docs.tensors = np.random.random([3, 5]).astype(np.float32)
docs.embed(paddle_model, collate_fn=collate_fn, to_numpy=True)
assert (docs.tensors == docs.embeddings).all()


@pytest.fixture
def bert_tokenizer(tmpfile):
from transformers import BertTokenizer

vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
with open(tmpfile, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
return BertTokenizer(tmpfile)


@pytest.mark.parametrize(
'bert_transformer, return_tensors',
[(BertModel(BertConfig()), 'pt'), (TFBertModel(BertConfig()), 'tf')],
)
def test_embed_bert_model(bert_transformer, bert_tokenizer, return_tensors):
def collate_fn(da):
return bert_tokenizer(
da.texts,
return_tensors=return_tensors,
)

docs = DocumentArray.empty(1)
docs[0].text = 'this is some random text to embed'
docs.embed(bert_transformer, collate_fn=collate_fn)
assert list(docs.embeddings.shape) == [1, 768]

0 comments on commit 34ad3fd

Please sign in to comment.