Skip to content

Commit

Permalink
feat: set prefetch in client for traffic control (#897)
Browse files Browse the repository at this point in the history
* fix: set prefetch in client for traffic control

* fix: set prefetch in client for traffic control

* fix: pass prefetch by param

* chore: remove uploading progress bar

* fix: address comments

* fix: address comments

* docs: prefetch

* fix: update pgbar gif
  • Loading branch information
ZiniuYu authored Mar 7, 2023
1 parent d70f238 commit cce3b05
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 37 deletions.
76 changes: 41 additions & 35 deletions client/clip_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,8 @@ def _prepare_streaming(self, disable, total):
os.environ['JINA_GRPC_SEND_BYTES'] = '0'
os.environ['JINA_GRPC_RECV_BYTES'] = '0'

self._s_task = self._pbar.add_task(
':arrow_up: Send', total=total, total_size=0, start=False
)
self._r_task = self._pbar.add_task(
':arrow_down: Recv', total=total, total_size=0, start=False
':arrow_down: Progress', total=total, total_size=0, start=False
)

@staticmethod
Expand All @@ -171,12 +168,8 @@ def _gather_result(
def _iter_doc(
self, content, results: Optional['DocumentArray'] = None
) -> Generator['Document', None, None]:
from rich import filesize
from docarray import Document

if hasattr(self, '_pbar'):
self._pbar.start_task(self._s_task)

for c in content:
if isinstance(c, str):
_mime = mimetypes.guess_type(c)[0]
Expand All @@ -199,17 +192,6 @@ def _iter_doc(
else:
raise TypeError(f'unsupported input type {c!r}')

if hasattr(self, '_pbar'):
self._pbar.update(
self._s_task,
advance=1,
total_size=str(
filesize.decimal(
int(os.environ.get('JINA_GRPC_SEND_BYTES', '0'))
)
),
)

if results is not None:
results.append(d)
yield d
Expand Down Expand Up @@ -251,6 +233,7 @@ def encode(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
) -> 'np.ndarray':
"""Encode images and texts into embeddings where the input is an iterable of raw strings.
Each image and text must be represented as a string. The following strings are acceptable:
Expand All @@ -268,6 +251,8 @@ def encode(
It takes the response ``DataRequest`` as the only argument
:param on_always: the callback function executed while streaming, after completion of each request.
It takes the response ``DataRequest`` as the only argument
:param prefetch: the number of in-flight batches made by the post() method. Use a lower value for expensive
operations, and a higher value for faster response times
:return: the embedding in a numpy ndarray with shape ``[N, D]``. ``N`` is in the same length of ``content``
"""
...
Expand All @@ -283,6 +268,7 @@ def encode(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
) -> 'DocumentArray':
"""Encode images and texts into embeddings where the input is an iterable of :class:`docarray.Document`.
:param content: an iterable of :class:`docarray.Document`, each Document must be filled with `.uri`, `.text` or `.blob`.
Expand All @@ -295,6 +281,8 @@ def encode(
It takes the response ``DataRequest`` as the only argument
:param on_always: the callback function executed while streaming, after completion of each request.
It takes the response ``DataRequest`` as the only argument
:param prefetch: the number of in-flight batches made by the post() method. Use a lower value for expensive
operations, and a higher value for faster response times
:return: the embedding in a numpy ndarray with shape ``[N, D]``. ``N`` is in the same length of ``content``
"""
...
Expand All @@ -314,6 +302,7 @@ def encode(self, content, **kwargs):
on_done = kwargs.pop('on_done', None)
on_error = kwargs.pop('on_error', None)
on_always = kwargs.pop('on_always', None)
prefetch = kwargs.pop('prefetch', 100)
results = DocumentArray() if not on_done and not on_always else None
if not on_done:
on_done = partial(
Expand All @@ -334,6 +323,7 @@ def encode(self, content, **kwargs):
on_error=on_error,
on_always=partial(self._update_pbar, func=on_always),
parameters=parameters,
prefetch=prefetch,
)

unbox = hasattr(content, '__len__') and isinstance(content[0], str)
Expand All @@ -350,6 +340,7 @@ async def aencode(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
) -> 'np.ndarray':
...

Expand All @@ -364,6 +355,7 @@ async def aencode(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
) -> 'DocumentArray':
...

Expand All @@ -382,6 +374,7 @@ async def aencode(self, content, **kwargs):
on_done = kwargs.pop('on_done', None)
on_error = kwargs.pop('on_error', None)
on_always = kwargs.pop('on_always', None)
prefetch = kwargs.pop('prefetch', 100)
results = DocumentArray() if not on_done and not on_always else None
if not on_done:
on_done = partial(
Expand All @@ -402,6 +395,7 @@ async def aencode(self, content, **kwargs):
on_error=on_error,
on_always=partial(self._update_pbar, func=on_always),
parameters=parameters,
prefetch=prefetch,
):
continue

Expand All @@ -411,29 +405,13 @@ async def aencode(self, content, **kwargs):
def _iter_rank_docs(
self, content, results: Optional['DocumentArray'] = None, source='matches'
) -> Generator['Document', None, None]:
from rich import filesize
from docarray import Document

if hasattr(self, '_pbar'):
self._pbar.start_task(self._s_task)

for c in content:
if isinstance(c, Document):
d = self._prepare_rank_doc(c, source)
else:
raise TypeError(f'Unsupported input type {c!r}')

if hasattr(self, '_pbar'):
self._pbar.update(
self._s_task,
advance=1,
total_size=str(
filesize.decimal(
int(os.environ.get('JINA_GRPC_SEND_BYTES', '0'))
)
),
)

if results is not None:
results.append(d)
yield d
Expand Down Expand Up @@ -498,6 +476,7 @@ def rank(
on_done = kwargs.pop('on_done', None)
on_error = kwargs.pop('on_error', None)
on_always = kwargs.pop('on_always', None)
prefetch = kwargs.pop('prefetch', 100)
results = DocumentArray() if not on_done and not on_always else None
if not on_done:
on_done = partial(self._gather_result, results=results, attribute='matches')
Expand All @@ -516,6 +495,7 @@ def rank(
on_error=on_error,
on_always=partial(self._update_pbar, func=on_always),
parameters=parameters,
prefetch=prefetch,
)

return results
Expand All @@ -533,6 +513,7 @@ async def arank(
on_done = kwargs.pop('on_done', None)
on_error = kwargs.pop('on_error', None)
on_always = kwargs.pop('on_always', None)
prefetch = kwargs.pop('prefetch', 100)
results = DocumentArray() if not on_done and not on_always else None
if not on_done:
on_done = partial(self._gather_result, results=results, attribute='matches')
Expand All @@ -551,6 +532,7 @@ async def arank(
on_error=on_error,
on_always=partial(self._update_pbar, func=on_always),
parameters=parameters,
prefetch=prefetch,
):
continue

Expand All @@ -567,6 +549,7 @@ def index(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
):
"""Index the images or texts where their embeddings are computed by the server CLIP model.
Expand All @@ -585,6 +568,8 @@ def index(
It takes the response ``DataRequest`` as the only argument
:param on_always: the callback function executed while streaming, after each request is completed.
It takes the response ``DataRequest`` as the only argument
:param prefetch: the number of in-flight batches made by the post() method. Use a lower value for expensive
operations, and a higher value for faster response times
:return: the embedding in a numpy ndarray with shape ``[N, D]``. ``N`` is in the same length of ``content``
"""
...
Expand All @@ -600,6 +585,7 @@ def index(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
) -> 'DocumentArray':
"""Index the images or texts where their embeddings are computed by the server CLIP model.
Expand All @@ -613,6 +599,8 @@ def index(
It takes the response ``DataRequest`` as the only argument
:param on_always: the callback function executed while streaming, after each request is completed.
It takes the response ``DataRequest`` as the only argument
:param prefetch: the number of in-flight batches made by the post() method. Use a lower value for expensive
operations, and a higher value for faster response times
:return: the embedding in a numpy ndarray with shape ``[N, D]``. ``N`` is in the same length of ``content``
"""
...
Expand All @@ -630,6 +618,7 @@ def index(self, content, **kwargs):
on_done = kwargs.pop('on_done', None)
on_error = kwargs.pop('on_error', None)
on_always = kwargs.pop('on_always', None)
prefetch = kwargs.pop('prefetch', 100)
results = DocumentArray() if not on_done and not on_always else None
if not on_done:
on_done = partial(
Expand All @@ -649,6 +638,7 @@ def index(self, content, **kwargs):
on_error=on_error,
on_always=partial(self._update_pbar, func=on_always),
parameters=parameters,
prefetch=prefetch,
)

return results
Expand All @@ -664,6 +654,7 @@ async def aindex(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
):
...

Expand All @@ -678,6 +669,7 @@ async def aindex(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
):
...

Expand All @@ -694,6 +686,7 @@ async def aindex(self, content, **kwargs):
on_done = kwargs.pop('on_done', None)
on_error = kwargs.pop('on_error', None)
on_always = kwargs.pop('on_always', None)
prefetch = kwargs.pop('prefetch', 100)
results = DocumentArray() if not on_done and not on_always else None
if not on_done:
on_done = partial(
Expand All @@ -713,6 +706,7 @@ async def aindex(self, content, **kwargs):
on_error=on_error,
on_always=partial(self._update_pbar, func=on_always),
parameters=parameters,
prefetch=prefetch,
):
continue

Expand All @@ -730,6 +724,7 @@ def search(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
) -> 'DocumentArray':
"""Search for top k results for given query string or ``Document``.
Expand All @@ -747,6 +742,8 @@ def search(
It takes the response ``DataRequest`` as the only argument
:param on_always: the callback function executed while streaming, after each request is completed.
It takes the response ``DataRequest`` as the only argument
:param prefetch: the number of in-flight batches made by the post() method. Use a lower value for expensive
operations, and a higher value for faster response times
"""
...

Expand All @@ -762,6 +759,7 @@ def search(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
) -> 'DocumentArray':
"""Search for top k results for given query string or ``Document``.
Expand All @@ -779,6 +777,8 @@ def search(
It takes the response ``DataRequest`` as the only argument
:param on_always: the callback function executed while streaming, after each request is completed.
It takes the response ``DataRequest`` as the only argument
:param prefetch: the number of in-flight batches made by the post() method. Use a lower value for expensive
operations, and a higher value for faster response times
"""
...

Expand All @@ -795,6 +795,7 @@ def search(self, content, limit: int = 10, **kwargs) -> 'DocumentArray':
on_done = kwargs.pop('on_done', None)
on_error = kwargs.pop('on_error', None)
on_always = kwargs.pop('on_always', None)
prefetch = kwargs.pop('prefetch', 100)
results = DocumentArray() if not on_done and not on_always else None
if not on_done:
on_done = partial(self._gather_result, results=results, attribute='matches')
Expand All @@ -813,6 +814,7 @@ def search(self, content, limit: int = 10, **kwargs) -> 'DocumentArray':
on_error=on_error,
on_always=partial(self._update_pbar, func=on_always),
parameters=parameters,
prefetch=prefetch,
)

return results
Expand All @@ -829,6 +831,7 @@ async def asearch(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
):
...

Expand All @@ -844,6 +847,7 @@ async def asearch(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
prefetch: int = 100,
):
...

Expand All @@ -860,6 +864,7 @@ async def asearch(self, content, limit: int = 10, **kwargs):
on_done = kwargs.pop('on_done', None)
on_error = kwargs.pop('on_error', None)
on_always = kwargs.pop('on_always', None)
prefetch = kwargs.pop('prefetch', 100)
results = DocumentArray() if not on_done and not on_always else None
if not on_done:
on_done = partial(self._gather_result, results=results, attribute='matches')
Expand All @@ -878,6 +883,7 @@ async def asearch(self, content, limit: int = 10, **kwargs):
on_error=on_error,
on_always=partial(self._update_pbar, func=on_always),
parameters=parameters,
prefetch=prefetch,
):
continue

Expand Down
13 changes: 11 additions & 2 deletions docs/user-guides/client.md
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,14 @@ You can specify `.encode(..., batch_size=8)` to control how many `Document`s are

Intuitively, setting `batch_size=1024` should result in very high GPU utilization on each request. However, a large batch size like this also means sending each request would take longer. Given that `clip-client` is designed with request and response streaming, large batch size would not benefit from the time overlapping between request streaming and response streaming.

### Control prefetch size

To control the number of in-flight batches, you can use the `.encode(..., prefetch=100)` option.
The way this works is that when you send a large request, the outgoing request stream will usually finish before the incoming response stream due to the asynchronous design.
This is because the request handling is typically time-consuming, which can prevent the server from sending back the response and may cause it to close the connection as it thinks the incoming channel is idle.
By default, the client is set to a prefetch value of 100. However, it is recommended to use a lower value for expensive operations and a higher value for faster response times.

For more information about client prefetching, please refer to [Rate Limit](https://docs.jina.ai/concepts/client/rate-limit/) in Jina documentation.

### Show progressbar

Expand Down Expand Up @@ -459,8 +467,9 @@ Here are some suggestions when encoding a large number of `Document`s:

c.encode(iglob('**/*.png'))
```
2. Adjust `batch_size`.
3. Turn on the progressbar.
2. Adjust the `batch_size` parameters.
3. Adjust the `prefetch` parameters.
4. Turn on the progressbar.

````{danger}
In any case, avoiding the following coding:
Expand Down
Binary file modified docs/user-guides/images/client-pgbar.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit cce3b05

Please sign in to comment.