Skip to content

Commit 5f86d76

Browse files
authored
[Pytorch AutoRevert] - Improves autorevert check heuristics (#6853)
Do some improvements in the back analisys for the revert logic with the goal of improving precision and recall and validate as a valid strategy. Checked against the workflows: pull trunk inductor linux-binary-manywheel Old code: ``` Timeframe: 720 hours Commits checked: 6177 Auto revert patterns detected: 188 Actual reverts inside auto revert patterns detected: 24 (12.8%) Total revert commits in period: 115 Reverts that dont match any auto revert pattern detected: 91 ``` Newer code: ``` Workflow(s): pull, trunk, inductor, linux-binary-manywheel Timeframe: 720 hours Commits checked: 5403 Auto revert patterns detected: 442 Actual reverts inside auto revert patterns detected (precision): 48 (10.9%) Total revert commits in period: 115 Reverts that dont match any auto revert pattern detected (recall): 67 (58.3%) Per workflow precision: pull: 45 reverts out of 411 patterns (10.9%) trunk: 1 reverts out of 8 patterns (12.5%) inductor: 2 reverts out of 20 patterns (10.0%) linux-binary-manywheel: 0 reverts out of 3 patterns (0.0%) ``` Critical implemented changes: * Look forward and back for the first commit that ran the failed job, instead of trusting on always looking on the one right before or right after. * Job names have parts we don't care, like shards indices. As a failure could happen in any shard we want to find any shard with the same failure; Things I tried and don't lead to great results: * ignoring error classification - too low precision, not significant increase in recall * not requiring error repetition - too low precision, not significant increase in recall My take: With a precision of 10% it justifies the cost of re-running jobs in order to confirm redness status, even if it is not possible to test, I suspect that the fact we force require the same output 2 times for all 3 signals, this should elevate the precision to a very high standard. Unfortunately the only way to test is run this in shadow mode. With a recall of 55%, it points out to being able to capture **most** of the introduced trunk redness errors. Lots of reverts might not be caused by ci redness, especially not in the workflows we are analyzing (could be performance degradation, GHF/internal reasons and many others). This number seems comfortable to provide a substantial gain in benefit for CI quality.
1 parent 3d3500e commit 5f86d76

File tree

3 files changed

+167
-68
lines changed

3 files changed

+167
-68
lines changed

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/__main__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ def get_opts() -> argparse.Namespace:
6565
# no subcommand runs the lambda flow
6666
subparsers = parser.add_subparsers(dest="subcommand")
6767

68-
# autorevert subcommand
68+
# autorevert-checker subcommand
6969
workflow_parser = subparsers.add_parser(
70-
"autorevert", help="Analyze workflows looking for autorevert patterns"
70+
"autorevert-checker", help="Analyze workflows looking for autorevert patterns"
7171
)
7272
workflow_parser.add_argument(
7373
"workflows",
@@ -85,9 +85,9 @@ def get_opts() -> argparse.Namespace:
8585
help="Show detailed output including commit summaries",
8686
)
8787

88-
# workflow-restart-checke subcommand
88+
# workflow-restart-checker subcommand
8989
workflow_restart_parser = subparsers.add_parser(
90-
"workflow-restart-checke", help="Check for restarted workflows"
90+
"workflow-restart-checker", help="Check for restarted workflows"
9191
)
9292
workflow_restart_parser.add_argument(
9393
"workflow",
@@ -145,7 +145,7 @@ def main(*args, **kwargs) -> None:
145145

146146
if opts.subcommand == "lambda":
147147
print("TODO: run lambda flow")
148-
elif opts.subcommand == "workflows":
148+
elif opts.subcommand == "autorevert-checker":
149149
autorevert_checker(opts.workflows, hours=opts.hours, verbose=opts.verbose)
150150
elif opts.subcommand == "workflow-restart-checker":
151151
workflow_restart_checker(opts.workflow, commit=opts.commit, days=opts.days)

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/autorevert_checker.py

Lines changed: 129 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import re
88
from dataclasses import dataclass
99
from datetime import datetime, timedelta
10-
from typing import Dict, List, Optional, Set
10+
from typing import Dict, Iterable, List, Optional, Set, Tuple
1111

1212
from .clickhouse_client_helper import CHCliFactory
1313

@@ -47,15 +47,15 @@ def has_pending_jobs(self) -> bool:
4747
@property
4848
def job_base_names(self) -> Set[str]:
4949
if not hasattr(self, "_job_base_names"):
50-
self._job_base_names = self._get_job_base_names()
50+
self._job_base_names = self.get_job_base_names()
5151
return self._job_base_names
5252

5353
def normalize_job_name(self, name: str) -> str:
5454
"""Strip shard suffix from job name for matching."""
5555
# Remove patterns like ", 1, 1, " or ", 2, 3, " from job names
5656
return re.sub(r", \d+, \d+, ", ", ", name)
5757

58-
def _get_job_base_names(self) -> Set[str]:
58+
def get_job_base_names(self) -> Set[str]:
5959
"""Get normalized job names (without shard info)."""
6060
return {self.normalize_job_name(j.name) for j in self.jobs}
6161

@@ -107,13 +107,17 @@ def _fetch_workflow_data(self):
107107
name,
108108
conclusion,
109109
status,
110-
torchci_classification.rule as classification_rule,
111-
workflow_created_at
112-
FROM workflow_job FINAL
113-
WHERE workflow_name IN {workflow_names:Array(String)}
114-
AND head_branch = 'main'
115-
AND workflow_created_at >= {lookback_time:DateTime}
116-
ORDER BY workflow_name, workflow_created_at DESC, head_sha, name
110+
torchci_classification.rule AS classification_rule,
111+
created_at AS workflow_created_at
112+
FROM
113+
workflow_job FINAL
114+
WHERE
115+
workflow_name IN {workflow_names:Array(String)}
116+
AND head_branch = 'main'
117+
AND created_at >= {lookback_time:DateTime}
118+
AND dynamoKey LIKE 'pytorch/pytorch/%'
119+
ORDER BY
120+
workflow_name, workflow_created_at DESC, head_sha, name
117121
"""
118122

119123
result = CHCliFactory().client.query(
@@ -194,6 +198,31 @@ def _fetch_commit_history(self):
194198
for row in result.result_rows
195199
]
196200

201+
def _find_last_commit_with_job(
202+
self, commits: Iterable[CommitJobs], job_name: str
203+
) -> Optional[Tuple[CommitJobs, List[JobResult]]]:
204+
"""
205+
Find the last commit in the iterable that has a job with the specified name.
206+
207+
Args:
208+
commits: Iterable of CommitJobs to search
209+
job_name: The job name to look for
210+
211+
Returns:
212+
The last CommitJobs object that contains the specified job, or None if not found
213+
"""
214+
job_results = []
215+
for commit in commits:
216+
for job in commit.jobs:
217+
if job.name.split("(")[0] == job_name: # Normalize job name
218+
job_results.append(job)
219+
if job_results:
220+
return (
221+
commit,
222+
job_results,
223+
)
224+
return None, None
225+
197226
def detect_autorevert_pattern_workflow(self, workflow_name: str) -> List[Dict]:
198227
"""
199228
Detect all autorevert patterns in commit job data for a specific workflow.
@@ -215,60 +244,90 @@ def detect_autorevert_pattern_workflow(self, workflow_name: str) -> List[Dict]:
215244

216245
patterns = []
217246

218-
for i in range(len(commits) - 2):
219-
newer_commit1 = commits[i] # Most recent
220-
newer_commit2 = commits[i + 1] # Second most recent
221-
older_commit = commits[i + 2] # Third most recent
222-
223-
# All commits must have jobs (signal)
224-
if not all(c.jobs for c in [newer_commit1, newer_commit2, older_commit]):
225-
continue
247+
for i in range(1, len(commits) - 1):
248+
suspected_commit1 = commits[i] # The commit we want to check for failure
226249

227-
# Oldest commit cannot have pending jobs
228-
if older_commit.has_pending_jobs:
250+
if suspected_commit1.has_pending_jobs:
229251
continue
230252

231-
# Find common failure classifications between the 2 newer commits
232-
newer1_failures = {j.classification_rule for j in newer_commit1.failed_jobs}
233-
newer2_failures = {j.classification_rule for j in newer_commit2.failed_jobs}
234-
common_failures = newer1_failures & newer2_failures
253+
suspected_failures = {
254+
(
255+
j.classification_rule,
256+
j.name.split("(")[0],
257+
)
258+
for j in suspected_commit1.failed_jobs
259+
}
260+
261+
common_failures = set()
262+
for (
263+
suspected_failure_class_rule,
264+
suspected_failure_job_name,
265+
) in suspected_failures:
266+
newer_commit_same_job, newer_same_jobs = (
267+
self._find_last_commit_with_job(
268+
(commits[j] for j in range(i - 1, -1, -1)),
269+
suspected_failure_job_name,
270+
)
271+
)
272+
if not newer_commit_same_job or not newer_same_jobs:
273+
# No older commit with the same job found
274+
continue
275+
276+
if any(
277+
j.classification_rule == suspected_failure_class_rule
278+
for j in newer_same_jobs
279+
):
280+
# The newer commit has the same job failing
281+
common_failures.add(
282+
(
283+
suspected_failure_class_rule,
284+
suspected_failure_job_name,
285+
)
286+
)
235287

236288
if not common_failures:
237289
continue
238290

239-
# Check if older commit lacks these failures but has overlapping job coverage
240-
older_failures = {j.classification_rule for j in older_commit.failed_jobs}
241-
older_job_names = older_commit.get_job_base_names()
242-
243-
for failure_rule in common_failures:
244-
if failure_rule in older_failures:
245-
continue # Older commit also has this failure
246-
247-
# Get job names that had this failure in newer commits
248-
failed_job_names = set()
249-
for commit in [newer_commit1, newer_commit2]:
250-
for job in commit.failed_jobs:
251-
if job.classification_rule == failure_rule:
252-
failed_job_names.add(commit.normalize_job_name(job.name))
253-
254-
# Check if older commit has overlapping job coverage
255-
if failed_job_names & older_job_names:
256-
patterns.append(
257-
{
258-
"pattern_detected": True,
259-
"workflow_name": workflow_name,
260-
"failure_rule": failure_rule,
261-
"newer_commits": [
262-
newer_commit1.head_sha,
263-
newer_commit2.head_sha,
264-
],
265-
"older_commit": older_commit.head_sha,
266-
"failed_job_names": list(failed_job_names),
267-
"older_job_coverage": list(
268-
older_job_names & failed_job_names
269-
),
270-
}
291+
for failure_rule, job_name in common_failures:
292+
last_commit_with_same_job, last_same_jobs = (
293+
self._find_last_commit_with_job(
294+
(commits[j] for j in range(i + 1, len(commits))), job_name
271295
)
296+
)
297+
298+
if not last_commit_with_same_job or not last_same_jobs:
299+
# No older commit with the same job found
300+
continue
301+
302+
if any(
303+
j.name.split("(")[0] != job_name
304+
for j in last_commit_with_same_job.failed_jobs
305+
):
306+
# newr commit has the same job failing
307+
continue
308+
309+
if any(
310+
j.classification_rule == suspected_failure_class_rule
311+
for j in last_same_jobs
312+
):
313+
# The last commit with the same job has the same failure classification
314+
continue
315+
316+
patterns.append(
317+
{
318+
"pattern_detected": True,
319+
"workflow_name": workflow_name,
320+
"failure_rule": failure_rule,
321+
"newer_commits": [
322+
"newer_commit_same_job.head_sha",
323+
suspected_commit1.head_sha,
324+
],
325+
"older_commit": "last_commit_with_same_job.head_sha",
326+
"failed_job_names": list("last_same_job.name"),
327+
"older_job_coverage": [],
328+
}
329+
)
330+
break
272331

273332
return patterns
274333

@@ -314,6 +373,20 @@ def detect_autorevert_pattern(self) -> List[Dict]:
314373

315374
return all_patterns
316375

376+
def get_commits_reverted(self) -> Set[str]:
377+
"""
378+
Get all commits that were reverted within the lookback window.
379+
380+
Returns:
381+
List of revert information dictionaries
382+
"""
383+
reverted_commits = set()
384+
for commit in self.commit_history:
385+
revert_info = self.is_commit_reverted(commit["sha"])
386+
if revert_info:
387+
reverted_commits.add(commit["sha"])
388+
return reverted_commits
389+
317390
def is_commit_reverted(self, target_commit_sha: str) -> Optional[Dict]:
318391
"""
319392
Check if a commit was reverted within the lookback window.

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/testers/autorevert.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
from collections import defaultdict
2+
13
from ..autorevert_checker import AutorevertPatternChecker
2-
from ..clickhouse_client_helper import CHCliFactory
34

45

56
def autorevert_checker(
@@ -44,6 +45,8 @@ def autorevert_checker(
4445

4546
# Detect patterns
4647
patterns = checker.detect_autorevert_pattern()
48+
reverts = checker.get_commits_reverted()
49+
not_found_reverts = reverts.copy()
4750

4851
if patterns:
4952
print(
@@ -52,15 +55,14 @@ def autorevert_checker(
5255

5356
# Create a revert checker (with extended lookback for finding reverts)
5457
revert_checker = AutorevertPatternChecker(
55-
CHCliFactory().client, workflow_names=[], lookback_hours=hours * 2
58+
workflow_names=[], lookback_hours=hours * 2
5659
)
5760

5861
# Track reverts
5962
reverted_patterns = []
6063

6164
for i, pattern in enumerate(patterns, 1):
62-
if len(patterns) > 1:
63-
print(f"\nPattern #{i}:")
65+
print(f"\nPattern #{i}:")
6466

6567
print(f"Failure rule: '{pattern['failure_rule']}'")
6668
print(
@@ -83,13 +85,14 @@ def autorevert_checker(
8385
revert_result = revert_checker.is_commit_reverted(second_commit)
8486

8587
if revert_result:
88+
not_found_reverts.discard(second_commit)
8689
print(
8790
f"✓ REVERTED: {second_commit[:8]} was reverted by {revert_result['revert_sha'][:8]} "
8891
f"after {revert_result['hours_after_target']:.1f} hours"
8992
)
9093
reverted_patterns.append(pattern)
9194
else:
92-
print(f"✗ NOT REVERTED: {second_commit[:8]} was not reverted")
95+
print(f"✗ NOT REVERTED: {second_commit} was not reverted")
9396

9497
if verbose:
9598
print(f"Failed jobs ({len(pattern['failed_job_names'])}):")
@@ -121,11 +124,34 @@ def autorevert_checker(
121124
)
122125
print(f"Commits checked: {total_commits}")
123126

124-
print(f"Patterns detected: {len(patterns)}")
127+
print(f"Auto revert patterns detected: {len(patterns)}")
128+
print(
129+
"Actual reverts inside auto revert patterns detected (precision): "
130+
+ f"{len(reverted_patterns)} ({len(reverted_patterns)/len(patterns)*100:.1f}%)"
131+
)
132+
print(f"Total revert commits in period: {len(reverts)}")
125133
print(
126-
f"Actual reverts: {len(reverted_patterns)} ({len(reverted_patterns)/len(patterns)*100:.1f}%)"
134+
"Reverts that dont match any auto revert pattern detected (recall): "
135+
+ f"{len(not_found_reverts)} ({len(not_found_reverts)/len(reverts)*100:.1f}%)"
127136
)
128137

138+
workflow_statistics = defaultdict(lambda: {"match_pattern": 0, "reverts": 0})
139+
for pattern in patterns:
140+
workflow_statistics[pattern["workflow_name"]]["match_pattern"] += 1
141+
if pattern["newer_commits"][1] in reverts:
142+
workflow_statistics[pattern["workflow_name"]]["reverts"] += 1
143+
144+
print("Per workflow precision:")
145+
for workflow, stats in workflow_statistics.items():
146+
precision = (
147+
stats["reverts"] / stats["match_pattern"] * 100
148+
if stats["match_pattern"] > 0
149+
else 0.0
150+
)
151+
print(
152+
f" {workflow}: {stats['reverts']} reverts out of {stats['match_pattern']} patterns ({precision:.1f}%)"
153+
)
154+
129155
if reverted_patterns:
130156
print("\nReverted patterns:")
131157
for pattern in reverted_patterns:

0 commit comments

Comments
 (0)