Skip to content

Commit

Permalink
exp save: add target arg
Browse files Browse the repository at this point in the history
  • Loading branch information
dberenbaum authored and BradyJ27 committed Apr 22, 2024
1 parent dee0ac1 commit 288304e
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 5 deletions.
15 changes: 14 additions & 1 deletion dvc/commands/experiments/save.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse

from dvc.cli import formatter
from dvc.cli import completion, formatter
from dvc.cli.command import CmdBase
from dvc.cli.utils import append_doc_link
from dvc.exceptions import DvcException
Expand All @@ -14,6 +14,7 @@ class CmdExperimentsSave(CmdBase):
def run(self):
try:
ref = self.repo.experiments.save(
targets=self.args.targets,
name=self.args.name,
force=self.args.force,
include_untracked=self.args.include_untracked,
Expand Down Expand Up @@ -41,6 +42,18 @@ def add_parser(experiments_subparsers, parent_parser):
help=EXPERIMENTS_SAVE_HELP,
formatter_class=formatter.RawDescriptionHelpFormatter,
)
save_parser.add_argument(
"targets",
nargs="*",
help="""\
Stages to save. 'dvc.yaml' by default.
The targets can be path to a dvc.yaml file or `.dvc` file,
or a stage name from dvc.yaml file from
current working directory. To save a stage from dvc.yaml
from other directories, the target must be a path followed by colon `:`
and then the stage name name.
""",
).complete = completion.DVCFILES_AND_STAGE
save_parser.add_argument(
"-f",
"--force",
Expand Down
8 changes: 7 additions & 1 deletion dvc/repo/experiments/executor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def _get_top_level_paths(cls, repo: "Repo") -> List["str"]:
def save(
cls,
info: "ExecutorInfo",
targets: Optional[Iterable[str]] = None,
force: bool = False,
include_untracked: Optional[List[str]] = None,
message: Optional[str] = None,
Expand All @@ -293,7 +294,12 @@ def save(
include_untracked.append(LOCK_FILE)

try:
stages = dvc.commit([], force=True, relink=False)
stages = []
if targets:
for target in targets:
stages.append(dvc.commit(target, force=True, relink=False))
else:
stages = dvc.commit([], force=True, relink=False)
exp_hash = cls.hash_exp(stages)
if include_untracked:
dvc.scm.add(include_untracked, force=True) # type: ignore[call-arg]
Expand Down
4 changes: 3 additions & 1 deletion dvc/repo/experiments/save.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Iterable, List, Optional

from funcy import first

Expand All @@ -14,6 +14,7 @@

def save(
repo: "Repo",
targets: Optional[Iterable[str]] = None,
name: Optional[str] = None,
force: bool = False,
include_untracked: Optional[List[str]] = None,
Expand All @@ -32,6 +33,7 @@ def save(
try:
save_result = executor.save(
executor.info,
targets=targets,
force=force,
include_untracked=include_untracked,
message=message,
Expand Down
16 changes: 16 additions & 0 deletions tests/func/experiments/test_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,19 @@ def test_exp_save_custom_message(tmp_dir, dvc, scm):

exp = dvc.experiments.save(message="custom commit message")
assert scm.gitpython.repo.commit(exp).message == "custom commit message"


def test_exp_save_target(tmp_dir, dvc, scm):
setup_stage(tmp_dir, dvc, scm)
orig_dvclock = (tmp_dir / "dvc.lock").read_text()
(tmp_dir / "bar").write_text("modified")

tmp_dir.dvc_gen({"file": "orig"}, commit="add files")
orig_dvcfile = (tmp_dir / "file.dvc").read_text()
(tmp_dir / "file").write_text("modified")

dvc.experiments.save(["file"])
assert (tmp_dir / "bar").read_text() == "modified"
assert (tmp_dir / "dvc.lock").read_text() == orig_dvclock
assert (tmp_dir / "file").read_text() == "modified"
assert (tmp_dir / "file.dvc").read_text() != orig_dvcfile
10 changes: 8 additions & 2 deletions tests/unit/command/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def test_experiments_rename_invalid(dvc, scm, mocker, capsys, caplog):


def test_experiments_save(dvc, scm, mocker):
cli_args = parse_args(["exp", "save", "--name", "exp-name", "--force"])
cli_args = parse_args(["exp", "save", "target", "--name", "exp-name", "--force"])
assert cli_args.func == CmdExperimentsSave

cmd = cli_args.func(cli_args)
Expand All @@ -484,7 +484,12 @@ def test_experiments_save(dvc, scm, mocker):
assert cmd.run() == 0

m.assert_called_once_with(
cmd.repo, name="exp-name", force=True, include_untracked=[], message=None
cmd.repo,
targets=["target"],
name="exp-name",
force=True,
include_untracked=[],
message=None,
)


Expand All @@ -500,6 +505,7 @@ def test_experiments_save_message(dvc, scm, mocker, flag):

m.assert_called_once_with(
cmd.repo,
targets=[],
name=None,
force=False,
include_untracked=[],
Expand Down

0 comments on commit 288304e

Please sign in to comment.