Skip to content

Commit

Permalink
feat: add gif export (#428)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao authored Jul 3, 2022
1 parent 29ded84 commit 5d935c9
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 10 deletions.
16 changes: 8 additions & 8 deletions docarray/array/mixins/io/pushpull.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from functools import lru_cache
from pathlib import Path
from typing import Dict, Type, TYPE_CHECKING, Optional
from urllib.request import Request, urlopen

from ....helper import get_request_header

from ....helper import get_request_header, __cache_path__

if TYPE_CHECKING:
from ....typing import T
Expand Down Expand Up @@ -137,7 +137,7 @@ def pull(
cls: Type['T'],
name: str,
show_progress: bool = False,
local_cache: bool = False,
local_cache: bool = True,
*args,
**kwargs,
) -> 'T':
Expand Down Expand Up @@ -177,10 +177,10 @@ def pull(
from .binary import LazyRequestReader

_source = LazyRequestReader(r)
if local_cache and os.path.exists(f'.cache/{name}'):
_cache_len = os.path.getsize(f'.cache/{name}')
if local_cache and os.path.exists(f'{__cache_path__}/{name}'):
_cache_len = os.path.getsize(f'{__cache_path__}/{name}')
if _cache_len == _da_len:
_source = f'.cache/{name}'
_source = f'{__cache_path__}/{name}'

r = cls.load_binary(
_source,
Expand All @@ -192,8 +192,8 @@ def pull(
)

if isinstance(_source, LazyRequestReader) and local_cache:
os.makedirs('.cache', exist_ok=True)
with open(f'.cache/{name}', 'wb') as fp:
Path(__cache_path__).mkdir(parents=True, exist_ok=True)
with open(f'{__cache_path__}/{name}', 'wb') as fp:
fp.write(_source.content)

return r
89 changes: 89 additions & 0 deletions docarray/array/mixins/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,95 @@ def _get_fastapi_app():
t_m.join()
return path

def save_gif(
self,
output: str,
channel_axis: int = -1,
duration: int = 200,
size_ratio: float = 1.0,
inline_display: bool = False,
image_source: str = 'tensor',
skip_empty: bool = False,
show_index: bool = False,
show_progress: bool = False,
) -> None:
"""
Save a gif of the DocumentArray. Each frame corresponds to a Document.uri/.tensor in the DocumentArray.
:param output: the file path to save the gif to.
:param channel_axis: the color channel axis of the tensor.
:param duration: the duration of each frame in milliseconds.
:param size_ratio: the size ratio of each frame.
:param inline_display: if to show the gif in Jupyter notebook.
:param image_source: the source of the image in Document atribute.
:param skip_empty: if to skip empty documents.
:param show_index: if to show the index of the document in the top-right corner.
:param show_progress: if to show a progress bar.
:return:
"""

from rich.progress import track
from PIL import Image, ImageDraw

def img_iterator(channel_axis):
for _idx, d in enumerate(
track(self, description='Plotting', disable=not show_progress)
):

if not d.uri and d.tensor is None:
if skip_empty:
continue
else:
raise ValueError(
f'Document has neither `uri` nor `tensor`, can not be plotted'
)

_d = copy.deepcopy(d)

if image_source == 'uri' or (
image_source == 'tensor' and _d.content_type != 'tensor'
):
_d.load_uri_to_image_tensor()
channel_axis = -1
elif image_source not in ('uri', 'tensor'):
raise ValueError(f'image_source can be only `uri` or `tensor`')

_d.set_image_tensor_channel_axis(channel_axis, -1)
if size_ratio < 1:
img_size_h, img_size_w, _ = _d.tensor.shape
_d.set_image_tensor_shape(
shape=(
int(size_ratio * img_size_h),
int(size_ratio * img_size_w),
)
)

if show_index:
_img = Image.fromarray(_d.tensor)
draw = ImageDraw.Draw(_img)
draw.text((0, 0), str(_idx), (255, 255, 255))
_d.tensor = np.asarray(_img)

yield Image.fromarray(_d.tensor).convert('RGB')

imgs = img_iterator(channel_axis)
img = next(imgs) # extract first image from iterator

with open(output, 'wb') as fp:
img.save(
fp=fp,
format='GIF',
append_images=imgs,
save_all=True,
duration=duration,
loop=0,
)

if inline_display:
from IPython.display import Image, display

display(Image(output))

def plot_image_sprites(
self,
output: Optional[str] = None,
Expand Down
3 changes: 3 additions & 0 deletions docarray/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sys
import uuid
import warnings
from os.path import expanduser
from typing import Any, Dict, Optional, Sequence, Tuple, Union

__resources_path__ = os.path.join(
Expand All @@ -14,6 +15,8 @@
'resources',
)

__cache_path__ = f'{expanduser("~")}/.cache/{__package__}'


def typename(obj):
"""
Expand Down
12 changes: 10 additions & 2 deletions tests/unit/array/mixins/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from docarray.array.elastic import DocumentArrayElastic, ElasticConfig


@pytest.mark.parametrize('keep_aspect_ratio', [True, False])
@pytest.mark.parametrize('show_index', [True, False])
@pytest.mark.parametrize(
'da_cls,config',
[
Expand All @@ -29,7 +31,7 @@
],
)
def test_sprite_fail_tensor_success_uri(
pytestconfig, tmpdir, da_cls, config, start_storage
pytestconfig, tmpdir, da_cls, config, start_storage, keep_aspect_ratio, show_index
):
files = [
f'{pytestconfig.rootdir}/tests/image-data/*.jpg',
Expand All @@ -44,7 +46,13 @@ def test_sprite_fail_tensor_success_uri(
)
with pytest.raises(ValueError):
da.plot_image_sprites()
da.plot_image_sprites(tmpdir / 'sprint_da.png', image_source='uri')
da.plot_image_sprites(
tmpdir / 'sprint_da.png',
image_source='uri',
keep_aspect_ratio=keep_aspect_ratio,
show_index=show_index,
)
da.save_gif(tmpdir / 'sprint_da.gif', show_index=show_index, channel_axis=0)
assert os.path.exists(tmpdir / 'sprint_da.png')


Expand Down

0 comments on commit 5d935c9

Please sign in to comment.