diff --git a/src/taskgraph/main.py b/src/taskgraph/main.py index 0c90ab927..5060f8fdf 100644 --- a/src/taskgraph/main.py +++ b/src/taskgraph/main.py @@ -477,6 +477,7 @@ def show_taskgraph(options): print(f"Generating {options['graph_attr']} @ {base_rev}", file=sys.stderr) ret |= generate_taskgraph(options, parameters, overrides, logdir) finally: + assert cur_rev repo.update(cur_rev) # Generate diff(s) diff --git a/src/taskgraph/util/hash.py b/src/taskgraph/util/hash.py index ad50dc5a1..65bf51b7e 100644 --- a/src/taskgraph/util/hash.py +++ b/src/taskgraph/util/hash.py @@ -4,9 +4,10 @@ import functools import hashlib -from pathlib import Path +import os from taskgraph.util import path as mozpath +from taskgraph.util.vcs import get_repository @functools.lru_cache(maxsize=None) @@ -52,8 +53,5 @@ def _find_matching_files(base_path, pattern): @functools.lru_cache(maxsize=None) def _get_all_files(base_path): - return [ - mozpath.normsep(str(path)) - for path in Path(base_path).rglob("*") - if path.is_file() - ] + repo = get_repository(os.getcwd()) + return repo.get_tracked_files(base_path) diff --git a/src/taskgraph/util/vcs.py b/src/taskgraph/util/vcs.py index 0b2db6a6f..80f08e4da 100644 --- a/src/taskgraph/util/vcs.py +++ b/src/taskgraph/util/vcs.py @@ -9,6 +9,7 @@ import subprocess from abc import ABC, abstractmethod from shutil import which +from typing import List, Optional from taskgraph.util.path import ancestors @@ -34,7 +35,7 @@ def __init__(self, path): self._env = os.environ.copy() - def run(self, *args: str, **kwargs): + def run(self, *args: str, **kwargs) -> str: return_codes = kwargs.pop("return_codes", []) cmd = (self.binary,) + args @@ -63,17 +64,17 @@ def head_rev(self) -> str: @property @abstractmethod - def base_rev(self): + def base_rev(self) -> str: """Hash of revision the current topic branch is based on.""" @property @abstractmethod - def branch(self): + def branch(self) -> Optional[str]: """Current branch or bookmark the checkout has active.""" @property @abstractmethod - def all_remote_names(self): + def all_remote_names(self) -> List[str]: """Name of all configured remote repositories.""" @property @@ -85,10 +86,10 @@ def default_remote_name(self) -> str: @property @abstractmethod - def remote_name(self): + def remote_name(self) -> str: """Name of the remote repository.""" - def _get_most_suitable_remote(self, remote_instructions): + def _get_most_suitable_remote(self, remote_instructions) -> str: remotes = self.all_remote_names # in case all_remote_names raised a RuntimeError @@ -113,19 +114,34 @@ def _get_most_suitable_remote(self, remote_instructions): @property @abstractmethod - def default_branch(self): + def default_branch(self) -> str: """Name of the default branch.""" @abstractmethod - def get_url(self, remote=None): + def get_url(self, remote: Optional[str]) -> str: """Get URL of the upstream repository.""" @abstractmethod - def get_commit_message(self, revision=None): + def get_commit_message(self, revision: Optional[str]) -> str: """Commit message of specified revision or current commit.""" @abstractmethod - def get_changed_files(self, diff_filter, mode="unstaged", rev=None, base_rev=None): + def get_tracked_files(self, *paths: str, rev: Optional[str] = None) -> List[str]: + """Return list of tracked files. + + ``*paths`` are path specifiers to limit results to. + ``rev`` is a revision specifier at which to retrieve the files. + Defaults to the parent of the working copy if unspecified. + """ + + @abstractmethod + def get_changed_files( + self, + diff_filter: Optional[str], + mode: Optional[str], + rev: Optional[str], + base_rev: Optional[str], + ) -> List[str]: """Return a list of files that are changed in: * either this repository's working copy, * or at a given revision (``rev``) @@ -152,7 +168,7 @@ def get_changed_files(self, diff_filter, mode="unstaged", rev=None, base_rev=Non """ @abstractmethod - def get_outgoing_files(self, diff_filter, upstream): + def get_outgoing_files(self, diff_filter: str, upstream: str) -> List[str]: """Return a list of changed files compared to upstream. ``diff_filter`` works the same as `get_changed_files`. @@ -162,7 +178,9 @@ def get_outgoing_files(self, diff_filter, upstream): """ @abstractmethod - def working_directory_clean(self, untracked=False, ignored=False): + def working_directory_clean( + self, untracked: Optional[bool] = False, ignored: Optional[bool] = False + ) -> bool: """Determine if the working directory is free of modifications. Returns True if the working directory does not have any file @@ -174,11 +192,11 @@ def working_directory_clean(self, untracked=False, ignored=False): """ @abstractmethod - def update(self, ref): + def update(self, ref: str) -> None: """Update the working directory to the specified reference.""" @abstractmethod - def find_latest_common_revision(self, base_ref_or_rev, head_rev): + def find_latest_common_revision(self, base_ref_or_rev: str, head_rev: str) -> str: """Find the latest revision that is common to both the given ``head_rev`` and ``base_ref_or_rev``. @@ -186,7 +204,7 @@ def find_latest_common_revision(self, base_ref_or_rev, head_rev): be returned.""" @abstractmethod - def does_revision_exist_locally(self, revision): + def does_revision_exist_locally(self, revision: str) -> bool: """Check whether this revision exists in the local repository. If this function returns an unexpected value, then make sure @@ -243,7 +261,8 @@ def default_branch(self): # https://www.mercurial-scm.org/wiki/StandardBranching#Don.27t_use_a_name_other_than_default_for_your_main_development_branch return "default" - def get_url(self, remote="default"): + def get_url(self, remote=None): + remote = remote or "default" return self.run("path", "-T", "{url}", remote).strip() def get_commit_message(self, revision=None): @@ -270,9 +289,12 @@ def _files_template(self, diff_filter): template += "{file_mods % '{file}\\n'}" return template - def get_changed_files( - self, diff_filter="ADM", mode="unstaged", rev=None, base_rev=None - ): + def get_tracked_files(self, *paths, rev=None): + rev = rev or "." + return self.run("files", "-r", rev, *paths).splitlines() + + def get_changed_files(self, diff_filter=None, mode=None, rev=None, base_rev=None): + diff_filter = diff_filter or "ADM" if rev is None: if base_rev is not None: raise ValueError("Cannot specify `base_rev` without `rev`") @@ -315,7 +337,7 @@ def working_directory_clean(self, untracked=False, ignored=False): return not len(self.run(*args).strip()) def update(self, ref): - return self.run("update", "--check", ref) + self.run("update", "--check", ref) def find_latest_common_revision(self, base_ref_or_rev, head_rev): ancestor = self.run( @@ -445,16 +467,21 @@ def _guess_default_branch(self): raise RuntimeError(f"Unable to find default branch. Got: {branches}") - def get_url(self, remote="origin"): + def get_url(self, remote=None): + remote = remote or "origin" return self.run("remote", "get-url", remote).strip() def get_commit_message(self, revision=None): revision = revision or "HEAD" return self.run("log", "-n1", "--format=%B", revision) - def get_changed_files( - self, diff_filter="ADM", mode="unstaged", rev=None, base_rev=None - ): + def get_tracked_files(self, *paths, rev=None): + rev = rev or "HEAD" + return self.run("ls-tree", "-r", "--name-only", rev, *paths).splitlines() + + def get_changed_files(self, diff_filter=None, mode=None, rev=None, base_rev=None): + diff_filter = diff_filter or "ADM" + mode = mode or "unstaged" assert all(f.lower() in self._valid_diff_filter for f in diff_filter) if rev is None: diff --git a/test/test_util_vcs.py b/test/test_util_vcs.py index e7ec2ff8b..7585fe367 100644 --- a/test/test_util_vcs.py +++ b/test/test_util_vcs.py @@ -4,6 +4,7 @@ import os import subprocess +from pathlib import Path from textwrap import dedent import pytest @@ -211,7 +212,31 @@ def test_default_branch_cloned_metadata(tmpdir, default_git_branch, repo): def assert_files(actual, expected): - assert set(map(os.path.basename, actual)) == set(expected) + assert set(actual) == set(expected) + + +def test_get_tracked_files(repo): + assert_files(repo.get_tracked_files(), ["first_file"]) + + second_file = Path(repo.path) / "subdir" / "second_file" + second_file.parent.mkdir() + second_file.write_text("foo") + assert_files(repo.get_tracked_files(), ["first_file"]) + + repo.run("add", str(second_file)) + assert_files(repo.get_tracked_files(), ["first_file"]) + + repo.run("commit", "-m", "Add second file") + rev = ".~1" if repo.tool == "hg" else "HEAD~1" + assert_files(repo.get_tracked_files(), ["first_file", "subdir/second_file"]) + assert_files(repo.get_tracked_files("subdir"), ["subdir/second_file"]) + assert_files(repo.get_tracked_files(rev=rev), ["first_file"]) + + if repo.tool == "git": + assert_files(repo.get_tracked_files("subdir", rev=rev), []) + elif repo.tool == "hg": + with pytest.raises(subprocess.CalledProcessError): + repo.get_tracked_files("subdir", rev=rev) def test_get_changed_files_no_changes(repo):