diff --git a/hub/integrations/pytorch/dataset.py b/hub/integrations/pytorch/dataset.py index a8fb0f7368..a1db52931a 100644 --- a/hub/integrations/pytorch/dataset.py +++ b/hub/integrations/pytorch/dataset.py @@ -120,9 +120,7 @@ def _worker_loop( data = next(it) data = _process(data, transform) - data = { - k: torch.as_tensor(v).share_memory_() for k, v in data.items() - } + data = {k: torch.as_tensor(v) for k, v in data.items()} data_queue.put((wid, data)) requested -= 1 @@ -471,10 +469,7 @@ def __iter__(self): for i in range(len(next_batch[batch_keys[0]])): val = IterableOrderedDict( - { - k: next_batch[k][i].clone().detach().share_memory_() - for k in batch_keys - } + {k: next_batch[k][i].clone().detach() for k in batch_keys} ) if buffer is not None: