Skip to content

Commit

Permalink
refactor(array): remove customize error (#254)
Browse files Browse the repository at this point in the history
* perf(array): use map_batch in push to improve speed

* refactor(array): remove customize error
  • Loading branch information
hanxiao authored Apr 5, 2022
1 parent ab7aebd commit 3025a42
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 56 deletions.
91 changes: 43 additions & 48 deletions docarray/array/mixins/io/pushpull.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
import warnings
from functools import lru_cache
from pathlib import Path
from typing import Dict, Type, TYPE_CHECKING
from typing import Dict, Type, TYPE_CHECKING, Optional
from urllib.request import Request, urlopen

from ....exceptions import ObjectNotFoundError
from ....helper import get_request_header

if TYPE_CHECKING:
Expand All @@ -16,7 +15,7 @@


@lru_cache()
def _get_hub_config() -> Dict:
def _get_hub_config() -> Optional[Dict]:
hub_root = Path(os.environ.get('JINA_HUB_ROOT', Path.home().joinpath('.jina')))

if not hub_root.exists():
Expand All @@ -27,8 +26,6 @@ def _get_hub_config() -> Dict:
with open(config_file) as f:
return json.load(f)

return {}


@lru_cache()
def _get_cloud_api() -> str:
Expand Down Expand Up @@ -90,8 +87,9 @@ def push(self, name: str, show_progress: bool = False) -> Dict:

headers = {'Content-Type': ctype, **get_request_header()}

auth_token = _get_hub_config().get('auth_token')
if auth_token:
_hub_config = _get_hub_config()
if _hub_config:
auth_token = _hub_config.get('auth_token')
headers['Authorization'] = f'token {auth_token}'

_head, _tail = data.split(delimiter)
Expand All @@ -103,42 +101,44 @@ def push(self, name: str, show_progress: bool = False) -> Dict:
def gen():
total_size = 0

with pbar:
pbar.start_task(t)

for idx, d in enumerate(self):
chunk = b''
if idx == 0:
chunk += _head
chunk += self._stream_header
if idx < len(self):
chunk += d._to_stream_bytes(
protocol='protobuf', compress='gzip'
)
total_size += len(chunk)
if total_size > self._max_bytes:
warnings.warn(
f'DocumentArray is too big. Only first {idx} Documents are pushed'
)
break
yield chunk
pbar.update(
t, advance=1, total_size=str(filesize.decimal(total_size))
)
pbar.start_task(t)

chunk = _head + self._stream_header

yield chunk

def _get_chunk(_batch):
return b''.join(
d._to_stream_bytes(protocol='protobuf', compress='gzip')
for d in _batch
), len(_batch)

for chunk, num_doc_in_chunk in self.map_batch(_get_chunk, batch_size=32):
total_size += len(chunk)
if total_size > self._max_bytes:
warnings.warn(
f'DocumentArray is too big. The pushed DocumentArray might be chopped off.'
)
break
yield chunk
pbar.update(
t,
advance=num_doc_in_chunk,
total_size=str(filesize.decimal(total_size)),
)
yield _tail

res = requests.post(
f'{_get_cloud_api()}/v2/rpc/artifact.upload', data=gen(), headers=headers
)
json_res = res.json()

if res.status_code != 200:
raise RuntimeError(
json_res.get('message', 'Failed to push DocumentArray to Jina Cloud'),
f'Status code: {res.status_code}',
with pbar:
response = requests.post(
f'{_get_cloud_api()}/v2/rpc/artifact.upload',
data=gen(),
headers=headers,
)

return json_res.get('data')
if response.ok:
return response.json()['data']
else:
response.raise_for_status()

@classmethod
def pull(
Expand All @@ -161,8 +161,9 @@ def pull(

headers = {}

auth_token = _get_hub_config().get('auth_token')
if auth_token:
_hub_config = _get_hub_config()
if _hub_config:
auth_token = _hub_config.get('auth_token')
headers['Authorization'] = f'token {auth_token}'

url = f'{_get_cloud_api()}/v2/rpc/artifact.getDownloadUrl?name={name}'
Expand All @@ -171,13 +172,7 @@ def pull(
if response.ok:
url = response.json()['data']['download']
else:
json_res = response.json()
raise ObjectNotFoundError(
json_res.get(
'message', 'Failed to pull DocumentArray from Jina Cloud.'
),
f'Status code: {response.status_code}',
)
response.raise_for_status()

with requests.get(
url,
Expand Down
4 changes: 0 additions & 4 deletions docarray/exceptions.py

This file was deleted.

10 changes: 6 additions & 4 deletions tests/unit/array/mixins/test_pushpull.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@ class PushMockResponse:
def __init__(self, status_code: int = 200):
self.status_code = status_code
self.headers = {'Content-length': 1}
self.ok = status_code == 200

def json(self):
return {'code': self.status_code}
return {'code': self.status_code, 'data': []}

def raise_for_status(self):
raise Exception


class PullMockResponse:
Expand Down Expand Up @@ -100,11 +104,9 @@ def test_push_fail(mocker, monkeypatch):
_mock_post(mock, monkeypatch, status_code=requests.codes.forbidden)

docs = random_docs(2)
with pytest.raises(RuntimeError) as exc_info:
with pytest.raises(Exception) as exc_info:
docs.push('test_name')

assert exc_info.match('Failed to push DocumentArray to Jina Cloud')
assert exc_info.match('Status code: 403')
assert mock.call_count == 1


Expand Down

0 comments on commit 3025a42

Please sign in to comment.