Skip to content

Commit 368c785

Browse files
Fix #10593 -- add --keep option for dvc experiments remove (#10633)
* Add keep_selected parameter, and corresponding code to keep only the selected exps (and remove all the other ones) * test keep_selected_by_name * test keep_selected_by_rev * test keep_selected multiple, by name * test keep all by name * test keep by rev, with num=2 * added option to cli * refactoring to meet pr needs * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed test_experiments to add keep_selected=False to remove tests * rename parameter to match cli option * follow the normal path, then invert the selection before removing * fixed tests for list ordering + fixed test with non existent name, it didn't make sense to delete everything if an exp name did not exist * changed cli option comment * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed typing issue * updated parameter name * removed handling queued experiments (since --queue would remove them all) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * code simplification, added __eq__ and __hash__ to be able to compare ExpRefs, updated and parametrized tests. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed linting issues * - --keep and --queue together raise an InvalidArgumentError - added a test to check if the error is raised - fixed CLI message * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * re-run gh tests. Some tests which did not involve my changes started failing while they were passing fine before. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 64ccd9c commit 368c785

File tree

6 files changed

+118
-2
lines changed

6 files changed

+118
-2
lines changed

dvc/commands/experiments/__init__.py

+9
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,15 @@ def add_parser(subparsers, parent_parser):
5858
hide_subparsers_from_help(experiments_subparsers)
5959

6060

61+
def add_keep_selection_flag(experiments_subcmd_parser):
62+
experiments_subcmd_parser.add_argument(
63+
"--keep",
64+
action="store_true",
65+
default=False,
66+
help="Keep the selected experiments instead of removing them.",
67+
)
68+
69+
6170
def add_rev_selection_flags(
6271
experiments_subcmd_parser, command: str, default: bool = True
6372
):

dvc/commands/experiments/remove.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def run(self):
3434
num=self.args.num,
3535
queue=self.args.queue,
3636
git_remote=self.args.git_remote,
37+
keep=self.args.keep,
3738
)
3839
if removed:
3940
ui.write(f"Removed experiments: {humanize.join(map(repr, removed))}")
@@ -44,7 +45,7 @@ def run(self):
4445

4546

4647
def add_parser(experiments_subparsers, parent_parser):
47-
from . import add_rev_selection_flags
48+
from . import add_keep_selection_flag, add_rev_selection_flags
4849

4950
EXPERIMENTS_REMOVE_HELP = "Remove experiments."
5051
experiments_remove_parser = experiments_subparsers.add_parser(
@@ -57,6 +58,7 @@ def add_parser(experiments_subparsers, parent_parser):
5758
)
5859
remove_group = experiments_remove_parser.add_mutually_exclusive_group()
5960
add_rev_selection_flags(experiments_remove_parser, "Remove", False)
61+
add_keep_selection_flag(experiments_remove_parser)
6062
remove_group.add_argument(
6163
"--queue", action="store_true", help="Remove all queued experiments."
6264
)

dvc/repo/experiments/refs.py

+9
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,12 @@ def from_ref(cls, ref: str):
6767
baseline_sha = parts[2] + parts[3]
6868
name = parts[4] if len(parts) == 5 else None
6969
return cls(baseline_sha, name)
70+
71+
def __eq__(self, other):
72+
if not isinstance(other, ExpRefInfo):
73+
return False
74+
75+
return self.baseline_sha == other.baseline_sha and self.name == other.name
76+
77+
def __hash__(self):
78+
return hash((self.baseline_sha, self.name))

dvc/repo/experiments/remove.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dvc.repo.scm_context import scm_context
77
from dvc.scm import Git, iter_revs
88

9-
from .exceptions import UnresolvedExpNamesError
9+
from .exceptions import InvalidArgumentError, UnresolvedExpNamesError
1010
from .utils import exp_refs, exp_refs_by_baseline, push_refspec
1111

1212
if TYPE_CHECKING:
@@ -30,10 +30,16 @@ def remove( # noqa: C901, PLR0912
3030
num: int = 1,
3131
queue: bool = False,
3232
git_remote: Optional[str] = None,
33+
keep: bool = False,
3334
) -> list[str]:
3435
removed: list[str] = []
36+
37+
if all([keep, queue]):
38+
raise InvalidArgumentError("Cannot use both `--keep` and `--queue`.")
39+
3540
if not any([exp_names, queue, all_commits, rev]):
3641
return removed
42+
3743
celery_queue: LocalCeleryQueue = repo.experiments.celery_queue
3844

3945
if queue:
@@ -43,6 +49,7 @@ def remove( # noqa: C901, PLR0912
4349

4450
exp_ref_list: list[ExpRefInfo] = []
4551
queue_entry_list: list[QueueEntry] = []
52+
4653
if exp_names:
4754
results: dict[str, ExpRefAndQueueEntry] = (
4855
celery_queue.get_ref_and_entry_by_names(exp_names, git_remote)
@@ -70,6 +77,10 @@ def remove( # noqa: C901, PLR0912
7077
exp_ref_list.extend(exp_refs(repo.scm, git_remote))
7178
removed = [ref.name for ref in exp_ref_list]
7279

80+
if keep:
81+
exp_ref_list = list(set(exp_refs(repo.scm, git_remote)) - set(exp_ref_list))
82+
removed = [ref.name for ref in exp_ref_list]
83+
7384
if exp_ref_list:
7485
_remove_commited_exps(repo.scm, exp_ref_list, git_remote)
7586

@@ -83,6 +94,7 @@ def remove( # noqa: C901, PLR0912
8394

8495
removed_refs = [str(r) for r in exp_ref_list]
8596
notify_refs_to_studio(repo, git_remote, removed=removed_refs)
97+
8698
return removed
8799

88100

tests/func/experiments/test_remove.py

+82
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,85 @@ def test_remove_multi_rev(tmp_dir, scm, dvc, exp_stage):
179179

180180
assert scm.get_ref(str(baseline_exp_ref)) is None
181181
assert scm.get_ref(str(new_exp_ref)) is None
182+
183+
184+
@pytest.mark.parametrize(
185+
"keep, expected_removed",
186+
[
187+
[["exp1"], ["exp2", "exp3"]],
188+
[["exp1", "exp2"], ["exp3"]],
189+
[["exp1", "exp2", "exp3"], []],
190+
[[], []], # remove does nothing if no experiments are specified
191+
],
192+
)
193+
def test_keep_selected_by_name(tmp_dir, scm, dvc, exp_stage, keep, expected_removed):
194+
# Setup: Run experiments
195+
refs = {}
196+
for i in range(1, len(keep) + len(expected_removed) + 1):
197+
results = dvc.experiments.run(
198+
exp_stage.addressing, params=[f"foo={i}"], name=f"exp{i}"
199+
)
200+
refs[f"exp{i}"] = first(exp_refs_by_rev(scm, first(results)))
201+
assert scm.get_ref(str(refs[f"exp{i}"])) is not None
202+
203+
removed = dvc.experiments.remove(exp_names=keep, keep=True)
204+
assert sorted(removed) == sorted(expected_removed)
205+
206+
for exp in expected_removed:
207+
assert scm.get_ref(str(refs[exp])) is None
208+
209+
for exp in keep:
210+
assert scm.get_ref(str(refs[exp])) is not None
211+
212+
213+
def test_keep_selected_by_nonexistent_name(tmp_dir, scm, dvc, exp_stage):
214+
# non existent name should raise an error
215+
with pytest.raises(UnresolvedExpNamesError):
216+
dvc.experiments.remove(exp_names=["nonexistent"], keep=True)
217+
218+
219+
@pytest.mark.parametrize(
220+
"num_exps, rev, num, expected_removed",
221+
[
222+
[2, "exp1", 1, ["exp2"]],
223+
[3, "exp3", 1, ["exp1", "exp2"]],
224+
[3, "exp3", 2, ["exp1"]],
225+
[3, "exp3", 3, []],
226+
[3, "exp2", 2, ["exp3"]],
227+
[4, "exp2", 2, ["exp3", "exp4"]],
228+
[4, "exp4", 2, ["exp1", "exp2"]],
229+
[1, None, 1, []], # remove does nothing if no experiments are specified
230+
],
231+
)
232+
def test_keep_selected_by_rev(
233+
tmp_dir, scm, dvc, exp_stage, num_exps, rev, num, expected_removed
234+
):
235+
refs = {}
236+
revs = {}
237+
# Setup: Run experiments and commit
238+
for i in range(1, num_exps + 1):
239+
scm.commit(f"commit{i}")
240+
results = dvc.experiments.run(
241+
exp_stage.addressing, params=[f"foo={i}"], name=f"exp{i}"
242+
)
243+
refs[f"exp{i}"] = first(exp_refs_by_rev(scm, first(results)))
244+
revs[f"exp{i}"] = scm.get_rev()
245+
assert scm.get_ref(str(refs[f"exp{i}"])) is not None
246+
247+
# Keep the experiment from the new revision
248+
removed = dvc.experiments.remove(rev=revs.get(rev), num=num, keep=True)
249+
assert sorted(removed) == sorted(expected_removed)
250+
251+
# Check remaining experiments
252+
for exp in expected_removed:
253+
assert scm.get_ref(str(refs[exp])) is None
254+
255+
for exp, ref in refs.items():
256+
if exp not in expected_removed:
257+
assert scm.get_ref(str(ref)) is not None
258+
259+
260+
def test_remove_with_queue_and_keep(tmp_dir, scm, dvc, exp_stage):
261+
# This should raise an exception, until decided otherwise
262+
with pytest.raises(InvalidArgumentError):
263+
dvc.experiments.remove(queue=True, keep=True)

tests/unit/command/test_experiments.py

+2
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,7 @@ def test_experiments_remove_flag(dvc, scm, mocker, capsys, caplog):
384384
num=2,
385385
queue=False,
386386
git_remote="myremote",
387+
keep=False,
387388
)
388389

389390

@@ -410,6 +411,7 @@ def test_experiments_remove_special(dvc, scm, mocker, capsys, caplog):
410411
num=1,
411412
queue=False,
412413
git_remote="myremote",
414+
keep=False,
413415
)
414416

415417

0 commit comments

Comments
 (0)