Skip to content

Commit cd8e25f

Browse files
authored
[autorevert] Extract test retries as separate events (#7327)
This helps filter out signals from flaky tests (that succeed after rerun) and avoid reverting based on them. Such tests are excluded from HUD, so this makes autorevert more consistent with HUD. Changes: - Extract individual test reruns as separate Signal events - Dedup retains both outcomes - For each attempt, emit at most one FAILURE and one SUCCESS event - Skipped tests are excluded from "success" Extraction result before: [2025-10-08T18-48-39.575589-00-00.html](https://github.com/user-attachments/files/22786564/2025-10-08T18-48-39.575589-00-00.html) Extraction result after: [2025-10-08T21-30-45.676074-00-00.html](https://github.com/user-attachments/files/22786565/2025-10-08T21-30-45.676074-00-00.html) notice flaky signals like: pull:inductor/test_compile_subprocess.py::test_remove_noop_slice_scatter_cpu pull:inductor/test_compile_subprocess.py::test_remove_noop_slice1_cpu
1 parent ab66abc commit cd8e25f

File tree

5 files changed

+105
-57
lines changed

5 files changed

+105
-57
lines changed

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal_extraction.py

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929

3030
@dataclass(frozen=True)
3131
class TestOutcome:
32-
failing: bool
33-
errored: bool
32+
failure_runs: int # count of failed test runs
33+
success_runs: int # count of successful test runs
3434
started_at: datetime
3535
job_id: int
3636

@@ -107,9 +107,11 @@ def _dedup_signal_events(self, signals: List[Signal]) -> List[Signal]:
107107
new_commits: List[SignalCommit] = []
108108
for c in s.commits:
109109
filtered: List[SignalEvent] = []
110-
prev_key: Optional[Tuple[datetime, int]] = None
110+
# Include status in the key so we can retain both a FAILURE and
111+
# a SUCCESS emitted at the same (started_at, wf_run_id)
112+
prev_key: Optional[Tuple[datetime, int, SignalStatus]] = None
111113
for e in c.events: # already sorted by (started_at, wf_run_id)
112-
key = (e.started_at, e.wf_run_id)
114+
key = (e.started_at, e.wf_run_id, e.status)
113115
if key == prev_key:
114116
continue
115117
filtered.append(e)
@@ -259,7 +261,8 @@ def _build_test_signals(
259261
which base(s) (by normalized job name) a test appears in. For each commit and (workflow, base),
260262
we compute attempt metadata (pending/completed, start time). Then, for tests that failed at least once in
261263
that base, we emit events per commit/attempt:
262-
- If test_run_s3 rows exist → FAILURE if any failing/errored else SUCCESS
264+
- If test_run_s3 rows exist → emit at most one FAILURE event if any failed runs exist,
265+
and at most one SUCCESS event if any successful runs exist (both may be present).
263266
- Else if group pending → PENDING
264267
- Else → no event (missing)
265268
@@ -290,7 +293,8 @@ def _build_test_signals(
290293
value_fn=lambda j: (j.wf_run_id, j.run_attempt),
291294
)
292295

293-
# Index test_run_s3 rows per (commit, job_base, wf_run, attempt, test_id) and collect base-scoped failing tests
296+
# Index test_run_s3 rows per (commit, job_base, wf_run, attempt, test_id)
297+
# Store aggregated failure/success counts
294298
tests_by_group_attempt: Dict[
295299
Tuple[Sha, WorkflowName, JobBaseName, WfRunId, RunAttempt, TestId],
296300
TestOutcome,
@@ -309,24 +313,14 @@ def _build_test_signals(
309313
tr.workflow_run_attempt,
310314
tr.test_id,
311315
)
312-
prev = tests_by_group_attempt.get(key)
313-
started_at = min(
314-
(prev.started_at if prev else job.started_at), job.started_at
315-
)
316-
# Use job_id from the first failing test, or current job_id
317-
use_job_id = (
318-
prev.job_id
319-
if prev and (prev.failing or prev.errored)
320-
else int(tr.job_id)
321-
)
322316
outcome = TestOutcome(
323-
failing=(prev.failing if prev else False) or bool(tr.failing),
324-
errored=(prev.errored if prev else False) or bool(tr.errored),
325-
started_at=started_at,
326-
job_id=use_job_id,
317+
failure_runs=tr.failure_runs,
318+
success_runs=tr.success_runs,
319+
started_at=job.started_at,
320+
job_id=int(tr.job_id),
327321
)
328322
tests_by_group_attempt[key] = outcome
329-
if outcome.failing or outcome.errored:
323+
if outcome.failure_runs > 0:
330324
failing_tests_by_job_base_name.add(
331325
(job.workflow_name, job_base_name, tr.test_id)
332326
)
@@ -354,7 +348,7 @@ def _build_test_signals(
354348
if meta.is_cancelled:
355349
# canceled attempts are treated as missing
356350
continue
357-
verdict = tests_by_group_attempt.get(
351+
outcome = tests_by_group_attempt.get(
358352
(
359353
commit_sha,
360354
wf_name,
@@ -378,17 +372,26 @@ def _build_test_signals(
378372
"run_attempt": int(run_attempt),
379373
}
380374

381-
if verdict:
382-
events.append(
383-
SignalEvent(
384-
status=SignalStatus.FAILURE
385-
if (verdict.failing or verdict.errored)
386-
else SignalStatus.SUCCESS,
387-
started_at=verdict.started_at,
388-
job_id=verdict.job_id,
389-
**event_common,
375+
if outcome:
376+
# Emit at most one FAILURE and one SUCCESS per attempt
377+
if outcome.failure_runs > 0:
378+
events.append(
379+
SignalEvent(
380+
status=SignalStatus.FAILURE,
381+
started_at=outcome.started_at,
382+
job_id=outcome.job_id,
383+
**event_common,
384+
)
385+
)
386+
if outcome.success_runs > 0:
387+
events.append(
388+
SignalEvent(
389+
status=SignalStatus.SUCCESS,
390+
started_at=outcome.started_at,
391+
job_id=outcome.job_id,
392+
**event_common,
393+
)
390394
)
391-
)
392395
elif meta.is_pending:
393396
events.append(
394397
SignalEvent(

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal_extraction_datasource.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def fetch_tests_for_job_ids(
215215
)
216216
# One query with a CTE that enumerates failed test ids from failed_job_ids,
217217
# then filters the main selection by those ids for the current chunk.
218+
# Note: success_runs explicitly excludes skipped rows via skipped_count = 0.
218219
query = """
219220
WITH failed_test_names AS (
220221
SELECT DISTINCT concat(file, '|', classname, '|', name) AS test_id
@@ -223,8 +224,8 @@ def fetch_tests_for_job_ids(
223224
AND (failure_count > 0 OR error_count > 0)
224225
)
225226
SELECT job_id, workflow_id, workflow_run_attempt, file, classname, name,
226-
max(failure_count > 0) AS failing,
227-
max(error_count > 0) AS errored
227+
countIf(failure_count > 0 OR error_count > 0) AS failure_runs,
228+
countIf(failure_count = 0 AND error_count = 0 AND skipped_count = 0) AS success_runs
228229
FROM default.test_run_s3
229230
WHERE job_id IN {job_ids:Array(Int64)}
230231
AND concat(file, '|', classname, '|', name) IN failed_test_names
@@ -247,8 +248,8 @@ def fetch_tests_for_job_ids(
247248
file=str(r[3] or ""),
248249
classname=str(r[4] or ""),
249250
name=str(r[5] or ""),
250-
failing=int(r[6] or 0),
251-
errored=int(r[7] or 0),
251+
failure_runs=int(r[6] or 0),
252+
success_runs=int(r[7] or 0),
252253
)
253254
)
254255
dt = time.perf_counter() - t0

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal_extraction_types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,8 @@ class TestRow:
130130
file: str
131131
classname: str
132132
name: str
133-
failing: int # 0/1
134-
errored: int # 0/1
133+
failure_runs: int # count of failed test runs
134+
success_runs: int # count of successful test runs
135135

136136
@property
137137
def test_id(self) -> TestId:

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/tests/test_signal_dedup.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ class TestSignalDedup(unittest.TestCase):
1313
def setUp(self) -> None:
1414
self.t0 = datetime(2025, 8, 20, 12, 0, 0)
1515

16-
def test_dedup_removes_adjacent_duplicates(self):
17-
# Two events with identical (started_at, wf_run_id) within a single commit
16+
def test_dedup_keeps_both_statuses(self):
17+
# Two events with identical (started_at, wf_run_id) but different statuses
18+
# should both be retained after dedup (we dedup by (started_at, wf_run_id, status)).
1819
e1 = SignalEvent(
1920
name="job-a",
2021
status=SignalStatus.FAILURE,
@@ -33,10 +34,10 @@ def test_dedup_removes_adjacent_duplicates(self):
3334
ex = SignalExtractor(workflows=["wf"], lookback_hours=24)
3435
out = ex._dedup_signal_events([s])
3536
self.assertEqual(len(out), 1)
36-
self.assertEqual(len(out[0].commits[0].events), 1)
37-
# keeps the first encountered event for that pair
38-
self.assertEqual(out[0].commits[0].events[0].name, "job-a")
39-
self.assertEqual(out[0].commits[0].events[0].status, SignalStatus.FAILURE)
37+
# Both events survive because status differs
38+
self.assertEqual(len(out[0].commits[0].events), 2)
39+
statuses = {e.status for e in out[0].commits[0].events}
40+
self.assertEqual(statuses, {SignalStatus.FAILURE, SignalStatus.SUCCESS})
4041

4142
def test_dedup_keeps_non_duplicates(self):
4243
e1 = SignalEvent(
@@ -67,7 +68,7 @@ def test_dedup_keeps_non_duplicates(self):
6768
self.assertEqual(len(out[0].commits[0].events), 3)
6869

6970
def test_dedup_applies_per_commit(self):
70-
# Duplicates in different commits are not cross-deduped
71+
# Dedup applies per commit: each commit retains at most one event per status
7172
e1 = SignalEvent(
7273
name="job-a",
7374
status=SignalStatus.FAILURE,
@@ -86,8 +87,8 @@ def test_dedup_applies_per_commit(self):
8687

8788
ex = SignalExtractor(workflows=["wf"], lookback_hours=24)
8889
out = ex._dedup_signal_events([s])
89-
# Both commits should each have one event after dedup
90-
self.assertEqual([len(c.events) for c in out[0].commits], [1, 1])
90+
# Both commits should each have two events (one per status) after dedup
91+
self.assertEqual([len(c.events) for c in out[0].commits], [2, 2])
9192

9293

9394
if __name__ == "__main__":

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/tests/test_signal_extraction.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ def T(
9393
attempt: int,
9494
file: str,
9595
name: str,
96-
failing: int,
97-
errored: int = 0,
96+
failure_runs: int,
97+
success_runs: int = 0,
9898
):
9999
return TestRow(
100100
job_id=JobId(job),
@@ -103,8 +103,8 @@ def T(
103103
file=file,
104104
classname="",
105105
name=name,
106-
failing=failing,
107-
errored=errored,
106+
failure_runs=failure_runs,
107+
success_runs=success_runs,
108108
)
109109

110110

@@ -197,7 +197,17 @@ def test_keep_going_failure_test_track_failure_and_no_job_signal(self):
197197
rule="pytest failure",
198198
)
199199
]
200-
tests = [T(job=20, run=400, attempt=1, file="f.py", name="test_a", failing=1)]
200+
tests = [
201+
T(
202+
job=20,
203+
run=400,
204+
attempt=1,
205+
file="f.py",
206+
name="test_a",
207+
failure_runs=1,
208+
success_runs=0,
209+
)
210+
]
201211
signals = self._extract(jobs, tests)
202212
# test signal present with FAILURE
203213
test_sig = self._find_test_signal(signals, "trunk", "f.py::test_a")
@@ -259,8 +269,24 @@ def test_non_test_inclusion_gate(self):
259269
),
260270
]
261271
tests_a = [
262-
T(job=40, run=600, attempt=1, file="f.py", name="test_x", failing=1),
263-
T(job=41, run=610, attempt=1, file="f.py", name="test_x", failing=1),
272+
T(
273+
job=40,
274+
run=600,
275+
attempt=1,
276+
file="f.py",
277+
name="test_x",
278+
failure_runs=1,
279+
success_runs=0,
280+
),
281+
T(
282+
job=41,
283+
run=610,
284+
attempt=1,
285+
file="f.py",
286+
name="test_x",
287+
failure_runs=1,
288+
success_runs=0,
289+
),
264290
]
265291
signals_a = self._extract(jobs_a, tests_a)
266292
self.assertIsNone(
@@ -374,8 +400,24 @@ def test_test_track_mapping_failure_then_success(self):
374400
),
375401
]
376402
tests = [
377-
T(job=60, run=800, attempt=1, file="g.py", name="test_y", failing=1),
378-
T(job=61, run=810, attempt=1, file="g.py", name="test_y", failing=0),
403+
T(
404+
job=60,
405+
run=800,
406+
attempt=1,
407+
file="g.py",
408+
name="test_y",
409+
failure_runs=1,
410+
success_runs=0,
411+
),
412+
T(
413+
job=61,
414+
run=810,
415+
attempt=1,
416+
file="g.py",
417+
name="test_y",
418+
failure_runs=0,
419+
success_runs=1,
420+
),
379421
]
380422
signals = self._extract(jobs, tests)
381423
test_sig = self._find_test_signal(signals, "trunk", "g.py::test_y")
@@ -424,7 +466,8 @@ def test_inject_pending_workflow_event_when_missing_in_signal(self):
424466
attempt=1,
425467
file="m.py",
426468
name="test_synthetic_pending",
427-
failing=1,
469+
failure_runs=1,
470+
success_runs=0,
428471
)
429472
]
430473

0 commit comments

Comments
 (0)