Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/taskgraph/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ def show_taskgraph(options):
print(f"Generating {options['graph_attr']} @ {base_rev}", file=sys.stderr)
ret |= generate_taskgraph(options, parameters, overrides, logdir)
finally:
assert cur_rev
repo.update(cur_rev)

# Generate diff(s)
Expand Down
10 changes: 4 additions & 6 deletions src/taskgraph/util/hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

import functools
import hashlib
from pathlib import Path
import os

from taskgraph.util import path as mozpath
from taskgraph.util.vcs import get_repository


@functools.lru_cache(maxsize=None)
Expand Down Expand Up @@ -52,8 +53,5 @@ def _find_matching_files(base_path, pattern):

@functools.lru_cache(maxsize=None)
def _get_all_files(base_path):
return [
mozpath.normsep(str(path))
for path in Path(base_path).rglob("*")
if path.is_file()
]
repo = get_repository(os.getcwd())
return repo.get_tracked_files(base_path)
75 changes: 51 additions & 24 deletions src/taskgraph/util/vcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import subprocess
from abc import ABC, abstractmethod
from shutil import which
from typing import List, Optional

from taskgraph.util.path import ancestors

Expand All @@ -34,7 +35,7 @@ def __init__(self, path):

self._env = os.environ.copy()

def run(self, *args: str, **kwargs):
def run(self, *args: str, **kwargs) -> str:
return_codes = kwargs.pop("return_codes", [])
cmd = (self.binary,) + args

Expand Down Expand Up @@ -63,17 +64,17 @@ def head_rev(self) -> str:

@property
@abstractmethod
def base_rev(self):
def base_rev(self) -> str:
"""Hash of revision the current topic branch is based on."""

@property
@abstractmethod
def branch(self):
def branch(self) -> Optional[str]:
"""Current branch or bookmark the checkout has active."""

@property
@abstractmethod
def all_remote_names(self):
def all_remote_names(self) -> List[str]:
"""Name of all configured remote repositories."""

@property
Expand All @@ -85,10 +86,10 @@ def default_remote_name(self) -> str:

@property
@abstractmethod
def remote_name(self):
def remote_name(self) -> str:
"""Name of the remote repository."""

def _get_most_suitable_remote(self, remote_instructions):
def _get_most_suitable_remote(self, remote_instructions) -> str:
remotes = self.all_remote_names

# in case all_remote_names raised a RuntimeError
Expand All @@ -113,19 +114,34 @@ def _get_most_suitable_remote(self, remote_instructions):

@property
@abstractmethod
def default_branch(self):
def default_branch(self) -> str:
"""Name of the default branch."""

@abstractmethod
def get_url(self, remote=None):
def get_url(self, remote: Optional[str]) -> str:
"""Get URL of the upstream repository."""

@abstractmethod
def get_commit_message(self, revision=None):
def get_commit_message(self, revision: Optional[str]) -> str:
"""Commit message of specified revision or current commit."""

@abstractmethod
def get_changed_files(self, diff_filter, mode="unstaged", rev=None, base_rev=None):
def get_tracked_files(self, *paths: str, rev: Optional[str] = None) -> List[str]:
"""Return list of tracked files.

``*paths`` are path specifiers to limit results to.
``rev`` is a revision specifier at which to retrieve the files.
Defaults to the parent of the working copy if unspecified.
"""

@abstractmethod
def get_changed_files(
self,
diff_filter: Optional[str],
mode: Optional[str],
rev: Optional[str],
base_rev: Optional[str],
) -> List[str]:
"""Return a list of files that are changed in:
* either this repository's working copy,
* or at a given revision (``rev``)
Expand All @@ -152,7 +168,7 @@ def get_changed_files(self, diff_filter, mode="unstaged", rev=None, base_rev=Non
"""

@abstractmethod
def get_outgoing_files(self, diff_filter, upstream):
def get_outgoing_files(self, diff_filter: str, upstream: str) -> List[str]:
"""Return a list of changed files compared to upstream.

``diff_filter`` works the same as `get_changed_files`.
Expand All @@ -162,7 +178,9 @@ def get_outgoing_files(self, diff_filter, upstream):
"""

@abstractmethod
def working_directory_clean(self, untracked=False, ignored=False):
def working_directory_clean(
self, untracked: Optional[bool] = False, ignored: Optional[bool] = False
) -> bool:
"""Determine if the working directory is free of modifications.

Returns True if the working directory does not have any file
Expand All @@ -174,19 +192,19 @@ def working_directory_clean(self, untracked=False, ignored=False):
"""

@abstractmethod
def update(self, ref):
def update(self, ref: str) -> None:
"""Update the working directory to the specified reference."""

@abstractmethod
def find_latest_common_revision(self, base_ref_or_rev, head_rev):
def find_latest_common_revision(self, base_ref_or_rev: str, head_rev: str) -> str:
"""Find the latest revision that is common to both the given
``head_rev`` and ``base_ref_or_rev``.

If no common revision exists, ``Repository.NULL_REVISION`` will
be returned."""

@abstractmethod
def does_revision_exist_locally(self, revision):
def does_revision_exist_locally(self, revision: str) -> bool:
"""Check whether this revision exists in the local repository.

If this function returns an unexpected value, then make sure
Expand Down Expand Up @@ -243,7 +261,8 @@ def default_branch(self):
# https://www.mercurial-scm.org/wiki/StandardBranching#Don.27t_use_a_name_other_than_default_for_your_main_development_branch
return "default"

def get_url(self, remote="default"):
def get_url(self, remote=None):
remote = remote or "default"
return self.run("path", "-T", "{url}", remote).strip()

def get_commit_message(self, revision=None):
Expand All @@ -270,9 +289,12 @@ def _files_template(self, diff_filter):
template += "{file_mods % '{file}\\n'}"
return template

def get_changed_files(
self, diff_filter="ADM", mode="unstaged", rev=None, base_rev=None
):
def get_tracked_files(self, *paths, rev=None):
rev = rev or "."
return self.run("files", "-r", rev, *paths).splitlines()

def get_changed_files(self, diff_filter=None, mode=None, rev=None, base_rev=None):
diff_filter = diff_filter or "ADM"
if rev is None:
if base_rev is not None:
raise ValueError("Cannot specify `base_rev` without `rev`")
Expand Down Expand Up @@ -315,7 +337,7 @@ def working_directory_clean(self, untracked=False, ignored=False):
return not len(self.run(*args).strip())

def update(self, ref):
return self.run("update", "--check", ref)
self.run("update", "--check", ref)

def find_latest_common_revision(self, base_ref_or_rev, head_rev):
ancestor = self.run(
Expand Down Expand Up @@ -445,16 +467,21 @@ def _guess_default_branch(self):

raise RuntimeError(f"Unable to find default branch. Got: {branches}")

def get_url(self, remote="origin"):
def get_url(self, remote=None):
remote = remote or "origin"
return self.run("remote", "get-url", remote).strip()

def get_commit_message(self, revision=None):
revision = revision or "HEAD"
return self.run("log", "-n1", "--format=%B", revision)

def get_changed_files(
self, diff_filter="ADM", mode="unstaged", rev=None, base_rev=None
):
def get_tracked_files(self, *paths, rev=None):
rev = rev or "HEAD"
return self.run("ls-tree", "-r", "--name-only", rev, *paths).splitlines()

def get_changed_files(self, diff_filter=None, mode=None, rev=None, base_rev=None):
diff_filter = diff_filter or "ADM"
mode = mode or "unstaged"
assert all(f.lower() in self._valid_diff_filter for f in diff_filter)

if rev is None:
Expand Down
27 changes: 26 additions & 1 deletion test/test_util_vcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import os
import subprocess
from pathlib import Path
from textwrap import dedent

import pytest
Expand Down Expand Up @@ -211,7 +212,31 @@ def test_default_branch_cloned_metadata(tmpdir, default_git_branch, repo):


def assert_files(actual, expected):
assert set(map(os.path.basename, actual)) == set(expected)
assert set(actual) == set(expected)


def test_get_tracked_files(repo):
assert_files(repo.get_tracked_files(), ["first_file"])

second_file = Path(repo.path) / "subdir" / "second_file"
second_file.parent.mkdir()
second_file.write_text("foo")
assert_files(repo.get_tracked_files(), ["first_file"])

repo.run("add", str(second_file))
assert_files(repo.get_tracked_files(), ["first_file"])

repo.run("commit", "-m", "Add second file")
rev = ".~1" if repo.tool == "hg" else "HEAD~1"
assert_files(repo.get_tracked_files(), ["first_file", "subdir/second_file"])
assert_files(repo.get_tracked_files("subdir"), ["subdir/second_file"])
assert_files(repo.get_tracked_files(rev=rev), ["first_file"])

if repo.tool == "git":
assert_files(repo.get_tracked_files("subdir", rev=rev), [])
elif repo.tool == "hg":
with pytest.raises(subprocess.CalledProcessError):
repo.get_tracked_files("subdir", rev=rev)


def test_get_changed_files_no_changes(repo):
Expand Down
Loading