Skip to content

Commit

Permalink
fix: support partial functions in parallel methods (#285)
Browse files Browse the repository at this point in the history
* fix: fix parallel apply for partial functions

* refactor: rename function

* test: cover partial function
  • Loading branch information
alaeddine-13 authored Apr 20, 2022
1 parent f7de996 commit 7d84ad5
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 5 deletions.
12 changes: 7 additions & 5 deletions docarray/array/mixins/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def map(
:yield: anything return from ``func``
"""
if _is_lambda_or_local_function(func) and backend == 'process':
if _is_lambda_or_partial_or_local_function(func) and backend == 'process':
func = _globalize_lambda_function(func)

from rich.progress import track
Expand Down Expand Up @@ -204,7 +204,7 @@ def map_batch(
:yield: anything return from ``func``
"""

if _is_lambda_or_local_function(func) and backend == 'process':
if _is_lambda_or_partial_or_local_function(func) and backend == 'process':
func = _globalize_lambda_function(func)

from rich.progress import track
Expand Down Expand Up @@ -240,9 +240,11 @@ def _get_pool(backend, num_worker):
)


def _is_lambda_or_local_function(func):
return (isinstance(func, LambdaType) and func.__name__ == '<lambda>') or (
'<locals>' in func.__qualname__
def _is_lambda_or_partial_or_local_function(func):
return (
(isinstance(func, LambdaType) and func.__name__ == '<lambda>')
or not hasattr(func, '__qualname__')
or ('<locals>' in func.__qualname__)
)


Expand Down
36 changes: 36 additions & 0 deletions tests/unit/array/mixins/test_parallel.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from multiprocessing import Pool
from multiprocessing.pool import ThreadPool

Expand Down Expand Up @@ -28,6 +29,12 @@ def foo_batch(da: DocumentArray):
return da


def foo_batch_with_args(da: DocumentArray, arg1, arg2):
for d in da:
foo(d)
return da


@pytest.mark.parametrize('pool', [None, Pool(), ThreadPool()])
def test_parallel_map_apply_external_pool(pytestconfig, pool):
da = DocumentArray.from_files(f'{pytestconfig.rootdir}/**/*.jpeg')
Expand Down Expand Up @@ -191,6 +198,35 @@ def test_map_lambda(pytestconfig, da_cls, config, start_storage):
assert d.tensor is not None


@pytest.mark.parametrize(
'da_cls, config',
[
(DocumentArray, None),
(DocumentArraySqlite, None),
(DocumentArrayAnnlite, AnnliteConfig(n_dim=10)),
(DocumentArrayWeaviate, WeaviateConfig(n_dim=10)),
(DocumentArrayQdrant, QdrantConfig(n_dim=10)),
(DocumentArrayElastic, ElasticConfig(n_dim=10)),
],
)
def test_apply_partial(pytestconfig, da_cls, config, start_storage):
if __name__ == '__main__':
if config:
da = da_cls.from_files(f'{pytestconfig.rootdir}/**/*.jpeg', config=config)[
:10
]
else:
da = da_cls.from_files(f'{pytestconfig.rootdir}/**/*.jpeg')[:10]

for d in da:
assert d.tensor is None

da.apply_batch(partial(foo_batch_with_args, arg1=None, arg2=None), batch_size=4)

for d in da:
assert d.tensor is not None


@pytest.mark.parametrize(
'storage,config',
[
Expand Down

0 comments on commit 7d84ad5

Please sign in to comment.