-
Notifications
You must be signed in to change notification settings - Fork 632
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AL - 1596] Dataset view saving #1380
Changes from 43 commits
18045dd
ef7728e
966cb78
3a2adde
0a77d8d
9f4ad43
8d448a0
94223ac
cc1bba5
c6c62e6
2ad9889
cc9cb82
e266f70
1c17976
5d98a8c
ee77b3e
b978ade
f7cc493
7f5cf62
f83f543
b3b199d
cebd5c8
c68f2f5
e21d5c3
2fdb2ce
2065f29
1f1e762
fb4bb05
11b902b
bd4b10b
5cc5ea6
0726ccf
3241610
d38259b
18e06f1
e032b82
06f6a62
49198c5
7aeec7e
506d28f
9aa3175
b8907a3
b5f9990
2f0c3a6
2438603
ed277c2
0a05f84
b9b22db
2a36b55
94de430
900c7c4
3ba16ab
c730333
6e2d830
7a50edb
b63497d
49905c0
70219db
43de6bc
baf38ad
f8f2bd5
4290553
5234d92
97cf742
a4ee4ce
9c90646
23f7821
5bec509
e49e2fc
f5bdfed
b4dcec0
aba1e61
b0efe23
d406fbb
bd4723f
4897a03
9e29f6a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,9 +11,10 @@ | |
from hub.api.info import load_info | ||
from hub.client.log import logger | ||
from hub.constants import FIRST_COMMIT_ID | ||
from hub.constants import DEFAULT_MEMORY_CACHE_SIZE, DEFAULT_LOCAL_CACHE_SIZE, MB | ||
from hub.core.fast_forwarding import ffw_dataset_meta | ||
from hub.core.index import Index | ||
from hub.core.lock import lock_version, unlock_version | ||
from hub.core.lock import lock_version, unlock_version, Lock | ||
from hub.core.meta.dataset_meta import DatasetMeta | ||
from hub.core.storage import LRUCache, S3Provider, MemoryProvider, GCSProvider | ||
from hub.core.tensor import ( | ||
|
@@ -26,6 +27,7 @@ | |
from hub.integrations import dataset_to_tensorflow | ||
from hub.util.bugout_reporter import hub_reporter | ||
from hub.util.dataset import try_flushing | ||
from hub.util.cache_chain import generate_chain | ||
from hub.util.exceptions import ( | ||
CouldNotCreateNewDatasetException, | ||
InvalidKeyTypeError, | ||
|
@@ -41,13 +43,17 @@ | |
InvalidTensorGroupNameError, | ||
LockedException, | ||
TensorGroupAlreadyExistsError, | ||
ReadOnlyModeError, | ||
NotLoggedInError, | ||
) | ||
from hub.util.keys import ( | ||
dataset_exists, | ||
get_dataset_info_key, | ||
get_dataset_meta_key, | ||
get_version_control_info_key, | ||
tensor_exists, | ||
get_queries_key, | ||
get_queries_lock_key, | ||
) | ||
from hub.util.path import get_path_from_storage | ||
from hub.util.remove_cache import get_base_storage | ||
|
@@ -65,7 +71,13 @@ | |
load_meta, | ||
warn_node_checkout, | ||
) | ||
from hub.client.utils import get_user_name | ||
|
||
|
||
from tqdm import tqdm # type: ignore | ||
from time import time | ||
import hashlib | ||
import json | ||
from collections import defaultdict | ||
|
||
|
||
|
@@ -115,11 +127,11 @@ def __init__( | |
d["path"] = path or get_path_from_storage(storage) | ||
d["storage"] = storage | ||
d["_read_only"] = read_only | ||
d["_locked_out"] = False | ||
d["_locked_out"] = False # User requested write access but was denied | ||
d["is_iteration"] = is_iteration | ||
d["is_first_load"] = is_first_load = version_state is None | ||
self.__dict__.update(d) | ||
d.clear() | ||
# self.__dict__.update(d) | ||
farizrahman4u marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# d.clear() | ||
d["index"] = index or Index() | ||
d["group_index"] = group_index | ||
d["_token"] = token | ||
|
@@ -565,10 +577,18 @@ def commit(self, message: Optional[str] = None) -> str: | |
|
||
Returns: | ||
str: the commit id of the stored commit that can be used to access the snapshot. | ||
|
||
Raises: | ||
Exception: if dataset is a filtered view. | ||
""" | ||
return self._commit(message) | ||
|
||
def _commit(self, message: Optional[str] = None, hash: Optional[str] = None) -> str: | ||
if getattr(self, "_is_filterd_view", False): | ||
farizrahman4u marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise Exception( | ||
"Cannot perform version control operations on a filtered dataset view." | ||
) | ||
|
||
try_flushing(self) | ||
|
||
self._initial_autoflush.append(self.storage.autoflush) | ||
|
@@ -595,14 +615,21 @@ def checkout(self, address: str, create: bool = False) -> Optional[str]: | |
create (bool): If True, creates a new branch with name as address. | ||
|
||
Returns: | ||
str, optional: The commit_id of the branch/commit that was checked out. | ||
If there are no commits present after checking out, returns the commit_id before the branch, if there are no commits, returns None. | ||
str: The commit_id of the dataset after checkout. | ||
|
||
Raises: | ||
Exception: if dataset is a filtered view. | ||
""" | ||
return self._checkout(address, create) | ||
|
||
def _checkout( | ||
self, address: str, create: bool = False, hash: Optional[str] = None | ||
) -> Optional[str]: | ||
if getattr(self, "_is_filterd_view", False): | ||
raise Exception( | ||
"Cannot perform version control operations on a filtered dataset view." | ||
) | ||
|
||
try_flushing(self) | ||
|
||
self._initial_autoflush.append(self.storage.autoflush) | ||
|
@@ -882,18 +909,24 @@ def filter( | |
num_workers: int = 0, | ||
scheduler: str = "threaded", | ||
progressbar: bool = True, | ||
store_result: bool = False, | ||
result_path: Optional[str] = None, | ||
result_ds_args: Optional[dict] = None, | ||
): | ||
"""Filters the dataset in accordance of filter function `f(x: sample) -> bool` | ||
|
||
Args: | ||
function(Callable | str): filter function that takes sample as argument and returns True/False | ||
function(Callable | str): Filter function that takes sample as argument and returns True/False | ||
if sample should be included in result. Also supports simplified expression evaluations. | ||
See hub.core.query.DatasetQuery for more details. | ||
num_workers(int): level of parallelization of filter evaluations. | ||
num_workers(int): Level of parallelization of filter evaluations. | ||
`0` indicates in-place for-loop evaluation, multiprocessing is used otherwise. | ||
scheduler(str): scheduler to use for multiprocessing evaluation. | ||
scheduler(str): Scheduler to use for multiprocessing evaluation. | ||
`threaded` is default | ||
progressbar(bool): display progress bar while filtering. True is default | ||
progressbar(bool): Display progress bar while filtering. True is default | ||
store_result (bool): If True, result of the filter will be saved to a dataset asynchronously. | ||
result_path (Optional, str): Path to save the filter result. Only applicable if `store_result` is True. | ||
result_ds_args (Optional, dict): Additional args for result dataset. Only applicable if `store_result` is True. | ||
|
||
Returns: | ||
View on Dataset with elements, that satisfy filter function | ||
|
@@ -919,6 +952,9 @@ def filter( | |
num_workers=num_workers, | ||
scheduler=scheduler, | ||
progressbar=progressbar, | ||
store_result=store_result, | ||
result_path=result_path, | ||
result_ds_args=result_ds_args, | ||
) | ||
|
||
def _get_total_meta(self): | ||
|
@@ -1261,6 +1297,9 @@ def __result__(self): | |
def __args__(self): | ||
return None | ||
|
||
def __bool__(self): | ||
return True | ||
|
||
def append(self, sample: Dict[str, Any], skip_ok: bool = False): | ||
if not skip_ok: | ||
for k in self.tensors: | ||
|
@@ -1303,3 +1342,187 @@ def append(self, sample: Dict[str, Any], skip_ok: bool = False): | |
"Error while attepting to rollback appends" | ||
) from e2 | ||
raise e | ||
|
||
def _view_hash(self): | ||
return hashlib.sha1( | ||
farizrahman4u marked this conversation as resolved.
Show resolved
Hide resolved
|
||
( | ||
f"{self.path}[{':'.join(str(e.value) for e in self.index.values)}]@{self.version_state['commit_id']}&{getattr(self, '_query', None)}" | ||
).encode() | ||
).hexdigest() | ||
|
||
def _get_view_info(self): | ||
if not hasattr(self, "_view_info"): | ||
farizrahman4u marked this conversation as resolved.
Show resolved
Hide resolved
|
||
tm = getattr(self, "_created_at", time()) | ||
hash = self._view_hash() | ||
info = { | ||
"id": hash, | ||
"description": "Virtual Datasource", | ||
"virtual-datasource": True, | ||
"source-dataset": self.path, | ||
"source-dataset-version": self.version_state["commit_id"], | ||
"created_at": tm, | ||
} | ||
|
||
query = getattr(self, "_query", None) | ||
if query: | ||
info["query"] = query | ||
info["source-dataset-index"] = getattr(self, "_source_ds_idx", None) | ||
self._view_info = info | ||
return self._view_info | ||
|
||
@staticmethod | ||
def _write_queries_json(ds, info): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When a view is saved in the
|
||
base_storage = get_base_storage(ds.storage) | ||
lock = Lock(base_storage, get_queries_lock_key()) | ||
lock.acquire(timeout=10, force=True) | ||
queries_key = get_queries_key() | ||
try: | ||
try: | ||
queries = json.loads(base_storage[queries_key].decode("utf-8")) | ||
except KeyError: | ||
queries = [] | ||
queries.append(info) | ||
base_storage[queries_key] = json.dumps(queries).encode("utf-8") | ||
finally: | ||
lock.release() | ||
|
||
def _write_vds(self, vds): | ||
"""Writes the indices of this view to a vds.""" | ||
info = self._get_view_info() | ||
with vds: | ||
vds.info.update(info) | ||
vds.create_tensor("VDS_INDEX", dtype="uint64").extend( | ||
list(self.index.values[0].indices(len(self))) | ||
) | ||
idxs = hub.dataset(vds.path)._vds.VDS_INDEX.numpy().reshape(-1).tolist() | ||
exp = list(self.index.values[0].indices(len(self))) | ||
assert idxs == exp, (idxs, exp, vds.path) | ||
farizrahman4u marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def _store_view_in_subdir(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See Type 1 VDS in notion doc. |
||
"""Stores this view under ".queries" sub directory of same storage.""" | ||
|
||
info = self._get_view_info() | ||
hash = info["id"] | ||
path = f".queries/{hash}" | ||
self.flush() | ||
get_base_storage(self.storage).subdir(path).clear() | ||
vds = self._sub_ds(path, empty=True) | ||
self._write_vds(vds) | ||
Dataset._write_queries_json(self, info) | ||
return vds | ||
|
||
def _store_view_in_user_queries_dataset(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See Type 2 VDS in notion doc. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no code coverage here |
||
"""Stores this view under hub://username/queries | ||
Only applicable for views of hub datasets. | ||
""" | ||
if len(self.index.values) > 1: | ||
raise NotImplementedError("Storing sub-sample slices is not supported yet.") | ||
username = get_user_name() | ||
if username == "public": | ||
raise NotLoggedInError("Unable to save query result. Not logged in.") | ||
|
||
info = self._get_view_info() | ||
hash = info["id"] | ||
|
||
queries_ds = hub.dataset(f"hub://{username}/queries") # create if doesn't exist | ||
|
||
path = f"hub://{username}/queries/{hash}" | ||
|
||
vds = hub.empty(path, overwrite=True) | ||
|
||
self._write_vds() | ||
|
||
Dataset._write_queries_json(queries_ds, info) | ||
|
||
return vds | ||
|
||
def _store_view_in_path(self, path, **ds_args): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See Type 3 VDS in notion doc. |
||
"""Stores this view at a given dataset path""" | ||
vds = hub.dataset(path, **ds_args) | ||
self._write_vds(vds) | ||
return vds | ||
|
||
def store(self, path: Optional[str] = None, _ret_ds: bool = False, **ds_args): | ||
farizrahman4u marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if len(self.index.values) > 1: | ||
raise NotImplementedError("Storing sub-sample slices is not supported yet.") | ||
|
||
if path is None and hasattr(self, "_vds"): | ||
vds = self._vds | ||
elif path is None: | ||
if isinstance(self, MemoryProvider): | ||
raise NotImplementedError( | ||
"Saving views inplace is not supported for in-memory datasets." | ||
) | ||
if self.read_only: | ||
if isinstance(self, hub.core.dataset.HubCloudDataset): | ||
vds = self._store_view_in_user_queries_dataset() | ||
else: | ||
raise ReadOnlyModeError( | ||
"Cannot save view in read only dataset. Speicify a path to store the view in a different location." | ||
) | ||
else: | ||
vds = self._store_view_in_subdir() | ||
else: | ||
vds = self._store_view_in_path(path, **ds_args) | ||
if _ret_ds: | ||
return vds | ||
return vds.path | ||
|
||
def _get_view(self): | ||
# Only applicable for virtual datasets | ||
farizrahman4u marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ds = hub.dataset(path=self.info["source-dataset"], verbose=False) | ||
ds = ds[self.VDS_INDEX.numpy().reshape(-1).tolist()] | ||
ds._vds = self | ||
return ds | ||
|
||
def _get_empty_vds(self, vds_path=None, query=None, **vds_args): | ||
view = self[:0] | ||
if query: | ||
view._query = query | ||
return view.store(vds_path, _ret_ds=True, **vds_args) | ||
|
||
def _get_query_history(self) -> List[str]: | ||
""" | ||
Internal. Returns a list of hashes which can be passed to Dataset._get_stored_vds to get a dataset view. | ||
""" | ||
try: | ||
queries = json.loads(self.storage[get_queries_key()].decode("utf-8")) | ||
return queries | ||
except KeyError: | ||
return [] | ||
|
||
def _sub_ds(self, path, empty=False): | ||
"""Loads a nested dataset. Internal. | ||
Note: Virtual datasets are returned as such, they are not converted to views. | ||
|
||
Args: | ||
empty (bool): If True, all contents of the sub directory is cleared before initializing the sub dataset. | ||
|
||
Returns: | ||
Sub dataset | ||
""" | ||
base_storage = get_base_storage(self.storage) | ||
sub_storage = base_storage.subdir(path) | ||
|
||
if self.path.startswith("hub://"): | ||
path = posixpath.join(self.path, path) | ||
cls = hub.core.dataset.HubCloudDataset | ||
else: | ||
path = sub_storage.root | ||
cls = hub.core.dataset.Dataset | ||
|
||
return cls( | ||
generate_chain( | ||
sub_storage, | ||
DEFAULT_MEMORY_CACHE_SIZE * MB, | ||
farizrahman4u marked this conversation as resolved.
Show resolved
Hide resolved
|
||
DEFAULT_LOCAL_CACHE_SIZE * MB, | ||
), | ||
path=path, | ||
token=self._token, | ||
) | ||
|
||
def _get_stored_vds(self, hash: str): | ||
""" | ||
Internal. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. better docstring |
||
""" | ||
return self._get_sub_ds(".queries/" + hash) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do not expose VDS (dataset with indices) to the user, instead a view to the source dataset is returned.