Skip to content

Commit

Permalink
index: data_tree: handle partial imports
Browse files Browse the repository at this point in the history
  • Loading branch information
efiop committed Jan 4, 2024
1 parent bf5fe95 commit f1de712
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 33 deletions.
57 changes: 29 additions & 28 deletions dvc/repo/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,33 @@ def _load_storage_from_out(storage_map, key, out):
_load_storage_from_import(storage_map, key, out)


def _build_tree_from_outs(outs):
from dvc_data.hashfile.tree import Tree

tree = Tree()
for out in outs:
if not out.use_cache:
continue

ws, key = out.index_key

if not out.stage.is_partial_import:
tree.add((ws, *key), out.meta, out.hash_info)
continue

dep = out.stage.deps[0]
if not dep.files:
tree.add((ws, *key), dep.meta, dep.hash_info)
continue

for okey, ometa, ohi in dep.get_obj():
tree.add((ws, *key, *okey), ometa, ohi)

tree.digest()

return tree


class Index:
def __init__(
self,
Expand Down Expand Up @@ -504,20 +531,7 @@ def plot_keys(self) -> Dict[str, Set["DataIndexKey"]]:

@cached_property
def data_tree(self):
from dvc_data.hashfile.tree import Tree

tree = Tree()
for out in self.outs:
if not out.use_cache:
continue

ws, key = out.index_key

tree.add((ws, *key), out.meta, out.hash_info)

tree.digest()

return tree
return _build_tree_from_outs(self.outs)

@cached_property
def data(self) -> "Dict[str, DataIndex]":
Expand Down Expand Up @@ -772,20 +786,7 @@ def data_keys(self) -> Dict[str, Set["DataIndexKey"]]:

@cached_property
def data_tree(self):
from dvc_data.hashfile.tree import Tree

tree = Tree()
for out in self.outs:
if not out.use_cache:
continue

ws, key = out.index_key

tree.add((ws, *key), out.meta, out.hash_info)

tree.digest()

return tree
return _build_tree_from_outs(self.outs)

@cached_property
def data(self) -> Dict[str, Union["DataIndex", "DataIndexView"]]:
Expand Down
25 changes: 22 additions & 3 deletions dvc/testing/workspace_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,33 @@ def test_import_dir(self, tmp_dir, dvc, remote_version_aware):
assert (tmp_dir / "data_dir" / "subdir" / "file").read_text() == "modified"
assert (tmp_dir / "data_dir" / "new_file").read_text() == "new"

def test_import_no_download(self, tmp_dir, dvc, remote_version_aware):
def test_import_no_download(self, tmp_dir, dvc, remote_version_aware, scm):
remote_version_aware.gen({"data_dir": {"subdir": {"file": "file"}}})
dvc.imp_url("remote://upstream/data_dir", version_aware=True, no_download=True)
scm.add(["data_dir.dvc", ".gitignore"])
scm.commit("v1")
scm.tag("v1")

stage = first(dvc.index.stages)
assert not stage.outs[0].can_push

dvc.pull()
assert (tmp_dir / "data_dir" / "subdir" / "file").read_text() == "file"
(remote_version_aware / "data_dir" / "foo").write_text("foo")
dvc.update(no_download=True)
assert dvc.pull()["fetched"] == 2
assert (tmp_dir / "data_dir").read_text() == {
"foo": "foo",
"subdir": {"file": "file"},
}
scm.add(["data_dir.dvc", ".gitignore"])
scm.commit("update")

scm.checkout("v1")
dvc.cache.local.clear()
remove(tmp_dir / "data_dir")
assert dvc.pull()["fetched"] == 1
assert (tmp_dir / "data_dir").read_text() == {
"subdir": {"file": "file"},
}

dvc.commit(force=True)
assert dvc.status() == {}
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ dependencies = [
"configobj>=5.0.6",
"distro>=1.3",
"dpath<3,>=2.1.0",
"dvc-data>=3.6,<3.7",
"dvc-data>=3.7,<3.8",
"dvc-http>=2.29.0",
"dvc-render>=1.0.0,<2",
"dvc-studio-client>=0.17.1,<1",
Expand Down
2 changes: 1 addition & 1 deletion tests/func/test_data_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def test_pull_partial_import(tmp_dir, dvc, local_workspace):
stage = dvc.imp_url("remote://workspace/file", os.fspath(dst), no_download=True)

result = dvc.pull("file")
assert result["fetched"] == 0
assert result["fetched"] == 1
assert dst.exists()

assert stage.outs[0].get_hash().value == "d10b4c3ff123b26dc068d43a8bef2d23"
Expand Down

0 comments on commit f1de712

Please sign in to comment.