Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ def get_opts() -> argparse.Namespace:
action="store_true",
help="Actually restart workflows for detected autorevert patterns",
)
workflow_parser.add_argument(
"--do-revert",
action="store_true",
help="When restarts complete and secondary pattern matches, log REVERT",
)
workflow_parser.add_argument(
"--dry-run",
action="store_true",
Expand Down Expand Up @@ -185,6 +190,7 @@ def main(*args, **kwargs) -> None:
hours=opts.hours,
verbose=opts.verbose,
do_restart=opts.do_restart,
do_revert=opts.do_revert,
dry_run=opts.dry_run,
ignore_common_errors=opts.ignore_common_errors,
)
Expand Down
241 changes: 219 additions & 22 deletions aws/lambda/pytorch-auto-revert/pytorch_auto_revert/autorevert_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,19 @@ def job_base_names(self) -> Set[str]:
return self._job_base_names

def normalize_job_name(self, name: str) -> str:
"""Strip shard suffix from job name for matching."""
"""Normalize job name to a stable base for matching across commits.

- Drop any trailing parenthetical qualifiers (e.g., "(rocm)", shard notes)
- Strip common shard suffixes like ", 1, 1, " used in CI naming
- Collapse redundant whitespace
"""
# Drop any trailing parenthetical qualifier
base = re.sub(r"\s*\(.*\)$", "", name)
# Remove patterns like ", 1, 1, " or ", 2, 3, " from job names
return re.sub(r", \d+, \d+, ", ", ", name)
base = re.sub(r", \d+, \d+, ", ", ", base)
# Collapse multiple spaces
base = re.sub(r"\s+", " ", base).strip()
return base

def get_job_base_names(self) -> Set[str]:
"""Get normalized job names (without shard info)."""
Expand All @@ -74,13 +84,24 @@ def __init__(
self._workflow_commits_cache: Dict[str, List[CommitJobs]] = {}
self._commit_history = None
self._ignore_classification_rules = ignore_classification_rules or set()
# Controls whether queries target restarted runs only (workflow_dispatch/tagged trunk/<sha>)
self._use_restarted_runs_only = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this a variable, if it never changes?


def get_workflow_commits(self, workflow_name: str) -> List[CommitJobs]:
"""Get workflow commits for a specific workflow, fetching if needed. From newer to older"""
if workflow_name not in self._workflow_commits_cache:
self._fetch_workflow_data()
return self._workflow_commits_cache.get(workflow_name, [])

def get_workflow_commit_by_sha(
self, workflow_name: str, sha: str
) -> Optional[CommitJobs]:
"""Return CommitJobs for a workflow and head_sha if present in cache."""
for cj in self.get_workflow_commits(workflow_name):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this loops happens inside a loop, I suspect that when running it for a long period like 2 years for evaluation this would be a significant bottleneck in the code due the quadratic nature.

Maybe the search here could leverage a map...?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trie would be even better for prefix search, but I don't think that would be a bottleneck

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll remove this function, it is not being used anywhere

if cj.head_sha.startswith(sha):
return cj
return None

@property
def workflow_commits(self) -> List[CommitJobs]:
"""Get workflow commits for the first workflow (backward compatibility)."""
Expand All @@ -106,7 +127,13 @@ def _fetch_workflow_data(self):
f"Fetching workflow data for {len(self.workflow_names)} workflows since {lookback_time.isoformat()}..."
)

query = """
base_where = (
"workflow_event = 'workflow_dispatch' AND head_branch LIKE 'trunk/%'"
if self._use_restarted_runs_only
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

condition is never true

else "workflow_event != 'workflow_dispatch' AND head_branch = 'main'"
)

query = f"""
SELECT
workflow_name,
head_sha,
Expand All @@ -118,11 +145,10 @@ def _fetch_workflow_data(self):
FROM
workflow_job FINAL
WHERE
workflow_name IN {workflow_names:Array(String)}
AND head_branch = 'main'
AND created_at >= {lookback_time:DateTime}
workflow_name IN {{workflow_names:Array(String)}}
AND {base_where}
AND created_at >= {{lookback_time:DateTime}}
AND dynamoKey LIKE 'pytorch/pytorch/%'
AND workflow_event != 'workflow_dispatch' -- Exclude restart jobs
ORDER BY
workflow_name, workflow_created_at DESC, head_sha, name
"""
Expand Down Expand Up @@ -221,7 +247,7 @@ def _find_last_commit_with_job(
job_results = []
for commit in commits:
for job in commit.jobs:
if job.name.split("(")[0] == job_name: # Normalize job name
if commit.normalize_job_name(job.name) == job_name:
job_results.append(job)
if job_results:
return (
Expand All @@ -230,6 +256,25 @@ def _find_last_commit_with_job(
)
return None, None

def _find_last_commit_with_rule(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function is not called anywhere, what is its purpose?

self, commits: Iterable[CommitJobs], rule: str, failures_only: bool = True
) -> Optional[Tuple[CommitJobs, List[JobResult]]]:
"""Find the first commit (per iteration order) that has any job with the given rule.

If failures_only is True, only consider jobs with conclusion == 'failure'.
"""
job_results = []
for commit in commits:
job_results = [
j
for j in commit.jobs
if j.classification_rule == rule
and (j.conclusion == "failure" if failures_only else True)
]
if job_results:
return commit, job_results
return None, None

def detect_autorevert_pattern_workflow(self, workflow_name: str) -> List[Dict]:
"""
Detect all autorevert patterns in commit job data for a specific workflow.
Expand Down Expand Up @@ -257,10 +302,11 @@ def detect_autorevert_pattern_workflow(self, workflow_name: str) -> List[Dict]:
if suspected_commit1.has_pending_jobs:
continue

# Primary path: use classified failures (highest precision)
suspected_failures = {
(
j.classification_rule,
j.name.split("(")[0],
suspected_commit1.normalize_job_name(j.name),
)
for j in suspected_commit1.failed_jobs
}
Expand All @@ -282,15 +328,17 @@ def detect_autorevert_pattern_workflow(self, workflow_name: str) -> List[Dict]:
suspected_failure_job_name,
)
)
if not newer_commit_same_job or not newer_same_jobs:
# No newer commit with the same job found
continue

if any(
j.classification_rule == suspected_failure_class_rule
for j in newer_same_jobs
if (
newer_commit_same_job
and newer_same_jobs
and any(
j.classification_rule == suspected_failure_class_rule
and j.conclusion == "failure"
for j in newer_same_jobs
)
):
# The newer commit has the same job failing
# The newer commit has the same failure (job may differ)
failure_key = (
suspected_failure_class_rule,
suspected_failure_job_name,
Expand All @@ -311,25 +359,61 @@ def detect_autorevert_pattern_workflow(self, workflow_name: str) -> List[Dict]:
)

if not last_commit_with_same_job or not last_same_jobs:
# No older commit with the same job found
# No older commit with any jobs found
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function _fild_last_commit_with_job is expected to only return the commit with the given job. Not any job.

If a commit ran some jobs, but not that specifically being checked on job_name it should be skipped in favor of finding the next one.

So, I believe the fix in the comment here is misplaced.

If there is a bug, maybe we should fix in the function itself? But I re-read it and could not pinpoint a problem.

https://github.com/pytorch/test-infra/pull/6983/files#diff-7174b1a731f38e3efb2765fac87a65ce7835d26a68cccd9c1e329e0b2070f1e2R234

continue

if any(j.classification_rule == failure_rule for j in last_same_jobs):
# The older commit has the same job failing with same rule
if any(
j.classification_rule == failure_rule and j.conclusion == "failure"
for j in last_same_jobs
):
# The older commit already has the same failure (regardless of job)
continue

# Ensure there is some overlap in job coverage between suspected and older commit
older_coverage = list(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can only happen because you removed the check in https://github.com/pytorch/test-infra/pull/6983/files#diff-7174b1a731f38e3efb2765fac87a65ce7835d26a68cccd9c1e329e0b2070f1e2L285

If you keep the check, it is guaranteed that the job_name is present in both commits. This is due checks in _fild_last_commit_with_job

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the check is different (older commit vs newer commit), but your point is correct

set(suspected_commit1.job_base_names)
& set(last_commit_with_same_job.job_base_names)
)
if not older_coverage:
# No overlapping jobs -> insufficient comparable signal
continue

# Cross-workflow baseline check: if multiple workflows were provided,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you post stats with and without this? I suspect this might be overkill and removing lots of possible commits from the list.

But need to do other fixes before running it...

Copy link
Contributor Author

@izaitsevfb izaitsevfb Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no point, it's slop, removed

# ensure the older_commit does NOT have the same failure in any sibling workflow.
older_sha = last_commit_with_same_job.head_sha
conflict_in_other_wf = False
if len(self.workflow_names) > 1:
for other_wf in self.workflow_names:
if other_wf == workflow_name:
continue
other_cj = self.get_workflow_commit_by_sha(other_wf, older_sha)
if other_cj and any(
j.classification_rule == failure_rule for j in other_cj.jobs
):
conflict_in_other_wf = True
break

if conflict_in_other_wf:
# Skip this pattern; baseline is not clean across workflows
continue

patterns.append(
{
"pattern_detected": True,
"workflow_name": workflow_name,
"failure_rule": failure_rule,
"job_name_base": job_name,
"newer_commits": [
newer_commit.head_sha,
suspected_commit1.head_sha,
],
"older_commit": last_commit_with_same_job.head_sha,
"failed_job_names": [j.name for j in last_same_jobs],
"older_job_coverage": [],
"older_commit": older_sha,
"failed_job_names": [
j.name
for j in suspected_commit1.failed_jobs
if j.classification_rule == failure_rule
][:10],
"older_job_coverage": older_coverage[:10],
}
)
break
Expand Down Expand Up @@ -378,6 +462,119 @@ def detect_autorevert_pattern(self) -> List[Dict]:

return all_patterns

def _fetch_single_commit_jobs(
self,
workflow_name: str,
head_sha: str,
restarted_only: bool = False,
) -> Optional[CommitJobs]:
"""Fetch jobs for a single workflow+commit, optionally only restarted runs.

Groups all jobs by head_sha (assumes at most one restart dispatch of interest).
Returns CommitJobs or None if no jobs found in lookback window.
"""
lookback_time = datetime.now() - timedelta(hours=self.lookback_hours)

where_event = (
"workflow_event = {we:String} AND head_branch LIKE {hb:String}"
if restarted_only
else "workflow_event != {we:String} AND head_branch = {hb:String}"
)

query = f"""
SELECT
head_sha,
name,
conclusion,
status,
torchci_classification.rule AS classification_rule,
created_at AS workflow_created_at
FROM workflow_job FINAL
WHERE workflow_name = {{workflow_name:String}}
AND head_sha = {{head_sha:String}}
AND {where_event}
AND created_at >= {{lookback_time:DateTime}}
AND dynamoKey LIKE 'pytorch/pytorch/%'
ORDER BY workflow_created_at DESC, name
"""

hb = "trunk/%" if restarted_only else "main"
we = "workflow_dispatch" if restarted_only else "workflow_dispatch"
# Note: for non-restarted we exclude workflow_dispatch via != in WHERE above

result = CHCliFactory().client.query(
query,
parameters={
"workflow_name": workflow_name,
"head_sha": head_sha,
"we": we,
"hb": hb,
"lookback_time": lookback_time,
},
)

rows = list(result.result_rows)
if not rows:
return None

# Use the newest created_at among returned rows as the commit's created_at marker
latest_created = max(r[5] for r in rows)
cj = CommitJobs(head_sha=head_sha, created_at=latest_created, jobs=[])
for row in rows:
_, name, conclusion, status, classification_rule, created_at = row
cj.jobs.append(
JobResult(
head_sha=head_sha,
name=name,
conclusion=conclusion,
status=status,
classification_rule=classification_rule or "",
workflow_created_at=created_at,
)
)
return cj

def confirm_commit_caused_failure_on_restarted(self, pattern: Dict) -> bool:
"""Confirm commit-caused failure using restarted runs.

Requires that:
- first failing commit's restarted run has the same failure classification for the job
- previous commit's restarted run does NOT have that failure classification for the job
- both restarted runs have no pending jobs
"""
workflow_name = pattern["workflow_name"]
job_base = pattern.get("job_name_base")
failure_rule = pattern["failure_rule"]
first_failing = pattern["newer_commits"][1]
previous_commit = pattern["older_commit"]

# Fetch restarted jobs for first failing and previous commits
failing_jobs = self._fetch_single_commit_jobs(
workflow_name, first_failing, restarted_only=True
)
prev_jobs = self._fetch_single_commit_jobs(
workflow_name, previous_commit, restarted_only=True
)
if not failing_jobs or not prev_jobs:
return False

# Pending check
if failing_jobs.has_pending_jobs or prev_jobs.has_pending_jobs:
return False

def has_rule(cj: CommitJobs, rule: str) -> bool:
return any(
cj.normalize_job_name(j.name) == job_base
and j.classification_rule == rule
and j.conclusion == "failure"
for j in cj.jobs
)

# Commit-caused if failing commit reproduces, previous does not
return has_rule(failing_jobs, failure_rule) and not has_rule(
prev_jobs, failure_rule
)

def get_commits_reverted(self) -> Set[str]:
"""
Get all commits that were reverted within the lookback window.
Expand Down
Loading