Skip to content

Commit

Permalink
Fix tiling + pytorch (#1427)
Browse files Browse the repository at this point in the history
* fix tiling + pytorch

* black fix

* mypy fix

* added pt + tiling test

* mypy fix

* black fix

* fix empty sample edgecase

* reduce test size
  • Loading branch information
AbhinavTuli authored Jan 12, 2022
1 parent 4a01cb8 commit 6cf888b
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 26 deletions.
17 changes: 17 additions & 0 deletions hub/api/tests/test_api_tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,20 @@ def test_cachable_overflow(memory_ds):
assert len(ds) == 3
assert len(ds.x) == 3
assert len(ds.y) == 3


@compressions_paremetrized
def test_empty_array(memory_ds, compression):
ds = memory_ds
arr_list = [
np.random.randint(0, 255, (3894, 4279, 0), dtype=np.uint8),
np.random.randint(0, 255, (3894, 4279, 3), dtype=np.uint8),
]
with ds:
ds.create_tensor("x", **compression)
ds.x.extend(arr_list)
assert len(ds) == 2
assert len(ds.x) == 2

for i in range(2):
np.testing.assert_array_equal(ds.x[i].numpy(), arr_list[i])
6 changes: 5 additions & 1 deletion hub/core/chunk/base_chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,11 @@ def memoryview_data(self):

@property
def is_empty(self):
return self.num_data_bytes == 0
return (
self.num_data_bytes == 0
and len(self.shapes_encoder.array) == 0
and len(self.byte_positions_encoder.array) == 0
)

def tobytes(self) -> memoryview:
return serialize_chunk(
Expand Down
2 changes: 2 additions & 0 deletions hub/core/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,8 @@ def decompress_array(
if compression == "apng":
return _decompress_apng(buffer) # type: ignore
try:
if shape is not None and 0 in shape:
return np.zeros(shape, dtype=dtype)
if not isinstance(buffer, str):
buffer = BytesIO(buffer) # type: ignore
img = Image.open(buffer) # type: ignore
Expand Down
63 changes: 39 additions & 24 deletions hub/core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from hub.core.meta.encode.base_encoder import LAST_SEEN_INDEX_COLUMN
from hub.core.meta.encode.chunk_id import CHUNK_ID_COLUMN, ChunkIdEncoder
from hub.core.storage import LRUCache, MemoryProvider, StorageProvider, LocalProvider
from hub.core.tiling.deserialize import combine_chunks
from hub.util.exceptions import (
DatasetUnsupportedPytorch,
SampleDecompressionError,
Expand All @@ -32,8 +33,8 @@ class IOBlock:
Represents ordered sequential read of samples from corresponding tensor chunks.
"""

def __init__(self, chunks: List[str], indexes: List[int]) -> None:
self._chunks: List[str] = chunks
def __init__(self, chunks: List[List[str]], indexes: List[int]) -> None:
self._chunks: List[List[str]] = chunks
self._ind: List[int] = indexes

def shuffle(self):
Expand All @@ -42,13 +43,13 @@ def shuffle(self):
"""
shuffle(self._ind)

def chunk_name(self, tensor_index: int) -> str:
def chunk_names(self, tensor_index: int) -> List[str]:
return self._chunks[tensor_index]

def indices(self) -> List[int]:
return self._ind

def chunks(self) -> List[str]:
def chunks(self) -> List[List[str]]:
return self._chunks

def split(self, n) -> List["IOBlock"]:
Expand Down Expand Up @@ -289,25 +290,31 @@ def stream(self, block: IOBlock):
for keyid, (key, engine) in enumerate(self.chunk_engines.items()):
chunk_class = engine.chunk_class
try:
commit_id = engine.get_chunk_commit(block.chunk_name(keyid))
c_key = get_chunk_key(key, block.chunk_name(keyid), commit_id)
chunk: BaseChunk

if self.local_caches is not None:
local_cache = self.local_caches[key]
chunks: List[BaseChunk] = []
c_names = block.chunk_names(keyid)

if c_key in local_cache:
chunk = local_cache.get_cachable(c_key, chunk_class, meta=engine.chunk_args) # type: ignore
for c_name in c_names:
commit_id = engine.get_chunk_commit(c_name)
c_key = get_chunk_key(key, c_name, commit_id)
if self.local_caches is not None:
local_cache = self.local_caches[key]

if c_key in local_cache:
chunk = local_cache.get_cachable(c_key, chunk_class, meta=engine.chunk_args) # type: ignore
else:
chunk = engine.get_chunk(c_key)
local_cache[c_key] = chunk

# send data to actual storage
local_cache._forward(c_key, True)
else:
chunk = engine.get_chunk(c_key)
local_cache[c_key] = chunk

# send data to actual storage
local_cache._forward(c_key, True)
chunks.append(chunk)
if len(chunks) == 1:
data = engine.read_sample_from_chunk(idx, chunk)
else:
chunk = engine.get_chunk(c_key)

data = engine.read_sample_from_chunk(idx, chunk)
data = combine_chunks(chunks, idx, engine.tile_encoder)

if data is not None:
sample[key] = data
Expand Down Expand Up @@ -345,10 +352,20 @@ def list_blocks(self) -> List[IOBlock]:
next_it_value = int(next_it.value[0])

if next_it_value >= last_idx:
chunks = [
ChunkIdEncoder.name_from_id(cid) # type: ignore
for cid in [int(it.value[1]) for it in iterators]
]
chunks = []
for it in iterators:
cur_ids = []
if it.value[0] == next_it_value:
while not it.finished and it.value[0] == next_it_value:
cur_ids.append(it.value[1])
it.iternext()
else:
cur_ids.append(it.value[1])
cur_chunks = [
ChunkIdEncoder.name_from_id(cid) # type: ignore
for cid in cur_ids
]
chunks.append(cur_chunks)

streamable_ids = list(
ds_indicies_set.intersection(range(last_idx, next_it_value + 1))
Expand All @@ -361,8 +378,6 @@ def list_blocks(self) -> List[IOBlock]:

last_idx = next_it_value + 1

next(next_it, None)

return blocks

def _use_cache(self, storage: Union[StorageProvider, LRUCache]) -> LRUCache:
Expand Down
2 changes: 1 addition & 1 deletion hub/core/tiling/deserialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def coalesce_tiles(
sample_shape = sample_shape or tuple( # Infer sample shape from tile shapes
sum(
tile.shape[i]
for tile in tiles[tuple(slice(None) if j == i else 0 for j in range(ndim))]
for tile in tiles[tuple(slice(None) if j == i else 0 for j in range(ndim))] # type: ignore
)
for i in range(ndim)
)
Expand Down
21 changes: 21 additions & 0 deletions hub/integrations/tests/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,27 @@ def test_string_tensors(local_ds):
np.testing.assert_array_equal(batch["strings"], f"string{idx}")


@requires_torch
def test_pytorch_large(local_ds):
arr_list_1 = [np.random.randn(1500, 1500, i) for i in range(5)]
arr_list_2 = [np.random.randn(400, 1500, 4, i) for i in range(5)]
label_list = list(range(5))

with local_ds as ds:
ds.create_tensor("img1")
ds.create_tensor("img2")
ds.create_tensor("label")
ds.img1.extend(arr_list_1)
ds.img2.extend(arr_list_2)
ds.label.extend(label_list)

ptds = local_ds.pytorch()
for idx, batch in enumerate(ptds):
np.testing.assert_array_equal(batch["img1"][0], arr_list_1[idx])
np.testing.assert_array_equal(batch["img2"][0], arr_list_2[idx])
np.testing.assert_array_equal(batch["label"][0], idx)


def run_ddp(rank, size, ds, q, backend="gloo"):
import torch.distributed as dist
import os
Expand Down

0 comments on commit 6cf888b

Please sign in to comment.