Skip to content

Commit

Permalink
feat(embed): transformer embed support (#96)
Browse files Browse the repository at this point in the history
  • Loading branch information
bwanglzu authored Feb 8, 2022
1 parent 646ace8 commit d1184e4
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 13 deletions.
15 changes: 14 additions & 1 deletion docarray/array/mixins/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ def _set_embeddings_keras(
with device:
for b_ids in self.batch_ids(batch_size):
r = embed_model(self[b_ids, 'tensor'], training=False)
if not isinstance(r, tf.Tensor):
# NOTE: Transformers has own output class.
from transformers.modeling_outputs import ModelOutput

r = r.pooler_output # type: ModelOutput

self[b_ids, 'embedding'] = r.numpy() if to_numpy else r

def _set_embeddings_torch(
Expand All @@ -61,7 +67,14 @@ def _set_embeddings_torch(
with torch.inference_mode():
for b_ids in self.batch_ids(batch_size):
batch_inputs = torch.tensor(self[b_ids, 'tensor'], device=device)
r = embed_model(batch_inputs).cpu().detach()
r = embed_model(batch_inputs)
if isinstance(r, torch.Tensor):
r = r.cpu().detach()
else:
# NOTE: Transformers has own output class.
from transformers.modeling_outputs import ModelOutput

r = r.pooler_output.cpu().detach() # type: ModelOutput
self[b_ids, 'embedding'] = r.numpy() if to_numpy else r

if is_training_before:
Expand Down
Empty file added docarray/array/mixins/io/tt.py
Empty file.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
'onnx',
'onnxruntime',
'jupyterlab',
'transformers==4.16.2',
],
},
classifiers=[
Expand Down
43 changes: 31 additions & 12 deletions tests/unit/array/mixins/test_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
import tensorflow as tf
import torch
from transformers import TFViTModel, ViTConfig, ViTModel

from docarray import DocumentArray
from docarray.array.memory import DocumentArrayInMemory
Expand All @@ -23,6 +24,8 @@
'paddle': lambda: paddle.nn.Sequential(
paddle.nn.Dropout(0.5), paddle.nn.BatchNorm1D(128)
),
'transformers_torch': lambda: ViTModel(ViTConfig()),
'transformers_tf': lambda: TFViTModel(ViTConfig()),
}
cur_dir = os.path.dirname(os.path.abspath(__file__))
torch.onnx.export(
Expand All @@ -43,26 +46,42 @@
)


@pytest.mark.parametrize('framework', ['onnx', 'keras', 'pytorch'])
@pytest.mark.parametrize(
'da_cls,config',
'framework, input_shape, embedding_shape',
[
(DocumentArray, None),
(DocumentArraySqlite, None),
(DocumentArrayWeaviate, WeaviateConfig(n_dim=128)),
('onnx', (128,), 128),
('keras', (128,), 128),
('pytorch', (128,), 128),
('transformers_torch', (3, 224, 224), 768),
('transformers_tf', (3, 224, 224), 768),
],
)
@pytest.mark.parametrize(
'da_cls',
[
DocumentArray,
DocumentArraySqlite,
# DocumentArrayWeaviate, # TODO enable this
],
)
@pytest.mark.parametrize('N', [2, 10])
@pytest.mark.parametrize('batch_size', [1, 256])
@pytest.mark.parametrize('to_numpy', [True, False])
def test_embedding_on_random_network(
framework, da_cls, config, N, batch_size, to_numpy, start_weaviate
framework,
input_shape,
da_cls,
embedding_shape,
N,
batch_size,
to_numpy,
start_weaviate,
):
if config:
da = da_cls.empty(N, config=config)
if da_cls == DocumentArrayWeaviate:
da = da_cls.empty(N, config=WeaviateConfig(n_dim=embedding_shape))
else:
da = da_cls.empty(N)
da.tensors = np.random.random([N, 128]).astype(np.float32)
da = da_cls.empty(N, config=None)
da.tensors = np.random.random([N, *input_shape]).astype(np.float32)
embed_model = random_embed_models[framework]()
da.embed(embed_model, batch_size=batch_size, to_numpy=to_numpy)

Expand All @@ -73,7 +92,7 @@ def test_embedding_on_random_network(
embed1 = r.copy()

# reset
da.embeddings = np.random.random([N, 128]).astype(np.float32)
da.embeddings = np.random.random([N, embedding_shape]).astype(np.float32)

# docs[a: b].embed is only supported for DocumentArrayInMemory
if isinstance(da, DocumentArrayInMemory):
Expand All @@ -82,7 +101,7 @@ def test_embedding_on_random_network(
np.testing.assert_array_almost_equal(da.embeddings, embed1)

# reset
da.embeddings = np.random.random([N, 128]).astype(np.float32)
da.embeddings = np.random.random([N, embedding_shape]).astype(np.float32)

# now do this one by one
da[: int(N / 2)].embed(embed_model, batch_size=batch_size, to_numpy=to_numpy)
Expand Down

0 comments on commit d1184e4

Please sign in to comment.