diff --git a/dvc/commands/experiments/push.py b/dvc/commands/experiments/push.py index d48a738c7f..a2d5469423 100644 --- a/dvc/commands/experiments/push.py +++ b/dvc/commands/experiments/push.py @@ -51,6 +51,25 @@ def join_exps(exps): def run(self): from dvc.repo.experiments.push import UploadError + if self.args.queued: + result = self.repo.experiments.push( + self.args.git_remote, + queued=True, + force=self.args.force, + ) + if pushed := result.get("queued", []): + from dvc.utils import humanize + + exps = humanize.join([f"[bold]{e}[/]" for e in pushed]) + ui.write( + f"Pushed queued experiment {exps}" + f" to Git remote {self.args.git_remote!r}.", + styled=True, + ) + else: + ui.write("No queued experiments to push.") + return 0 + try: result = self.repo.experiments.push( self.args.git_remote, @@ -93,6 +112,12 @@ def add_parser(experiments_subparsers, parent_parser): formatter_class=formatter.RawDescriptionHelpFormatter, ) add_rev_selection_flags(experiments_push_parser, "Push", True) + experiments_push_parser.add_argument( + "--queued", + action="store_true", + default=False, + help="Push all queued experiments to the Git remote.", + ) experiments_push_parser.add_argument( "-f", "--force", diff --git a/dvc/output.py b/dvc/output.py index 384ba77cd8..f797d58867 100644 --- a/dvc/output.py +++ b/dvc/output.py @@ -1059,6 +1059,8 @@ def transfer( callback=cb, ) + staging.clear() + self.hash_info = obj.hash_info self.files = None return obj diff --git a/dvc/repo/experiments/push.py b/dvc/repo/experiments/push.py index 83614381aa..89144ad83b 100644 --- a/dvc/repo/experiments/push.py +++ b/dvc/repo/experiments/push.py @@ -14,7 +14,7 @@ from dvc.utils.collections import ensure_list from .exceptions import UnresolvedExpNamesError -from .refs import ExpRefInfo +from .refs import CELERY_QUEUE, ExpRefInfo from .utils import exp_commits, exp_refs, exp_refs_by_baseline, resolve_name if TYPE_CHECKING: @@ -94,10 +94,15 @@ def push( num: int = 1, force: bool = False, push_cache: bool = False, + queued: bool = False, **kwargs: Any, ) -> dict[str, Any]: - exp_ref_set: set[ExpRefInfo] = set() assert isinstance(repo.scm, Git) + + if queued: + return _push_queued(repo, git_remote, force) + + exp_ref_set: set[ExpRefInfo] = set() if all_commits: exp_ref_set.update(exp_refs(repo.scm)) if exp_names: @@ -182,3 +187,47 @@ def _push_cache( return repo.push( jobs=jobs, remote=dvc_remote, run_cache=run_cache, revs=revs, workspace=False ) + + +def _push_queued( + repo: "Repo", + git_remote: str, + force: bool, +) -> dict[str, Any]: + """Push queued experiments to the Git remote as temporary refs.""" + from scmrepo.exceptions import AuthError + + from dvc.scm import GitAuthError + + queued_entries = list(repo.experiments.celery_queue.iter_queued()) + if not queued_entries: + return {"queued": []} + + # Create temporary refs for each queued experiment + temp_refs = [] + for entry in queued_entries: + ref = f"{CELERY_QUEUE}/{entry.stash_rev}" + repo.scm.set_ref(ref, entry.stash_rev) + temp_refs.append(ref) + + refspec_list = [f"{ref}:{ref}" for ref in temp_refs] + logger.debug("git push queued experiments %s -> '%s'", refspec_list, git_remote) + + try: + with TqdmGit(desc="Pushing queued experiments") as pbar: + try: + repo.scm.push_refspecs( + git_remote, + refspec_list, + force=force, + progress=pbar.update_git, + ) + except AuthError as exc: + raise GitAuthError(str(exc)) # noqa: B904 + finally: + # Clean up local temporary refs + for ref in temp_refs: + repo.scm.remove_ref(ref) + + pushed_names = [entry.name or entry.stash_rev[:7] for entry in queued_entries] + return {"queued": pushed_names} diff --git a/dvc/repo/experiments/refs.py b/dvc/repo/experiments/refs.py index 3a34ff35a0..d0c8991b5b 100644 --- a/dvc/repo/experiments/refs.py +++ b/dvc/repo/experiments/refs.py @@ -13,6 +13,7 @@ CELERY_NAMESPACE = f"{EXPS_NAMESPACE}/celery" CELERY_STASH = f"{CELERY_NAMESPACE}/stash" CELERY_FAILED_STASH = f"{CELERY_NAMESPACE}/failed" +CELERY_QUEUE = f"{CELERY_NAMESPACE}/queue" EXEC_NAMESPACE = f"{EXPS_NAMESPACE}/exec" EXEC_APPLY = f"{EXEC_NAMESPACE}/EXEC_APPLY" EXEC_BRANCH = f"{EXEC_NAMESPACE}/EXEC_BRANCH" diff --git a/tests/func/test_add.py b/tests/func/test_add.py index 1bed9628a5..c250be376b 100644 --- a/tests/func/test_add.py +++ b/tests/func/test_add.py @@ -893,6 +893,18 @@ def test_add_with_out(tmp_dir, scm, dvc): assert "/out_foo" in gitignore_content +def test_add_with_out_cleans_up_staging(tmp_dir, dvc, mocker): + """Test that 'dvc add --out' cleans up the staging ODB after transfer.""" + from dvc_objects.db import ObjectDB + + tmp_dir.gen({"foo": "foo"}) + clear_spy = mocker.spy(ObjectDB, "clear") + dvc.add("foo", out="out_foo") + + assert (tmp_dir / "out_foo").read_text() == "foo" + assert clear_spy.call_count >= 1 + + def test_add_to_cache_different_name(tmp_dir, dvc, local_cloud): local_cloud.gen({"data": {"foo": "foo", "bar": "bar"}}) diff --git a/tests/unit/command/test_experiments.py b/tests/unit/command/test_experiments.py index 3537d39934..828b70af90 100644 --- a/tests/unit/command/test_experiments.py +++ b/tests/unit/command/test_experiments.py @@ -316,6 +316,50 @@ def test_experiments_push(dvc, scm, mocker): assert cmd.run() == 0 +def test_experiments_push_queued(dvc, scm, mocker): + cli_args = parse_args( + [ + "experiments", + "push", + "origin", + "--queued", + "--force", + ] + ) + assert cli_args.func == CmdExperimentsPush + + cmd = cli_args.func(cli_args) + m = mocker.patch( + "dvc.repo.experiments.push.push", return_value={"queued": ["exp-1", "exp-2"]} + ) + + assert cmd.run() == 0 + + m.assert_called_once_with( + cmd.repo, + "origin", + queued=True, + force=True, + ) + + +def test_experiments_push_queued_empty(dvc, scm, mocker): + cli_args = parse_args( + [ + "experiments", + "push", + "origin", + "--queued", + ] + ) + assert cli_args.func == CmdExperimentsPush + + cmd = cli_args.func(cli_args) + mocker.patch("dvc.repo.experiments.push.push", return_value={"queued": []}) + + assert cmd.run() == 0 + + def test_experiments_pull(dvc, scm, mocker): cli_args = parse_args( [ diff --git a/tests/unit/repo/experiments/test_push.py b/tests/unit/repo/experiments/test_push.py new file mode 100644 index 0000000000..d5abc25755 --- /dev/null +++ b/tests/unit/repo/experiments/test_push.py @@ -0,0 +1,123 @@ +"""Tests for dvc.repo.experiments.push (queued experiments).""" + +from unittest.mock import MagicMock + +import pytest +from scmrepo.exceptions import AuthError + +from dvc.repo.experiments.push import _push_queued +from dvc.repo.experiments.queue.base import QueueEntry +from dvc.repo.experiments.refs import CELERY_QUEUE + + +def _make_queue_entry(stash_rev, name=None): + return QueueEntry( + dvc_root="/repo", + scm_root="/repo", + stash_ref="refs/exps/celery/stash", + stash_rev=stash_rev, + baseline_rev="abc123", + branch=None, + name=name, + head_rev="def456", + ) + + +class TestPushQueued: + def test_push_queued_empty(self): + """When no experiments are queued, returns empty list.""" + repo = MagicMock() + repo.experiments.celery_queue.iter_queued.return_value = iter([]) + + result = _push_queued(repo, "origin", force=False) + + assert result == {"queued": []} + repo.scm.set_ref.assert_not_called() + repo.scm.push_refspecs.assert_not_called() + + def test_push_queued_with_named_entries(self): + """Named queued experiments are pushed as temp refs and names returned.""" + repo = MagicMock() + entries = [ + _make_queue_entry("aaa111bbb222ccc333", name="exp-1"), + _make_queue_entry("ddd444eee555fff666", name="exp-2"), + ] + repo.experiments.celery_queue.iter_queued.return_value = iter(entries) + repo.scm.push_refspecs.return_value = {} + + result = _push_queued(repo, "origin", force=False) + + assert result == {"queued": ["exp-1", "exp-2"]} + + # Verify temp refs were created + assert repo.scm.set_ref.call_count == 2 + repo.scm.set_ref.assert_any_call( + f"{CELERY_QUEUE}/aaa111bbb222ccc333", "aaa111bbb222ccc333" + ) + repo.scm.set_ref.assert_any_call( + f"{CELERY_QUEUE}/ddd444eee555fff666", "ddd444eee555fff666" + ) + + # Verify push was called with correct refspecs + repo.scm.push_refspecs.assert_called_once() + call_args = repo.scm.push_refspecs.call_args + assert call_args[0][0] == "origin" + refspecs = call_args[0][1] + assert len(refspecs) == 2 + assert call_args[1]["force"] is False + + # Verify temp refs were cleaned up + assert repo.scm.remove_ref.call_count == 2 + + def test_push_queued_unnamed_entries_use_short_rev(self): + """Unnamed experiments fall back to short stash_rev as name.""" + repo = MagicMock() + entries = [_make_queue_entry("aaa111bbb222ccc333", name=None)] + repo.experiments.celery_queue.iter_queued.return_value = iter(entries) + repo.scm.push_refspecs.return_value = {} + + result = _push_queued(repo, "origin", force=False) + + assert result == {"queued": ["aaa111b"]} + + def test_push_queued_force_flag(self): + """Force flag is forwarded to push_refspecs.""" + repo = MagicMock() + entries = [_make_queue_entry("aaa111bbb222ccc333", name="exp-1")] + repo.experiments.celery_queue.iter_queued.return_value = iter(entries) + repo.scm.push_refspecs.return_value = {} + + _push_queued(repo, "origin", force=True) + + call_args = repo.scm.push_refspecs.call_args + assert call_args[1]["force"] is True + + def test_push_queued_cleans_up_refs_on_error(self): + """Temp refs are cleaned up even if push_refspecs raises.""" + repo = MagicMock() + entries = [_make_queue_entry("aaa111bbb222ccc333", name="exp-1")] + repo.experiments.celery_queue.iter_queued.return_value = iter(entries) + repo.scm.push_refspecs.side_effect = Exception("network error") + + with pytest.raises(Exception, match="network error"): + _push_queued(repo, "origin", force=False) + + # Temp refs should still be cleaned up + repo.scm.remove_ref.assert_called_once_with( + f"{CELERY_QUEUE}/aaa111bbb222ccc333" + ) + + def test_push_queued_auth_error(self): + """AuthError is wrapped in GitAuthError.""" + from dvc.scm import GitAuthError + + repo = MagicMock() + entries = [_make_queue_entry("aaa111bbb222ccc333", name="exp-1")] + repo.experiments.celery_queue.iter_queued.return_value = iter(entries) + repo.scm.push_refspecs.side_effect = AuthError("bad credentials") + + with pytest.raises(GitAuthError): + _push_queued(repo, "origin", force=False) + + # Temp refs should still be cleaned up + repo.scm.remove_ref.assert_called_once()