Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions dvc/commands/experiments/push.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,25 @@ def join_exps(exps):
def run(self):
from dvc.repo.experiments.push import UploadError

if self.args.queued:
result = self.repo.experiments.push(
self.args.git_remote,
queued=True,
force=self.args.force,
)
if pushed := result.get("queued", []):
from dvc.utils import humanize

exps = humanize.join([f"[bold]{e}[/]" for e in pushed])
ui.write(
f"Pushed queued experiment {exps}"
f" to Git remote {self.args.git_remote!r}.",
styled=True,
)
else:
ui.write("No queued experiments to push.")
return 0

try:
result = self.repo.experiments.push(
self.args.git_remote,
Expand Down Expand Up @@ -93,6 +112,12 @@ def add_parser(experiments_subparsers, parent_parser):
formatter_class=formatter.RawDescriptionHelpFormatter,
)
add_rev_selection_flags(experiments_push_parser, "Push", True)
experiments_push_parser.add_argument(
"--queued",
action="store_true",
default=False,
help="Push all queued experiments to the Git remote.",
)
experiments_push_parser.add_argument(
"-f",
"--force",
Expand Down
2 changes: 2 additions & 0 deletions dvc/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,8 @@ def transfer(
callback=cb,
)

staging.clear()

self.hash_info = obj.hash_info
self.files = None
return obj
Expand Down
53 changes: 51 additions & 2 deletions dvc/repo/experiments/push.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from dvc.utils.collections import ensure_list

from .exceptions import UnresolvedExpNamesError
from .refs import ExpRefInfo
from .refs import CELERY_QUEUE, ExpRefInfo
from .utils import exp_commits, exp_refs, exp_refs_by_baseline, resolve_name

if TYPE_CHECKING:
Expand Down Expand Up @@ -94,10 +94,15 @@ def push(
num: int = 1,
force: bool = False,
push_cache: bool = False,
queued: bool = False,
**kwargs: Any,
) -> dict[str, Any]:
exp_ref_set: set[ExpRefInfo] = set()
assert isinstance(repo.scm, Git)

if queued:
return _push_queued(repo, git_remote, force)

exp_ref_set: set[ExpRefInfo] = set()
if all_commits:
exp_ref_set.update(exp_refs(repo.scm))
if exp_names:
Expand Down Expand Up @@ -182,3 +187,47 @@ def _push_cache(
return repo.push(
jobs=jobs, remote=dvc_remote, run_cache=run_cache, revs=revs, workspace=False
)


def _push_queued(
repo: "Repo",
git_remote: str,
force: bool,
) -> dict[str, Any]:
"""Push queued experiments to the Git remote as temporary refs."""
from scmrepo.exceptions import AuthError

from dvc.scm import GitAuthError

queued_entries = list(repo.experiments.celery_queue.iter_queued())
if not queued_entries:
return {"queued": []}

# Create temporary refs for each queued experiment
temp_refs = []
for entry in queued_entries:
ref = f"{CELERY_QUEUE}/{entry.stash_rev}"
repo.scm.set_ref(ref, entry.stash_rev)
temp_refs.append(ref)

refspec_list = [f"{ref}:{ref}" for ref in temp_refs]
logger.debug("git push queued experiments %s -> '%s'", refspec_list, git_remote)

try:
with TqdmGit(desc="Pushing queued experiments") as pbar:
try:
repo.scm.push_refspecs(
git_remote,
refspec_list,
force=force,
progress=pbar.update_git,
)
except AuthError as exc:
raise GitAuthError(str(exc)) # noqa: B904
finally:
# Clean up local temporary refs
for ref in temp_refs:
repo.scm.remove_ref(ref)

pushed_names = [entry.name or entry.stash_rev[:7] for entry in queued_entries]
return {"queued": pushed_names}
1 change: 1 addition & 0 deletions dvc/repo/experiments/refs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
CELERY_NAMESPACE = f"{EXPS_NAMESPACE}/celery"
CELERY_STASH = f"{CELERY_NAMESPACE}/stash"
CELERY_FAILED_STASH = f"{CELERY_NAMESPACE}/failed"
CELERY_QUEUE = f"{CELERY_NAMESPACE}/queue"
EXEC_NAMESPACE = f"{EXPS_NAMESPACE}/exec"
EXEC_APPLY = f"{EXEC_NAMESPACE}/EXEC_APPLY"
EXEC_BRANCH = f"{EXEC_NAMESPACE}/EXEC_BRANCH"
Expand Down
12 changes: 12 additions & 0 deletions tests/func/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,18 @@ def test_add_with_out(tmp_dir, scm, dvc):
assert "/out_foo" in gitignore_content


def test_add_with_out_cleans_up_staging(tmp_dir, dvc, mocker):
"""Test that 'dvc add --out' cleans up the staging ODB after transfer."""
from dvc_objects.db import ObjectDB

tmp_dir.gen({"foo": "foo"})
clear_spy = mocker.spy(ObjectDB, "clear")
dvc.add("foo", out="out_foo")

assert (tmp_dir / "out_foo").read_text() == "foo"
assert clear_spy.call_count >= 1


def test_add_to_cache_different_name(tmp_dir, dvc, local_cloud):
local_cloud.gen({"data": {"foo": "foo", "bar": "bar"}})

Expand Down
44 changes: 44 additions & 0 deletions tests/unit/command/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,50 @@ def test_experiments_push(dvc, scm, mocker):
assert cmd.run() == 0


def test_experiments_push_queued(dvc, scm, mocker):
cli_args = parse_args(
[
"experiments",
"push",
"origin",
"--queued",
"--force",
]
)
assert cli_args.func == CmdExperimentsPush

cmd = cli_args.func(cli_args)
m = mocker.patch(
"dvc.repo.experiments.push.push", return_value={"queued": ["exp-1", "exp-2"]}
)

assert cmd.run() == 0

m.assert_called_once_with(
cmd.repo,
"origin",
queued=True,
force=True,
)


def test_experiments_push_queued_empty(dvc, scm, mocker):
cli_args = parse_args(
[
"experiments",
"push",
"origin",
"--queued",
]
)
assert cli_args.func == CmdExperimentsPush

cmd = cli_args.func(cli_args)
mocker.patch("dvc.repo.experiments.push.push", return_value={"queued": []})

assert cmd.run() == 0


def test_experiments_pull(dvc, scm, mocker):
cli_args = parse_args(
[
Expand Down
123 changes: 123 additions & 0 deletions tests/unit/repo/experiments/test_push.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""Tests for dvc.repo.experiments.push (queued experiments)."""

from unittest.mock import MagicMock

import pytest
from scmrepo.exceptions import AuthError

from dvc.repo.experiments.push import _push_queued
from dvc.repo.experiments.queue.base import QueueEntry
from dvc.repo.experiments.refs import CELERY_QUEUE


def _make_queue_entry(stash_rev, name=None):
return QueueEntry(
dvc_root="/repo",
scm_root="/repo",
stash_ref="refs/exps/celery/stash",
stash_rev=stash_rev,
baseline_rev="abc123",
branch=None,
name=name,
head_rev="def456",
)


class TestPushQueued:
def test_push_queued_empty(self):
"""When no experiments are queued, returns empty list."""
repo = MagicMock()
repo.experiments.celery_queue.iter_queued.return_value = iter([])

result = _push_queued(repo, "origin", force=False)

assert result == {"queued": []}
repo.scm.set_ref.assert_not_called()
repo.scm.push_refspecs.assert_not_called()

def test_push_queued_with_named_entries(self):
"""Named queued experiments are pushed as temp refs and names returned."""
repo = MagicMock()
entries = [
_make_queue_entry("aaa111bbb222ccc333", name="exp-1"),
_make_queue_entry("ddd444eee555fff666", name="exp-2"),
]
repo.experiments.celery_queue.iter_queued.return_value = iter(entries)
repo.scm.push_refspecs.return_value = {}

result = _push_queued(repo, "origin", force=False)

assert result == {"queued": ["exp-1", "exp-2"]}

# Verify temp refs were created
assert repo.scm.set_ref.call_count == 2
repo.scm.set_ref.assert_any_call(
f"{CELERY_QUEUE}/aaa111bbb222ccc333", "aaa111bbb222ccc333"
)
repo.scm.set_ref.assert_any_call(
f"{CELERY_QUEUE}/ddd444eee555fff666", "ddd444eee555fff666"
)

# Verify push was called with correct refspecs
repo.scm.push_refspecs.assert_called_once()
call_args = repo.scm.push_refspecs.call_args
assert call_args[0][0] == "origin"
refspecs = call_args[0][1]
assert len(refspecs) == 2
assert call_args[1]["force"] is False

# Verify temp refs were cleaned up
assert repo.scm.remove_ref.call_count == 2

def test_push_queued_unnamed_entries_use_short_rev(self):
"""Unnamed experiments fall back to short stash_rev as name."""
repo = MagicMock()
entries = [_make_queue_entry("aaa111bbb222ccc333", name=None)]
repo.experiments.celery_queue.iter_queued.return_value = iter(entries)
repo.scm.push_refspecs.return_value = {}

result = _push_queued(repo, "origin", force=False)

assert result == {"queued": ["aaa111b"]}

def test_push_queued_force_flag(self):
"""Force flag is forwarded to push_refspecs."""
repo = MagicMock()
entries = [_make_queue_entry("aaa111bbb222ccc333", name="exp-1")]
repo.experiments.celery_queue.iter_queued.return_value = iter(entries)
repo.scm.push_refspecs.return_value = {}

_push_queued(repo, "origin", force=True)

call_args = repo.scm.push_refspecs.call_args
assert call_args[1]["force"] is True

def test_push_queued_cleans_up_refs_on_error(self):
"""Temp refs are cleaned up even if push_refspecs raises."""
repo = MagicMock()
entries = [_make_queue_entry("aaa111bbb222ccc333", name="exp-1")]
repo.experiments.celery_queue.iter_queued.return_value = iter(entries)
repo.scm.push_refspecs.side_effect = Exception("network error")

with pytest.raises(Exception, match="network error"):
_push_queued(repo, "origin", force=False)

# Temp refs should still be cleaned up
repo.scm.remove_ref.assert_called_once_with(
f"{CELERY_QUEUE}/aaa111bbb222ccc333"
)

def test_push_queued_auth_error(self):
"""AuthError is wrapped in GitAuthError."""
from dvc.scm import GitAuthError

repo = MagicMock()
entries = [_make_queue_entry("aaa111bbb222ccc333", name="exp-1")]
repo.experiments.celery_queue.iter_queued.return_value = iter(entries)
repo.scm.push_refspecs.side_effect = AuthError("bad credentials")

with pytest.raises(GitAuthError):
_push_queued(repo, "origin", force=False)

# Temp refs should still be cleaned up
repo.scm.remove_ref.assert_called_once()