Skip to content

Commit

Permalink
[tiny] tqdm progress bar for pytorch (#1243)
Browse files Browse the repository at this point in the history
  • Loading branch information
verbose-void authored Oct 11, 2021
1 parent c9ef99b commit 5b49c96
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions hub/core/dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import hub
from tqdm import tqdm # type: ignore
import pickle
import warnings
import posixpath
Expand Down Expand Up @@ -476,13 +477,14 @@ def pytorch(
transform: Optional[Callable] = None,
tensors: Optional[Sequence[str]] = None,
num_workers: int = 1,
batch_size: Optional[int] = 1,
batch_size: int = 1,
drop_last: bool = False,
collate_fn: Optional[Callable] = None,
pin_memory: bool = False,
shuffle: bool = False,
buffer_size: int = 10 * 1000,
use_local_cache: bool = False,
use_progress_bar: bool = False,
):
"""Converts the dataset into a pytorch Dataloader.
Expand All @@ -494,7 +496,7 @@ def pytorch(
transform (Callable, optional) : Transformation function to be applied to each sample.
tensors (List, optional): Optionally provide a list of tensor names in the ordering that your training script expects. For example, if you have a dataset that has "image" and "label" tensors, if `tensors=["image", "label"]`, your training script should expect each batch will be provided as a tuple of (image, label).
num_workers (int): The number of workers to use for fetching data in parallel.
batch_size (int, optional): Number of samples per batch to load. Default value is 1.
batch_size (int): Number of samples per batch to load. Default value is 1.
drop_last (bool): Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. Default value is False.
Read torch.utils.data.DataLoader docs for more details.
Expand All @@ -505,13 +507,14 @@ def pytorch(
shuffle (bool): If True, the data loader will shuffle the data indices. Default value is False.
buffer_size (int): The size of the buffer used to prefetch/shuffle in MB. The buffer uses shared memory under the hood. Default value is 10 GB. Increasing the buffer_size will increase the extent of shuffling.
use_local_cache (bool): If True, the data loader will use a local cache to store data. This is useful when the dataset can fit on the machine and we don't want to fetch the data multiple times for each iteration. Default value is False.
use_progress_bar (bool): If True, tqdm will be wrapped around the returned dataloader. Default value is True.
Returns:
A torch.utils.data.DataLoader object.
"""
from hub.integrations import dataset_to_pytorch

return dataset_to_pytorch(
dataloader = dataset_to_pytorch(
self,
transform,
tensors,
Expand All @@ -525,6 +528,11 @@ def pytorch(
use_local_cache=use_local_cache,
)

if use_progress_bar:
dataloader = tqdm(dataloader, desc=self.path, total=len(self) // batch_size)

return dataloader

def _get_total_meta(self):
"""Returns tensor metas all together"""
return {
Expand Down

0 comments on commit 5b49c96

Please sign in to comment.