Skip to content

Commit

Permalink
Fix to ensure agreement isn't reloaded for each slice of dataset (#1347)
Browse files Browse the repository at this point in the history
* fix to ensure agreement isn't reloaded for each slice

* fix

* improved first load behaviour

* added another test
  • Loading branch information
AbhinavTuli authored Nov 22, 2021
1 parent 47a79f6 commit 960945f
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 11 deletions.
7 changes: 5 additions & 2 deletions hub/api/tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
import pytest
from hub.util.exceptions import MemoryDatasetCanNotBePickledError
import pickle
from hub.tests.dataset_fixtures import enabled_datasets


@enabled_datasets
@pytest.mark.parametrize(
"ds",
["memory_ds", "local_ds", "s3_ds", "gcs_ds", "hub_cloud_ds"],
indirect=True,
)
def test_dataset(ds):
if ds.path.startswith("mem://"):
with pytest.raises(MemoryDatasetCanNotBePickledError):
Expand Down
2 changes: 2 additions & 0 deletions hub/core/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
self._token = token
self.public = public
self.verbose = verbose
self.is_first_load = version_state is None
self.version_state: Dict[str, Any] = version_state or {}
self._info = None
self._set_derived_attributes()
Expand Down Expand Up @@ -172,6 +173,7 @@ def __setstate__(self, state: Dict[str, Any]):
state (dict): The pickled state used to restore the dataset.
"""
self.__dict__.update(state)
self.is_first_load = True
self._info = None
self._set_derived_attributes()

Expand Down
34 changes: 25 additions & 9 deletions hub/core/dataset/hub_cloud_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Any, Dict, Optional
from hub.constants import AGREEMENT_FILENAME, HUB_CLOUD_DEV_USERNAME
from hub.core.dataset import Dataset
from hub.client.client import HubBackendClient
Expand All @@ -16,14 +16,19 @@ def __init__(self, path, *args, **kwargs):
self._set_org_and_name()

super().__init__(*args, **kwargs)

if self.is_actually_cloud:
handle_dataset_agreement(self.agreement, path, self.ds_name, self.org_id)
else:
# NOTE: this can happen if you override `hub.core.dataset.FORCE_CLASS`
warn(
f'Created a hub cloud dataset @ "{self.path}" which does not have the "hub://" prefix. Note: this dataset should only be used for testing!'
)
self.first_load_init()

def first_load_init(self):
if self.is_first_load:
if self.is_actually_cloud:
handle_dataset_agreement(
self.agreement, self.path, self.ds_name, self.org_id
)
else:
# NOTE: this can happen if you override `hub.core.dataset.FORCE_CLASS`
warn(
f'Created a hub cloud dataset @ "{self.path}" which does not have the "hub://" prefix. Note: this dataset should only be used for testing!'
)

@property
def client(self):
Expand Down Expand Up @@ -94,3 +99,14 @@ def agreement(self) -> Optional[str]:
def add_agreeement(self, agreement: str):
self.storage.check_readonly()
self.storage[AGREEMENT_FILENAME] = agreement.encode("utf-8")

def __getstate__(self) -> Dict[str, Any]:
state = super().__getstate__()
state["org_id"] = self.org_id
state["ds_name"] = self.ds_name
return state

def __setstate__(self, state: Dict[str, Any]):
super().__setstate__(state)
self._client = None
self.first_load_init()

0 comments on commit 960945f

Please sign in to comment.