diff --git a/.github/actions/update-viablestrict/action.yml b/.github/actions/update-viablestrict/action.yml index c9742a363d..162058caae 100644 --- a/.github/actions/update-viablestrict/action.yml +++ b/.github/actions/update-viablestrict/action.yml @@ -55,13 +55,6 @@ runs: with: python-version: '3.11' - - name: Checkout test-infra for the fetch_latest_green_commit scripts - uses: actions/checkout@v3 - with: - repository: ${{ inputs.test-infra-repository }} - ref: ${{ inputs.test-infra-ref }} - path: test-infra - - uses: actions/checkout@v3 with: repository: ${{ inputs.repository }} @@ -85,8 +78,14 @@ runs: run: | set -ex - output=$(python ${GITHUB_WORKSPACE}/test-infra/.github/scripts/fetch_latest_green_commit.py --requires "${{ inputs.requires }}") + TEST_INFRA_PATH="${GITHUB_ACTION_PATH}/../../.." + + output=$(python ${TEST_INFRA_PATH}/tools/scripts/fetch_latest_green_commit.py \ + --required-checks "${{ inputs.requires }}" \ + --viable-strict-branch "viable/strict" \ + --main-branch "master") echo "latest_viable_sha=$output" >> "${GITHUB_OUTPUT}" + echo $output - name: Push SHA to viable/strict branch if: steps.get-latest-commit.outputs.latest_viable_sha != 'None' diff --git a/.github/scripts/fetch_latest_green_commit.py b/.github/scripts/fetch_latest_green_commit.py deleted file mode 100644 index 9b0aa2818e..0000000000 --- a/.github/scripts/fetch_latest_green_commit.py +++ /dev/null @@ -1,156 +0,0 @@ -import json -import os -import re -import sys -from typing import Any, cast, Dict, List, NamedTuple, Optional, Tuple - -import rockset # type: ignore[import] -from gitutils import _check_output - - -def eprint(msg: str) -> None: - print(msg, file=sys.stderr) - - -class WorkflowCheck(NamedTuple): - workflowName: str - name: str - jobName: str - conclusion: str - - -def get_latest_commits() -> List[str]: - latest_viable_commit = _check_output( - [ - "git", - "log", - "-n", - "1", - "--pretty=format:%H", - "origin/viable/strict", - ], - encoding="ascii", - ) - commits = _check_output( - [ - "git", - "rev-list", - f"{latest_viable_commit}^..HEAD", - "--remotes=*origin/main", - ], - encoding="ascii", - ).splitlines() - - return commits - - -def query_commits(commits: List[str]) -> List[Dict[str, Any]]: - rs = rockset.RocksetClient( - host="api.usw2a1.rockset.com", api_key=os.environ["ROCKSET_API_KEY"] - ) - params = [{"name": "shas", "type": "string", "value": ",".join(commits)}] - res = rs.QueryLambdas.execute_query_lambda( - # https://console.rockset.com/lambdas/details/commons.commit_jobs_batch_query - query_lambda="commit_jobs_batch_query", - version="19c74e10819104f9", - workspace="commons", - parameters=params, - ) - - return cast(List[Dict[str, Any]], res.results) - - -def print_commit_status(commit: str, results: Dict[str, Any]) -> None: - print(commit) - for check in results["results"]: - if check["sha"] == commit: - print(f"\t{check['conclusion']:>10}: {check['name']}") - - -def get_commit_results( - commit: str, results: List[Dict[str, Any]] -) -> List[Dict[str, Any]]: - workflow_checks = [] - for check in results: - if check["sha"] == commit: - workflow_checks.append( - WorkflowCheck( - workflowName=check["workflowName"], - name=check["name"], - jobName=check["jobName"], - conclusion=check["conclusion"], - )._asdict() - ) - return workflow_checks - - -def is_green( - commit: str, requires: List[str], results: List[Dict[str, Any]] -) -> Tuple[bool, str]: - workflow_checks = get_commit_results(commit, results) - - regex = {name: False for name in requires} - - for check in workflow_checks: - jobName = check["jobName"] - # Ignore result from unstable job, be it success or failure - if "unstable" in jobName: - continue - - workflowName = check["workflowName"] - conclusion = check["conclusion"] - for required_check in regex: - if re.match(required_check, workflowName, flags=re.IGNORECASE): - if conclusion not in ["success", "skipped"]: - return (False, workflowName + " checks were not successful") - else: - regex[required_check] = True - - missing_workflows = [x for x in regex.keys() if not regex[x]] - if len(missing_workflows) > 0: - return (False, "missing required workflows: " + ", ".join(missing_workflows)) - - return (True, "") - - -def get_latest_green_commit( - commits: List[str], requires: List[str], results: List[Dict[str, Any]] -) -> Optional[str]: - for commit in commits: - eprint(f"Checking {commit}") - green, msg = is_green(commit, requires, results) - if green: - eprint("GREEN") - return commit - else: - eprint("RED: " + msg) - return None - - -def parse_args() -> Any: - from argparse import ArgumentParser - - parser = ArgumentParser("Return the latest green commit from a PyTorch repo") - parser.add_argument( - "--requires", - type=str, - required=True, - help="the JSON list of required jobs that need to pass for the commit to be green", - ) - return parser.parse_args() - - -def main() -> None: - args = parse_args() - - commits = get_latest_commits() - results = query_commits(commits) - - latest_viable_commit = get_latest_green_commit( - commits, json.loads(args.requires), results - ) - print(latest_viable_commit) - - -if __name__ == "__main__": - main() diff --git a/.github/scripts/gitutils.py b/.github/scripts/gitutils.py deleted file mode 100644 index 88230eb689..0000000000 --- a/.github/scripts/gitutils.py +++ /dev/null @@ -1,444 +0,0 @@ -import os -import re -import tempfile -from collections import defaultdict -from datetime import datetime -from functools import wraps -from typing import ( - Any, - Callable, - cast, - Dict, - Iterator, - List, - Optional, - Tuple, - TypeVar, - Union, -) - -T = TypeVar("T") - -RE_GITHUB_URL_MATCH = re.compile("^https://.*@?github.com/(.+)/(.+)$") - - -def get_git_remote_name() -> str: - return os.getenv("GIT_REMOTE_NAME", "origin") - - -def get_git_repo_dir() -> str: - from pathlib import Path - - return os.getenv("GIT_REPO_DIR", str(Path(__file__).resolve().parent.parent.parent)) - - -def fuzzy_list_to_dict(items: List[Tuple[str, str]]) -> Dict[str, List[str]]: - """ - Converts list to dict preserving elements with duplicate keys - """ - rc: Dict[str, List[str]] = defaultdict(list) - for key, val in items: - rc[key].append(val) - return dict(rc) - - -def _check_output(items: List[str], encoding: str = "utf-8") -> str: - from subprocess import CalledProcessError, check_output, STDOUT - - try: - return check_output(items, stderr=STDOUT).decode(encoding) - except CalledProcessError as e: - msg = f"Command `{' '.join(e.cmd)}` returned non-zero exit code {e.returncode}" - stdout = e.stdout.decode(encoding) if e.stdout is not None else "" - stderr = e.stderr.decode(encoding) if e.stderr is not None else "" - # These get swallowed up, so print them here for debugging - print(f"stdout: \n{stdout}") - print(f"stderr: \n{stderr}") - if len(stderr) == 0: - msg += f"\n```\n{stdout}```" - else: - msg += f"\nstdout:\n```\n{stdout}```\nstderr:\n```\n{stderr}```" - raise RuntimeError(msg) from e - - -class GitCommit: - commit_hash: str - title: str - body: str - author: str - author_date: datetime - commit_date: Optional[datetime] - - def __init__( - self, - commit_hash: str, - author: str, - author_date: datetime, - title: str, - body: str, - commit_date: Optional[datetime] = None, - ) -> None: - self.commit_hash = commit_hash - self.author = author - self.author_date = author_date - self.commit_date = commit_date - self.title = title - self.body = body - - def __repr__(self) -> str: - return f"{self.title} ({self.commit_hash})" - - def __contains__(self, item: Any) -> bool: - return item in self.body or item in self.title - - -def parse_fuller_format(lines: Union[str, List[str]]) -> GitCommit: - """ - Expect commit message generated using `--format=fuller --date=unix` format, i.e.: - commit - Author: - AuthorDate: - Commit: - CommitDate: - - - - <full commit message> - - """ - if isinstance(lines, str): - lines = lines.split("\n") - # TODO: Handle merge commits correctly - if len(lines) > 1 and lines[1].startswith("Merge:"): - del lines[1] - assert len(lines) > 7 - assert lines[0].startswith("commit") - assert lines[1].startswith("Author: ") - assert lines[2].startswith("AuthorDate: ") - assert lines[3].startswith("Commit: ") - assert lines[4].startswith("CommitDate: ") - assert len(lines[5]) == 0 - return GitCommit( - commit_hash=lines[0].split()[1].strip(), - author=lines[1].split(":", 1)[1].strip(), - author_date=datetime.fromtimestamp(int(lines[2].split(":", 1)[1].strip())), - commit_date=datetime.fromtimestamp(int(lines[4].split(":", 1)[1].strip())), - title=lines[6].strip(), - body="\n".join(lines[7:]), - ) - - -class GitRepo: - def __init__(self, path: str, remote: str = "origin", debug: bool = False) -> None: - self.repo_dir = path - self.remote = remote - self.debug = debug - - def _run_git(self, *args: Any) -> str: - if self.debug: - print(f"+ git -C {self.repo_dir} {' '.join(args)}") - return _check_output(["git", "-C", self.repo_dir] + list(args)) - - def revlist(self, revision_range: str) -> List[str]: - rc = self._run_git("rev-list", revision_range, "--", ".").strip() - return rc.split("\n") if len(rc) > 0 else [] - - def branches_containing_ref( - self, ref: str, *, include_remote: bool = True - ) -> List[str]: - rc = ( - self._run_git("branch", "--remote", "--contains", ref) - if include_remote - else self._run_git("branch", "--contains", ref) - ) - return [x.strip() for x in rc.split("\n") if x.strip()] if len(rc) > 0 else [] - - def current_branch(self) -> str: - return self._run_git("symbolic-ref", "--short", "HEAD").strip() - - def checkout(self, branch: str) -> None: - self._run_git("checkout", branch) - - def fetch(self, ref: Optional[str] = None, branch: Optional[str] = None) -> None: - if branch is None and ref is None: - self._run_git("fetch", self.remote) - elif branch is None: - self._run_git("fetch", self.remote, ref) - else: - self._run_git("fetch", self.remote, f"{ref}:{branch}") - - def show_ref(self, name: str) -> str: - refs = self._run_git("show-ref", "-s", name).strip().split("\n") - if not all(refs[i] == refs[0] for i in range(1, len(refs))): - raise RuntimeError(f"reference {name} is ambiguous") - return refs[0] - - def rev_parse(self, name: str) -> str: - return self._run_git("rev-parse", "--verify", name).strip() - - def get_merge_base(self, from_ref: str, to_ref: str) -> str: - return self._run_git("merge-base", from_ref, to_ref).strip() - - def patch_id(self, ref: Union[str, List[str]]) -> List[Tuple[str, str]]: - is_list = isinstance(ref, list) - if is_list: - if len(ref) == 0: - return [] - ref = " ".join(ref) - rc = _check_output( - ["sh", "-c", f"git -C {self.repo_dir} show {ref}|git patch-id --stable"] - ).strip() - return [cast(Tuple[str, str], x.split(" ", 1)) for x in rc.split("\n")] - - def commits_resolving_gh_pr(self, pr_num: int) -> List[str]: - owner, name = self.gh_owner_and_name() - msg = f"Pull Request resolved: https://github.com/{owner}/{name}/pull/{pr_num}" - rc = self._run_git("log", "--format=%H", "--grep", msg).strip() - return rc.split("\n") if len(rc) > 0 else [] - - def get_commit(self, ref: str) -> GitCommit: - return parse_fuller_format( - self._run_git("show", "--format=fuller", "--date=unix", "--shortstat", ref) - ) - - def cherry_pick(self, ref: str) -> None: - self._run_git("cherry-pick", "-x", ref) - - def revert(self, ref: str) -> None: - self._run_git("revert", "--no-edit", ref) - - def compute_branch_diffs( - self, from_branch: str, to_branch: str - ) -> Tuple[List[str], List[str]]: - """ - Returns list of commmits that are missing in each other branch since their merge base - Might be slow if merge base is between two branches is pretty far off - """ - from_ref = self.rev_parse(from_branch) - to_ref = self.rev_parse(to_branch) - merge_base = self.get_merge_base(from_ref, to_ref) - from_commits = self.revlist(f"{merge_base}..{from_ref}") - to_commits = self.revlist(f"{merge_base}..{to_ref}") - from_ids = fuzzy_list_to_dict(self.patch_id(from_commits)) - to_ids = fuzzy_list_to_dict(self.patch_id(to_commits)) - for patch_id in set(from_ids).intersection(set(to_ids)): - from_values = from_ids[patch_id] - to_values = to_ids[patch_id] - if len(from_values) != len(to_values): - # Eliminate duplicate commits+reverts from the list - while len(from_values) > 0 and len(to_values) > 0: - frc = self.get_commit(from_values.pop()) - toc = self.get_commit(to_values.pop()) - # FRC branch might have PR number added to the title - if ( # noqa: SIM102 - frc.title != toc.title or frc.author_date != toc.author_date - ): - # HACK: Same commit were merged, reverted and landed again - # which creates a tracking problem - if ( - "pytorch/pytorch" not in self.remote_url() - or frc.commit_hash - not in { - "0a6a1b27a464ba5be5f587cce2ee12ab8c504dbf", - "6d0f4a1d545a8f161df459e8d4ccafd4b9017dbe", - "edf909e58f06150f7be41da2f98a3b9de3167bca", - "a58c6aea5a0c9f8759a4154e46f544c8b03b8db1", - "7106d216c29ca16a3504aa2bedad948ebcf4abc2", - } - ): - raise RuntimeError( - f"Unexpected differences between {frc} and {toc}" - ) - from_commits.remove(frc.commit_hash) - to_commits.remove(toc.commit_hash) - continue - for commit in from_values: - from_commits.remove(commit) - for commit in to_values: - to_commits.remove(commit) - # Another HACK: Patch-id is not stable for commits with binary files or for big changes across commits - # I.e. cherry-picking those from one branch into another will change patchid - if "pytorch/pytorch" in self.remote_url(): - for excluded_commit in { - "8e09e20c1dafcdbdb45c2d1574da68a32e54a3a5", - "5f37e5c2a39c3acb776756a17730b865f0953432", - "b5222584e6d6990c6585981a936defd1af14c0ba", - "84d9a2e42d5ed30ec3b8b4140c38dd83abbce88d", - "f211ec90a6cdc8a2a5795478b5b5c8d7d7896f7e", - }: - if excluded_commit in from_commits: - from_commits.remove(excluded_commit) - - return (from_commits, to_commits) - - def cherry_pick_commits(self, from_branch: str, to_branch: str) -> None: - orig_branch = self.current_branch() - self.checkout(to_branch) - from_commits, to_commits = self.compute_branch_diffs(from_branch, to_branch) - if len(from_commits) == 0: - print("Nothing to do") - self.checkout(orig_branch) - return - for commit in reversed(from_commits): - print(f"Cherry picking commit {commit}") - self.cherry_pick(commit) - self.checkout(orig_branch) - - def push(self, branch: str, dry_run: bool, retry: int = 3) -> None: - for cnt in range(retry): - try: - if dry_run: - self._run_git("push", "--dry-run", self.remote, branch) - else: - self._run_git("push", self.remote, branch) - except RuntimeError as e: - print(f"{cnt} push attempt failed with {e}") - self.fetch() - self._run_git("rebase", f"{self.remote}/{branch}") - - def head_hash(self) -> str: - return self._run_git("show-ref", "--hash", "HEAD").strip() - - def remote_url(self) -> str: - return self._run_git("remote", "get-url", self.remote) - - def gh_owner_and_name(self) -> Tuple[str, str]: - url = os.getenv("GIT_REMOTE_URL", None) - if url is None: - url = self.remote_url() - rc = RE_GITHUB_URL_MATCH.match(url) - if rc is None: - raise RuntimeError(f"Unexpected url format {url}") - return cast(Tuple[str, str], rc.groups()) - - def commit_message(self, ref: str) -> str: - return self._run_git("log", "-1", "--format=%B", ref) - - def amend_commit_message(self, msg: str) -> None: - self._run_git("commit", "--amend", "-m", msg) - - def diff(self, from_ref: str, to_ref: Optional[str] = None) -> str: - if to_ref is None: - return self._run_git("diff", f"{from_ref}^!") - return self._run_git("diff", f"{from_ref}..{to_ref}") - - -def clone_repo(username: str, password: str, org: str, project: str) -> GitRepo: - path = tempfile.mkdtemp() - _check_output( - [ - "git", - "clone", - f"https://{username}:{password}@github.com/{org}/{project}", - path, - ] - ).strip() - return GitRepo(path=path) - - -class PeekableIterator(Iterator[str]): - def __init__(self, val: str) -> None: - self._val = val - self._idx = -1 - - def peek(self) -> Optional[str]: - if self._idx + 1 >= len(self._val): - return None - return self._val[self._idx + 1] - - def __iter__(self) -> "PeekableIterator": - return self - - def __next__(self) -> str: - rc = self.peek() - if rc is None: - raise StopIteration - self._idx += 1 - return rc - - -def patterns_to_regex(allowed_patterns: List[str]) -> Any: - """ - pattern is glob-like, i.e. the only special sequences it has are: - - ? - matches single character - - * - matches any non-folder separator characters or no character - - ** - matches any characters or no character - Assuming that patterns are free of braces and backslashes - the only character that needs to be escaped are dot and plus - """ - rc = "(" - for idx, pattern in enumerate(allowed_patterns): - if idx > 0: - rc += "|" - pattern_ = PeekableIterator(pattern) - assert not any(c in pattern for c in "{}()[]\\") - for c in pattern_: - if c == ".": - rc += "\\." - elif c == "+": - rc += "\\+" - elif c == "*": - if pattern_.peek() == "*": - next(pattern_) - rc += ".*" - else: - rc += "[^/]*" - else: - rc += c - rc += ")" - return re.compile(rc) - - -def _shasum(value: str) -> str: - import hashlib - - m = hashlib.sha256() - m.update(value.encode("utf-8")) - return m.hexdigest() - - -def is_commit_hash(ref: str) -> bool: - "True if ref is hexadecimal number, else false" - try: - int(ref, 16) - except ValueError: - return False - return True - - -def are_ghstack_branches_in_sync( - repo: GitRepo, head_ref: str, base_ref: Optional[str] = None -) -> bool: - """Checks that diff between base and head is the same as diff between orig and its parent""" - orig_ref = re.sub(r"/head$", "/orig", head_ref) - if base_ref is None: - base_ref = re.sub(r"/head$", "/base", head_ref) - orig_diff_sha = _shasum(repo.diff(f"{repo.remote}/{orig_ref}")) - head_diff_sha = _shasum( - repo.diff( - base_ref if is_commit_hash(base_ref) else f"{repo.remote}/{base_ref}", - f"{repo.remote}/{head_ref}", - ) - ) - return orig_diff_sha == head_diff_sha - - -def retries_decorator( - rc: Any = None, num_retries: int = 3 -) -> Callable[[Callable[..., T]], Callable[..., T]]: - def decorator(f: Callable[..., T]) -> Callable[..., T]: - @wraps(f) - def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> T: - for idx in range(num_retries): - try: - return f(*args, **kwargs) - except Exception as e: - print( - f'Attempt {idx} of {num_retries} to call {f.__name__} failed with "{e}"' - ) - pass - return cast(T, rc) - - return wrapper - - return decorator diff --git a/.github/scripts/test_fetch_latest_green_commit.py b/.github/scripts/test_fetch_latest_green_commit.py deleted file mode 100644 index 264334f7ec..0000000000 --- a/.github/scripts/test_fetch_latest_green_commit.py +++ /dev/null @@ -1,150 +0,0 @@ -from typing import Any, Dict, List -from unittest import main, mock, TestCase - -from fetch_latest_green_commit import is_green, WorkflowCheck - -workflowNames = [ - "pull", - "trunk", - "Lint", - "linux-binary-libtorch-pre-cxx11", - "android-tests", - "windows-binary-wheel", - "periodic", - "docker-release-builds", - "nightly", - "pr-labels", - "Close stale pull requests", - "Update S3 HTML indices for download.pytorch.org", - "Create Release", -] - -requires = ["pull", "trunk", "lint", "linux-binary"] - - -def set_workflow_job_status( - workflow: List[Dict[str, Any]], name: str, status: str -) -> List[Dict[str, Any]]: - for check in workflow: - if check["workflowName"] == name: - check["conclusion"] = status - return workflow - - -class TestChecks: - def make_test_checks(self) -> List[Dict[str, Any]]: - workflow_checks = [] - for i in range(len(workflowNames)): - workflow_checks.append( - WorkflowCheck( - workflowName=workflowNames[i], - name="test/job", - jobName="job", - conclusion="success", - )._asdict() - ) - return workflow_checks - - -class TestPrintCommits(TestCase): - @mock.patch( - "fetch_latest_green_commit.get_commit_results", - return_value=TestChecks().make_test_checks(), - ) - def test_all_successful(self, mock_get_commit_results: Any) -> None: - "Test with workflows are successful" - workflow_checks = mock_get_commit_results() - self.assertTrue(is_green("sha", requires, workflow_checks)[0]) - - @mock.patch( - "fetch_latest_green_commit.get_commit_results", - return_value=TestChecks().make_test_checks(), - ) - def test_necessary_successful(self, mock_get_commit_results: Any) -> None: - "Test with necessary workflows are successful" - workflow_checks = mock_get_commit_results() - workflow_checks = set_workflow_job_status( - workflow_checks, workflowNames[8], "failed" - ) - workflow_checks = set_workflow_job_status( - workflow_checks, workflowNames[9], "failed" - ) - workflow_checks = set_workflow_job_status( - workflow_checks, workflowNames[10], "failed" - ) - workflow_checks = set_workflow_job_status( - workflow_checks, workflowNames[11], "failed" - ) - workflow_checks = set_workflow_job_status( - workflow_checks, workflowNames[12], "failed" - ) - self.assertTrue(is_green("sha", requires, workflow_checks)[0]) - - @mock.patch( - "fetch_latest_green_commit.get_commit_results", - return_value=TestChecks().make_test_checks(), - ) - def test_necessary_skipped(self, mock_get_commit_results: Any) -> None: - "Test with necessary job (ex: pull) skipped" - workflow_checks = mock_get_commit_results() - workflow_checks = set_workflow_job_status(workflow_checks, "pull", "skipped") - result = is_green("sha", requires, workflow_checks) - self.assertTrue(result[0]) - - @mock.patch( - "fetch_latest_green_commit.get_commit_results", - return_value=TestChecks().make_test_checks(), - ) - def test_skippable_skipped(self, mock_get_commit_results: Any) -> None: - "Test with skippable jobs (periodic and docker-release-builds skipped" - workflow_checks = mock_get_commit_results() - workflow_checks = set_workflow_job_status( - workflow_checks, "periodic", "skipped" - ) - workflow_checks = set_workflow_job_status( - workflow_checks, "docker-release-builds", "skipped" - ) - self.assertTrue(is_green("sha", requires, workflow_checks)) - - @mock.patch( - "fetch_latest_green_commit.get_commit_results", - return_value=TestChecks().make_test_checks(), - ) - def test_necessary_failed(self, mock_get_commit_results: Any) -> None: - "Test with necessary job (ex: Lint) failed" - workflow_checks = mock_get_commit_results() - workflow_checks = set_workflow_job_status(workflow_checks, "Lint", "failed") - result = is_green("sha", requires, workflow_checks) - self.assertFalse(result[0]) - self.assertEqual(result[1], "Lint checks were not successful") - - @mock.patch( - "fetch_latest_green_commit.get_commit_results", - return_value=TestChecks().make_test_checks(), - ) - def test_skippable_failed(self, mock_get_commit_results: Any) -> None: - "Test with failing skippable jobs (ex: docker-release-builds) should pass" - workflow_checks = mock_get_commit_results() - workflow_checks = set_workflow_job_status( - workflow_checks, "periodic", "skipped" - ) - workflow_checks = set_workflow_job_status( - workflow_checks, "docker-release-builds", "failed" - ) - result = is_green("sha", requires, workflow_checks) - self.assertTrue(result[0]) - - @mock.patch("fetch_latest_green_commit.get_commit_results", return_value={}) - def test_no_workflows(self, mock_get_commit_results: Any) -> None: - "Test with missing workflows" - workflow_checks = mock_get_commit_results() - result = is_green("sha", requires, workflow_checks) - self.assertFalse(result[0]) - self.assertEqual( - result[1], - "missing required workflows: pull, trunk, lint, linux-binary", - ) - - -if __name__ == "__main__": - main() diff --git a/.github/workflows/update-viablestrict.yml b/.github/workflows/update-viablestrict.yml index 1a93acea92..8f0e0c7046 100644 --- a/.github/workflows/update-viablestrict.yml +++ b/.github/workflows/update-viablestrict.yml @@ -54,7 +54,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v4 with: - python-version: '3.8' + python-version: '3.10' architecture: x64 cache: pip @@ -69,8 +69,12 @@ jobs: run: | cd ${{ inputs.repository }} - export PYTHONPATH=$PYTHONPATH:../../test-infra - output=$(python3 ../../test-infra/tools/scripts/fetch_latest_green_commit.py "${{ inputs.required_checks }}" "${{ inputs.viable_strict_branch }}") + TEST_INFRA_PATH="../../test-infra" + + output=$(python3 ${TEST_INFRA_PATH}/tools/scripts/fetch_latest_green_commit.py \ + --required-checks "${{ inputs.required_checks }}" \ + --viable-strict-branch "${{ inputs.viable_strict_branch }}" \ + --main-branch "main") echo "latest_viable_sha=$output" >> "${GITHUB_OUTPUT}" id: get-latest-commit diff --git a/tools/scripts/fetch_latest_green_commit.py b/tools/scripts/fetch_latest_green_commit.py index 52d5840dc6..d29915e213 100644 --- a/tools/scripts/fetch_latest_green_commit.py +++ b/tools/scripts/fetch_latest_green_commit.py @@ -1,12 +1,15 @@ +import json import os +from pathlib import Path import re import sys -from argparse import ArgumentParser - -from typing import Any, cast, Dict, List, NamedTuple, Tuple +from typing import Any, cast, Dict, List, NamedTuple, Optional, Tuple import rockset # type: ignore[import] -from tools.scripts.gitutils import check_output +REPO_ROOT = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(REPO_ROOT)) +from tools.scripts.gitutils import _check_output +sys.path.pop(0) def eprint(msg: str) -> None: @@ -20,8 +23,8 @@ class WorkflowCheck(NamedTuple): conclusion: str -def get_latest_commits(viable_strict_branch: str) -> List[str]: - latest_viable_commit = check_output( +def get_latest_commits(viable_strict_branch: str, main_branch: str) -> List[str]: + latest_viable_commit = _check_output( [ "git", "log", @@ -32,12 +35,12 @@ def get_latest_commits(viable_strict_branch: str) -> List[str]: ], encoding="ascii", ) - commits = check_output( + commits = _check_output( [ "git", "rev-list", f"{latest_viable_commit}^..HEAD", - "--remotes=*origin/master", + f"--remotes=*origin/{main_branch}", ], encoding="ascii", ).splitlines() @@ -51,8 +54,9 @@ def query_commits(commits: List[str]) -> List[Dict[str, Any]]: ) params = [{"name": "shas", "type": "string", "value": ",".join(commits)}] res = rs.QueryLambdas.execute_query_lambda( + # https://console.rockset.com/lambdas/details/commons.commit_jobs_batch_query query_lambda="commit_jobs_batch_query", - version="8003fdfd18b64696", + version="19c74e10819104f9", workspace="commons", parameters=params, ) @@ -61,6 +65,7 @@ def query_commits(commits: List[str]) -> List[Dict[str, Any]]: def print_commit_status(commit: str, results: Dict[str, Any]) -> None: + print(commit) for check in results["results"]: if check["sha"] == commit: print(f"\t{check['conclusion']:>10}: {check['name']}") @@ -84,12 +89,18 @@ def get_commit_results( def is_green( - commit: str, results: List[Dict[str, Any]], required_checks: List[str] + commit: str, requires: List[str], results: List[Dict[str, Any]] ) -> Tuple[bool, str]: workflow_checks = get_commit_results(commit, results) - regex = {check: False for check in required_checks} + + regex = {check: False for check in requires} for check in workflow_checks: + jobName = check["jobName"] + # Ignore result from unstable job, be it success or failure + if "unstable" in jobName: + continue + workflow_name = check["workflowName"] conclusion = check["conclusion"] for required_check in regex: @@ -107,15 +118,12 @@ def is_green( def get_latest_green_commit( - commits: List[str], results: List[Dict[str, Any]], required_checks: str -) -> Any: - required_checks = required_checks.split(",") - + commits: List[str], requires: List[str], results: List[Dict[str, Any]] +) -> Optional[str]: for commit in commits: eprint(f"Checking {commit}") - is_green_status, msg = is_green(commit, results, required_checks) - - if is_green_status: + green, msg = is_green(commit, requires, results) + if green: eprint("GREEN") return commit else: @@ -123,29 +131,28 @@ def get_latest_green_commit( return None -def _arg_parser() -> Any: - parser = ArgumentParser() - parser.add_argument("required_checks", type=str) - parser.add_argument("viable_strict_branch", type=str) +def parse_args() -> Any: + from argparse import ArgumentParser + parser = ArgumentParser("Return the latest green commit from a PyTorch repo") + parser.add_argument("--required-checks", type=str) + parser.add_argument("--viable-strict-branch", type=str, default="viable/strict") + parser.add_argument("--main-branch", type=str, default="main") return parser.parse_args() def main() -> None: - args = _arg_parser() + args = parse_args() - commits = get_latest_commits(args.viable_strict_branch) + commits = get_latest_commits(args.viable_strict_branch, args.main_branch) results = query_commits(commits) - - latest_viable_commit = get_latest_green_commit( - commits, results, args.required_checks - ) + try: + required_checks = json.loads(args.required_checks) + except json.JSONDecodeError: + required_checks = args.required_checks.split(",") + latest_viable_commit = get_latest_green_commit(commits, required_checks, results) print(latest_viable_commit) if __name__ == "__main__": - """ - The basic logic was taken from the pytorch/pytorch repo - - https://github.com/pytorch/pytorch/blob/master/.github/scripts/fetch_latest_green_commit.py - """ main() diff --git a/tools/scripts/gitutils.py b/tools/scripts/gitutils.py index 32040b7855..88230eb689 100644 --- a/tools/scripts/gitutils.py +++ b/tools/scripts/gitutils.py @@ -1,13 +1,48 @@ -#!/usr/bin/env python3 +import os +import re +import tempfile +from collections import defaultdict +from datetime import datetime +from functools import wraps +from typing import ( + Any, + Callable, + cast, + Dict, + Iterator, + List, + Optional, + Tuple, + TypeVar, + Union, +) -from typing import List +T = TypeVar("T") +RE_GITHUB_URL_MATCH = re.compile("^https://.*@?github.com/(.+)/(.+)$") -def check_output(items: List[str], encoding: str = "utf-8") -> str: + +def get_git_remote_name() -> str: + return os.getenv("GIT_REMOTE_NAME", "origin") + + +def get_git_repo_dir() -> str: + from pathlib import Path + + return os.getenv("GIT_REPO_DIR", str(Path(__file__).resolve().parent.parent.parent)) + + +def fuzzy_list_to_dict(items: List[Tuple[str, str]]) -> Dict[str, List[str]]: """ - The logic was taken from the pytorch/pytroch repo - - https://github.com/pytorch/pytorch/blob/master/.github/scripts/gitutils.py + Converts list to dict preserving elements with duplicate keys """ + rc: Dict[str, List[str]] = defaultdict(list) + for key, val in items: + rc[key].append(val) + return dict(rc) + + +def _check_output(items: List[str], encoding: str = "utf-8") -> str: from subprocess import CalledProcessError, check_output, STDOUT try: @@ -16,8 +51,394 @@ def check_output(items: List[str], encoding: str = "utf-8") -> str: msg = f"Command `{' '.join(e.cmd)}` returned non-zero exit code {e.returncode}" stdout = e.stdout.decode(encoding) if e.stdout is not None else "" stderr = e.stderr.decode(encoding) if e.stderr is not None else "" + # These get swallowed up, so print them here for debugging + print(f"stdout: \n{stdout}") + print(f"stderr: \n{stderr}") if len(stderr) == 0: msg += f"\n```\n{stdout}```" else: msg += f"\nstdout:\n```\n{stdout}```\nstderr:\n```\n{stderr}```" raise RuntimeError(msg) from e + + +class GitCommit: + commit_hash: str + title: str + body: str + author: str + author_date: datetime + commit_date: Optional[datetime] + + def __init__( + self, + commit_hash: str, + author: str, + author_date: datetime, + title: str, + body: str, + commit_date: Optional[datetime] = None, + ) -> None: + self.commit_hash = commit_hash + self.author = author + self.author_date = author_date + self.commit_date = commit_date + self.title = title + self.body = body + + def __repr__(self) -> str: + return f"{self.title} ({self.commit_hash})" + + def __contains__(self, item: Any) -> bool: + return item in self.body or item in self.title + + +def parse_fuller_format(lines: Union[str, List[str]]) -> GitCommit: + """ + Expect commit message generated using `--format=fuller --date=unix` format, i.e.: + commit <sha1> + Author: <author> + AuthorDate: <author date> + Commit: <committer> + CommitDate: <committer date> + + <title line> + + <full commit message> + + """ + if isinstance(lines, str): + lines = lines.split("\n") + # TODO: Handle merge commits correctly + if len(lines) > 1 and lines[1].startswith("Merge:"): + del lines[1] + assert len(lines) > 7 + assert lines[0].startswith("commit") + assert lines[1].startswith("Author: ") + assert lines[2].startswith("AuthorDate: ") + assert lines[3].startswith("Commit: ") + assert lines[4].startswith("CommitDate: ") + assert len(lines[5]) == 0 + return GitCommit( + commit_hash=lines[0].split()[1].strip(), + author=lines[1].split(":", 1)[1].strip(), + author_date=datetime.fromtimestamp(int(lines[2].split(":", 1)[1].strip())), + commit_date=datetime.fromtimestamp(int(lines[4].split(":", 1)[1].strip())), + title=lines[6].strip(), + body="\n".join(lines[7:]), + ) + + +class GitRepo: + def __init__(self, path: str, remote: str = "origin", debug: bool = False) -> None: + self.repo_dir = path + self.remote = remote + self.debug = debug + + def _run_git(self, *args: Any) -> str: + if self.debug: + print(f"+ git -C {self.repo_dir} {' '.join(args)}") + return _check_output(["git", "-C", self.repo_dir] + list(args)) + + def revlist(self, revision_range: str) -> List[str]: + rc = self._run_git("rev-list", revision_range, "--", ".").strip() + return rc.split("\n") if len(rc) > 0 else [] + + def branches_containing_ref( + self, ref: str, *, include_remote: bool = True + ) -> List[str]: + rc = ( + self._run_git("branch", "--remote", "--contains", ref) + if include_remote + else self._run_git("branch", "--contains", ref) + ) + return [x.strip() for x in rc.split("\n") if x.strip()] if len(rc) > 0 else [] + + def current_branch(self) -> str: + return self._run_git("symbolic-ref", "--short", "HEAD").strip() + + def checkout(self, branch: str) -> None: + self._run_git("checkout", branch) + + def fetch(self, ref: Optional[str] = None, branch: Optional[str] = None) -> None: + if branch is None and ref is None: + self._run_git("fetch", self.remote) + elif branch is None: + self._run_git("fetch", self.remote, ref) + else: + self._run_git("fetch", self.remote, f"{ref}:{branch}") + + def show_ref(self, name: str) -> str: + refs = self._run_git("show-ref", "-s", name).strip().split("\n") + if not all(refs[i] == refs[0] for i in range(1, len(refs))): + raise RuntimeError(f"reference {name} is ambiguous") + return refs[0] + + def rev_parse(self, name: str) -> str: + return self._run_git("rev-parse", "--verify", name).strip() + + def get_merge_base(self, from_ref: str, to_ref: str) -> str: + return self._run_git("merge-base", from_ref, to_ref).strip() + + def patch_id(self, ref: Union[str, List[str]]) -> List[Tuple[str, str]]: + is_list = isinstance(ref, list) + if is_list: + if len(ref) == 0: + return [] + ref = " ".join(ref) + rc = _check_output( + ["sh", "-c", f"git -C {self.repo_dir} show {ref}|git patch-id --stable"] + ).strip() + return [cast(Tuple[str, str], x.split(" ", 1)) for x in rc.split("\n")] + + def commits_resolving_gh_pr(self, pr_num: int) -> List[str]: + owner, name = self.gh_owner_and_name() + msg = f"Pull Request resolved: https://github.com/{owner}/{name}/pull/{pr_num}" + rc = self._run_git("log", "--format=%H", "--grep", msg).strip() + return rc.split("\n") if len(rc) > 0 else [] + + def get_commit(self, ref: str) -> GitCommit: + return parse_fuller_format( + self._run_git("show", "--format=fuller", "--date=unix", "--shortstat", ref) + ) + + def cherry_pick(self, ref: str) -> None: + self._run_git("cherry-pick", "-x", ref) + + def revert(self, ref: str) -> None: + self._run_git("revert", "--no-edit", ref) + + def compute_branch_diffs( + self, from_branch: str, to_branch: str + ) -> Tuple[List[str], List[str]]: + """ + Returns list of commmits that are missing in each other branch since their merge base + Might be slow if merge base is between two branches is pretty far off + """ + from_ref = self.rev_parse(from_branch) + to_ref = self.rev_parse(to_branch) + merge_base = self.get_merge_base(from_ref, to_ref) + from_commits = self.revlist(f"{merge_base}..{from_ref}") + to_commits = self.revlist(f"{merge_base}..{to_ref}") + from_ids = fuzzy_list_to_dict(self.patch_id(from_commits)) + to_ids = fuzzy_list_to_dict(self.patch_id(to_commits)) + for patch_id in set(from_ids).intersection(set(to_ids)): + from_values = from_ids[patch_id] + to_values = to_ids[patch_id] + if len(from_values) != len(to_values): + # Eliminate duplicate commits+reverts from the list + while len(from_values) > 0 and len(to_values) > 0: + frc = self.get_commit(from_values.pop()) + toc = self.get_commit(to_values.pop()) + # FRC branch might have PR number added to the title + if ( # noqa: SIM102 + frc.title != toc.title or frc.author_date != toc.author_date + ): + # HACK: Same commit were merged, reverted and landed again + # which creates a tracking problem + if ( + "pytorch/pytorch" not in self.remote_url() + or frc.commit_hash + not in { + "0a6a1b27a464ba5be5f587cce2ee12ab8c504dbf", + "6d0f4a1d545a8f161df459e8d4ccafd4b9017dbe", + "edf909e58f06150f7be41da2f98a3b9de3167bca", + "a58c6aea5a0c9f8759a4154e46f544c8b03b8db1", + "7106d216c29ca16a3504aa2bedad948ebcf4abc2", + } + ): + raise RuntimeError( + f"Unexpected differences between {frc} and {toc}" + ) + from_commits.remove(frc.commit_hash) + to_commits.remove(toc.commit_hash) + continue + for commit in from_values: + from_commits.remove(commit) + for commit in to_values: + to_commits.remove(commit) + # Another HACK: Patch-id is not stable for commits with binary files or for big changes across commits + # I.e. cherry-picking those from one branch into another will change patchid + if "pytorch/pytorch" in self.remote_url(): + for excluded_commit in { + "8e09e20c1dafcdbdb45c2d1574da68a32e54a3a5", + "5f37e5c2a39c3acb776756a17730b865f0953432", + "b5222584e6d6990c6585981a936defd1af14c0ba", + "84d9a2e42d5ed30ec3b8b4140c38dd83abbce88d", + "f211ec90a6cdc8a2a5795478b5b5c8d7d7896f7e", + }: + if excluded_commit in from_commits: + from_commits.remove(excluded_commit) + + return (from_commits, to_commits) + + def cherry_pick_commits(self, from_branch: str, to_branch: str) -> None: + orig_branch = self.current_branch() + self.checkout(to_branch) + from_commits, to_commits = self.compute_branch_diffs(from_branch, to_branch) + if len(from_commits) == 0: + print("Nothing to do") + self.checkout(orig_branch) + return + for commit in reversed(from_commits): + print(f"Cherry picking commit {commit}") + self.cherry_pick(commit) + self.checkout(orig_branch) + + def push(self, branch: str, dry_run: bool, retry: int = 3) -> None: + for cnt in range(retry): + try: + if dry_run: + self._run_git("push", "--dry-run", self.remote, branch) + else: + self._run_git("push", self.remote, branch) + except RuntimeError as e: + print(f"{cnt} push attempt failed with {e}") + self.fetch() + self._run_git("rebase", f"{self.remote}/{branch}") + + def head_hash(self) -> str: + return self._run_git("show-ref", "--hash", "HEAD").strip() + + def remote_url(self) -> str: + return self._run_git("remote", "get-url", self.remote) + + def gh_owner_and_name(self) -> Tuple[str, str]: + url = os.getenv("GIT_REMOTE_URL", None) + if url is None: + url = self.remote_url() + rc = RE_GITHUB_URL_MATCH.match(url) + if rc is None: + raise RuntimeError(f"Unexpected url format {url}") + return cast(Tuple[str, str], rc.groups()) + + def commit_message(self, ref: str) -> str: + return self._run_git("log", "-1", "--format=%B", ref) + + def amend_commit_message(self, msg: str) -> None: + self._run_git("commit", "--amend", "-m", msg) + + def diff(self, from_ref: str, to_ref: Optional[str] = None) -> str: + if to_ref is None: + return self._run_git("diff", f"{from_ref}^!") + return self._run_git("diff", f"{from_ref}..{to_ref}") + + +def clone_repo(username: str, password: str, org: str, project: str) -> GitRepo: + path = tempfile.mkdtemp() + _check_output( + [ + "git", + "clone", + f"https://{username}:{password}@github.com/{org}/{project}", + path, + ] + ).strip() + return GitRepo(path=path) + + +class PeekableIterator(Iterator[str]): + def __init__(self, val: str) -> None: + self._val = val + self._idx = -1 + + def peek(self) -> Optional[str]: + if self._idx + 1 >= len(self._val): + return None + return self._val[self._idx + 1] + + def __iter__(self) -> "PeekableIterator": + return self + + def __next__(self) -> str: + rc = self.peek() + if rc is None: + raise StopIteration + self._idx += 1 + return rc + + +def patterns_to_regex(allowed_patterns: List[str]) -> Any: + """ + pattern is glob-like, i.e. the only special sequences it has are: + - ? - matches single character + - * - matches any non-folder separator characters or no character + - ** - matches any characters or no character + Assuming that patterns are free of braces and backslashes + the only character that needs to be escaped are dot and plus + """ + rc = "(" + for idx, pattern in enumerate(allowed_patterns): + if idx > 0: + rc += "|" + pattern_ = PeekableIterator(pattern) + assert not any(c in pattern for c in "{}()[]\\") + for c in pattern_: + if c == ".": + rc += "\\." + elif c == "+": + rc += "\\+" + elif c == "*": + if pattern_.peek() == "*": + next(pattern_) + rc += ".*" + else: + rc += "[^/]*" + else: + rc += c + rc += ")" + return re.compile(rc) + + +def _shasum(value: str) -> str: + import hashlib + + m = hashlib.sha256() + m.update(value.encode("utf-8")) + return m.hexdigest() + + +def is_commit_hash(ref: str) -> bool: + "True if ref is hexadecimal number, else false" + try: + int(ref, 16) + except ValueError: + return False + return True + + +def are_ghstack_branches_in_sync( + repo: GitRepo, head_ref: str, base_ref: Optional[str] = None +) -> bool: + """Checks that diff between base and head is the same as diff between orig and its parent""" + orig_ref = re.sub(r"/head$", "/orig", head_ref) + if base_ref is None: + base_ref = re.sub(r"/head$", "/base", head_ref) + orig_diff_sha = _shasum(repo.diff(f"{repo.remote}/{orig_ref}")) + head_diff_sha = _shasum( + repo.diff( + base_ref if is_commit_hash(base_ref) else f"{repo.remote}/{base_ref}", + f"{repo.remote}/{head_ref}", + ) + ) + return orig_diff_sha == head_diff_sha + + +def retries_decorator( + rc: Any = None, num_retries: int = 3 +) -> Callable[[Callable[..., T]], Callable[..., T]]: + def decorator(f: Callable[..., T]) -> Callable[..., T]: + @wraps(f) + def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> T: + for idx in range(num_retries): + try: + return f(*args, **kwargs) + except Exception as e: + print( + f'Attempt {idx} of {num_retries} to call {f.__name__} failed with "{e}"' + ) + pass + return cast(T, rc) + + return wrapper + + return decorator diff --git a/tools/tests/test_fetch_latest_green_commit.py b/tools/tests/test_fetch_latest_green_commit.py index df68bd4e38..153238dfe4 100644 --- a/tools/tests/test_fetch_latest_green_commit.py +++ b/tools/tests/test_fetch_latest_green_commit.py @@ -19,7 +19,7 @@ "Create Release", ] -required_workflows = ["pull", "trunk", "lint", "linux-binary", "windows-binary"] +requires = ["pull", "trunk", "lint", "linux-binary"] def set_workflow_job_status( @@ -54,7 +54,7 @@ class TestPrintCommits(TestCase): def test_all_successful(self, mock_get_commit_results: Any) -> None: """Test with workflows are successful""" workflow_checks = mock_get_commit_results() - self.assertTrue(is_green("sha", workflow_checks, required_workflows)[0]) + self.assertTrue(is_green("sha", requires, workflow_checks)[0]) @mock.patch( "tools.scripts.fetch_latest_green_commit.get_commit_results", @@ -78,7 +78,7 @@ def test_necessary_successful(self, mock_get_commit_results: Any) -> None: workflow_checks = set_workflow_job_status( workflow_checks, workflow_names[12], "failed" ) - self.assertTrue(is_green("sha", workflow_checks, required_workflows)[0]) + self.assertTrue(is_green("sha", requires, workflow_checks)[0]) @mock.patch( "tools.scripts.fetch_latest_green_commit.get_commit_results", @@ -88,7 +88,7 @@ def test_necessary_skipped(self, mock_get_commit_results: Any) -> None: """Test with necessary job (ex: pull) skipped""" workflow_checks = mock_get_commit_results() workflow_checks = set_workflow_job_status(workflow_checks, "pull", "skipped") - result = is_green("sha", workflow_checks, required_workflows) + result = is_green("sha", requires, workflow_checks) self.assertTrue(result[0]) @mock.patch( @@ -104,7 +104,7 @@ def test_skippable_skipped(self, mock_get_commit_results: Any) -> None: workflow_checks = set_workflow_job_status( workflow_checks, "docker-release-builds", "skipped" ) - self.assertTrue(is_green("sha", workflow_checks, required_workflows)) + self.assertTrue(is_green("sha", requires, workflow_checks)) @mock.patch( "tools.scripts.fetch_latest_green_commit.get_commit_results", @@ -114,7 +114,7 @@ def test_necessary_failed(self, mock_get_commit_results: Any) -> None: """Test with necessary job (ex: Lint) failed""" workflow_checks = mock_get_commit_results() workflow_checks = set_workflow_job_status(workflow_checks, "Lint", "failed") - result = is_green("sha", workflow_checks, required_workflows) + result = is_green("sha", requires, workflow_checks) self.assertFalse(result[0]) self.assertEqual(result[1], "Lint checks were not successful") @@ -131,7 +131,7 @@ def test_skippable_failed(self, mock_get_commit_results: Any) -> None: workflow_checks = set_workflow_job_status( workflow_checks, "docker-release-builds", "failed" ) - result = is_green("sha", workflow_checks, required_workflows) + result = is_green("sha", requires, workflow_checks) self.assertTrue(result[0]) @mock.patch( @@ -140,17 +140,13 @@ def test_skippable_failed(self, mock_get_commit_results: Any) -> None: def test_no_workflows(self, mock_get_commit_results: Any) -> None: """Test with missing workflows""" workflow_checks = mock_get_commit_results() - result = is_green("sha", workflow_checks, required_workflows) + result = is_green("sha", requires, workflow_checks) self.assertFalse(result[0]) self.assertEqual( result[1], - "missing required workflows: pull, trunk, lint, linux-binary, windows-binary", + "missing required workflows: pull, trunk, lint, linux-binary", ) if __name__ == "__main__": - """ - The tests were migrated from the pytorch/pytorch repo - - https://github.com/pytorch/pytorch/blob/master/.github/scripts/test_fetch_latest_green_commit.py - """ main()