Skip to content

Commit b4268b0

Browse files
committed
tc
1 parent 6cd5ea9 commit b4268b0

File tree

1 file changed

+107
-50
lines changed

1 file changed

+107
-50
lines changed

tools/torchci/td/get_reverts_caused_by_td.py

Lines changed: 107 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -6,34 +6,38 @@
66
past year. It expects the folder setup to have test-infra and pytorch in the
77
same folder, and will use whatever branch is currently checked out on pytorch.
88
"""
9+
10+
import argparse
11+
import re
12+
from concurrent.futures import ThreadPoolExecutor
13+
from dataclasses import dataclass
914
from functools import lru_cache
15+
from typing import Optional
16+
1017
import requests
11-
from torchci.utils import run_command
1218
from torchci.clickhouse import query_clickhouse
13-
import re
14-
from dataclasses import dataclass
15-
from concurrent.futures import ThreadPoolExecutor
16-
import argparse
19+
from torchci.utils import run_command
20+
1721

1822
@dataclass
1923
class JobFailure:
20-
torchci_classification_line: str | None = None
21-
job_name: str | None = None
22-
failed_test: str | None = None
24+
torchci_classification_line: str
25+
job_name: str
26+
failed_test: Optional[str] = None
2327

2428

2529
@dataclass
2630
class CommitInfo:
2731
id: str
28-
last_pr_sha: str | None = None
29-
merge_commit_sha: str | None = None
30-
merge_commit_sha_prev: str | None = None
31-
revert_commit_sha: str | None = None
32-
revert_commit_sha_prev: str | None = None
32+
last_pr_sha: Optional[str] = None
33+
merge_commit_sha: str
34+
merge_commit_sha_prev: str
35+
revert_commit_sha: str
36+
revert_commit_sha_prev: str
3337
timestamp_of_revert: int = 0
3438
timestamp_of_merge: int = 0
3539
pr_num: int = 0
36-
run_id: str | None = None
40+
run_id: Optional[int] = None
3741

3842

3943
class IndentPrinter:
@@ -62,13 +66,13 @@ def __exit__(self, exc_type, exc_val, exc_tb):
6266
p = IndentPrinter()
6367

6468
# Match against things like Reverted https://github.com/pytorch/pytorch/pull/155998 on behalf of https://github.com/malfet due to
65-
REVERT_REGEX = (
66-
r"(?s)This reverts commit (.*)\..*Reverted https:\/\/github.com\/pytorch\/pytorch\/pull\/(\d+) on behalf of"
67-
)
69+
REVERT_REGEX = r"(?s)This reverts commit (.*)\..*Reverted https:\/\/github.com\/pytorch\/pytorch\/pull\/(\d+) on behalf of"
6870
# Matches stuff like FAILED [2.1965s] inductor/test_analysis.py::TestAnalysisCUDA::test_augment_trace_against_flop_counter_maxat0_cuda_float16 - IndexError: list index out of range
6971
FAILED_TEST_REGEX = r"FAILED \[.*\] (.*)\.py::.*"
7072
# Matches stuff like The following tests failed consistently: ['test/inductor/test_distributed_patterns.py::DistributedPatternTests::test_nn_param_return3']
71-
CONSISTENTLY_FAILED_TEST_REGEX = r"The following tests failed consistently: \['test/(.*).py::.*'\]"
73+
CONSISTENTLY_FAILED_TEST_REGEX = (
74+
r"The following tests failed consistently: \['test/(.*).py::.*'\]"
75+
)
7276

7377
JOB_NAME_REGEX = r"(.*) / test \(([^,]*), .*\)"
7478

@@ -146,7 +150,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
146150
"""
147151

148152

149-
def get_git_log() -> None:
153+
def get_git_log() -> list[tuple[str, int, str]]:
150154
"""Fetches commit sha and message for all commits"""
151155
return [
152156
line.split(" ", 2)
@@ -171,7 +175,7 @@ def get_td_exclusions(run_id: int) -> dict:
171175
return {}
172176

173177

174-
def get_test_file(torchci_classification_line: str) -> str | None:
178+
def get_test_file(torchci_classification_line: str) -> Optional[str]:
175179
"""Extracts the test file from the torchci classification line."""
176180
match = re.search(FAILED_TEST_REGEX, torchci_classification_line)
177181
if match:
@@ -188,7 +192,7 @@ def get_commit_info(num_to_process: int) -> list[CommitInfo]:
188192
commits_reverted: list[CommitInfo] = []
189193
sha_to_idx = {sha[0]: i for i, sha in enumerate(shas)}
190194

191-
def process_sha(i: int) -> CommitInfo | None:
195+
def process_sha(i: int) -> Optional[CommitInfo]:
192196
item = shas[i]
193197
sha, timestamp, message = item
194198
if not message.startswith('Revert "') or not message.endswith('"'):
@@ -198,7 +202,9 @@ def process_sha(i: int) -> CommitInfo | None:
198202
reverted_sha = x.group(1)
199203
reverted_pr = x.group(2)
200204
if reverted_sha not in sha_to_idx:
201-
p.print(f"Reverted commit {reverted_sha} not found in the log, skipping revert commit {sha}")
205+
p.print(
206+
f"Reverted commit {reverted_sha} not found in the log, skipping revert commit {sha}"
207+
)
202208
return None
203209
return CommitInfo(
204210
id=sha,
@@ -210,6 +216,7 @@ def process_sha(i: int) -> CommitInfo | None:
210216
pr_num=int(reverted_pr),
211217
timestamp_of_merge=int(shas[sha_to_idx[reverted_sha]][1]),
212218
)
219+
return None
213220

214221
with ThreadPoolExecutor(max_workers=8) as executor:
215222
results = list(executor.map(process_sha, range(num_to_process)))
@@ -236,11 +243,15 @@ def process_sha(i: int) -> CommitInfo | None:
236243
while commit.merge_commit_sha not in run_ids_present:
237244
commit.merge_commit_sha = shas[sha_to_idx[commit.merge_commit_sha] - 1][0]
238245
while commit.merge_commit_sha_prev not in run_ids_present:
239-
commit.merge_commit_sha_prev = shas[sha_to_idx[commit.merge_commit_sha_prev] + 1][0]
246+
commit.merge_commit_sha_prev = shas[
247+
sha_to_idx[commit.merge_commit_sha_prev] + 1
248+
][0]
240249
while commit.revert_commit_sha not in run_ids_present:
241250
commit.revert_commit_sha = shas[sha_to_idx[commit.revert_commit_sha] - 1][0]
242251
while commit.revert_commit_sha_prev not in run_ids_present:
243-
commit.revert_commit_sha_prev = shas[sha_to_idx[commit.revert_commit_sha_prev] + 1][0]
252+
commit.revert_commit_sha_prev = shas[
253+
sha_to_idx[commit.revert_commit_sha_prev] + 1
254+
][0]
244255

245256
# For ghstacked PRs, we might not have info about which sha got merged
246257
# because it was merged as a stack, so we query to the most recent workflow
@@ -254,7 +265,10 @@ def process_sha(i: int) -> CommitInfo | None:
254265
alt_last_pr_sha = ("", 0)
255266
for row in ghstack_last_pr_commits:
256267
timestamp = int(row["timestamp"])
257-
if int(row["pr_number"]) == commit.pr_num and alt_last_pr_sha[1] < timestamp < commit.timestamp_of_merge:
268+
if (
269+
int(row["pr_number"]) == commit.pr_num
270+
and alt_last_pr_sha[1] < timestamp < commit.timestamp_of_merge
271+
):
258272
alt_last_pr_sha = (row["head_sha"], timestamp)
259273
if alt_last_pr_sha[0] != commit.last_pr_sha and commit.last_pr_sha is not None:
260274
p.print(
@@ -263,20 +277,29 @@ def process_sha(i: int) -> CommitInfo | None:
263277
bad += 1
264278
if commit.last_pr_sha is None:
265279
commit.last_pr_sha = alt_last_pr_sha[0]
266-
p.print(f"Found {bad}, {bad / len(commits_reverted):<.2%} where last pr sha != alt last pr sha")
280+
p.print(
281+
f"Found {bad}, {bad / len(commits_reverted):<.2%} where last pr sha != alt last pr sha"
282+
)
267283

268284
# Get the run_id for the jobs on the pr
269285
run_ids = query_clickhouse(
270286
WORKFLOW_ID_QUERY,
271-
{"shas": [x.last_pr_sha for x in commits_reverted if x.last_pr_sha is not None]},
287+
{
288+
"shas": [
289+
x.last_pr_sha for x in commits_reverted if x.last_pr_sha is not None
290+
]
291+
},
272292
)
273293
for row in run_ids:
274294
run_id = row["id"]
275295
head_sha = row["head_sha"]
276296
created_at = row["created_at"]
277297
for commit in commits_reverted:
278-
if commit.last_pr_sha == head_sha and created_at < commit.timestamp_of_merge:
279-
commit.run_id = run_id
298+
if (
299+
commit.last_pr_sha == head_sha
300+
and created_at < commit.timestamp_of_merge
301+
):
302+
commit.run_id = int(run_id)
280303

281304
return commits_reverted
282305

@@ -306,7 +329,11 @@ def get_job_failures(shas: list[str]) -> dict[str, list[JobFailure]]:
306329
if head_sha not in failures_dict:
307330
failures_dict[head_sha] = []
308331
failures_dict[head_sha].append(
309-
JobFailure(torchci_classification_line=line, job_name=job_name, failed_test=get_test_file(line))
332+
JobFailure(
333+
torchci_classification_line=line,
334+
job_name=job_name,
335+
failed_test=get_test_file(line),
336+
)
310337
)
311338
return failures_dict
312339

@@ -315,18 +342,24 @@ def check_failure_in_td_exclusion(f: JobFailure, run_id: int) -> bool:
315342
"""True if the commit is bad (excluded in TD)"""
316343
x = re.search(JOB_NAME_REGEX, f.job_name)
317344
if x is None:
318-
p.print(f"Failed to parse job name {f.job_name} for failure {f.torchci_classification_line}")
345+
p.print(
346+
f"Failed to parse job name {f.job_name} for failure {f.torchci_classification_line}"
347+
)
319348
return False
320349

321350
td_exclusions = get_td_exclusions(run_id)
322351
build_env = x.group(1)
323352
test_config = x.group(2)
324-
p.print(f"Build environment: {build_env}, Test config: {test_config}, len(td_exclusions): {len(td_exclusions)}")
353+
p.print(
354+
f"Build environment: {build_env}, Test config: {test_config}, len(td_exclusions): {len(td_exclusions)}"
355+
)
325356
if len(td_exclusions) == 0:
326357
p.print(f"No TD exclusions found for run {run_id}")
327358
return False
328359
if build_env not in td_exclusions:
329-
p.print(f"Build environment {build_env} not found in TD exclusions for run {run_id}")
360+
p.print(
361+
f"Build environment {build_env} not found in TD exclusions for run {run_id}"
362+
)
330363
elif test_config not in td_exclusions[build_env]:
331364
p.print(f"Test {test_config} not found in TD exclusions for run {run_id}")
332365
elif f.failed_test in td_exclusions[build_env][test_config]:
@@ -337,7 +370,9 @@ def check_failure_in_td_exclusion(f: JobFailure, run_id: int) -> bool:
337370
return False
338371

339372

340-
def check_on_commit(sha: str, job_name: str, test_file: str, failures: dict[str, list[JobFailure]]) -> bool:
373+
def check_on_commit(
374+
sha: str, job_name: str, test_file: str, failures: dict[str, list[JobFailure]]
375+
) -> bool:
341376
"""True if the test failed on the given commit."""
342377
for failure in failures.get(sha, []):
343378
if failure.failed_test == test_file:
@@ -383,15 +418,26 @@ def main() -> None:
383418
any_bad = False
384419
for f in job_failures.get(s.merge_commit_sha, []):
385420
with p:
386-
p.print(f"Failure: {f.job_name}, {f.torchci_classification_line}, {f.failed_test}")
421+
p.print(
422+
f"Failure: {f.job_name}, {f.torchci_classification_line}, {f.failed_test}"
423+
)
387424

388425
if f.failed_test is None:
389426
continue
390427
with p:
391-
if check_on_commit(s.revert_commit_sha, f.job_name, f.failed_test, job_failures):
392-
p.print(f"Failure {f.failed_test} is present on the revert commit {s.revert_commit_sha}")
428+
if check_on_commit(
429+
s.revert_commit_sha, f.job_name, f.failed_test, job_failures
430+
):
431+
p.print(
432+
f"Failure {f.failed_test} is present on the revert commit {s.revert_commit_sha}"
433+
)
393434
continue
394-
if check_on_commit(s.merge_commit_sha_prev, f.job_name, f.failed_test, job_failures):
435+
if check_on_commit(
436+
s.merge_commit_sha_prev,
437+
f.job_name,
438+
f.failed_test,
439+
job_failures,
440+
):
395441
p.print(
396442
f"Failure {f.failed_test} is present on commit before the merge {s.merge_commit_sha_prev}"
397443
)
@@ -400,36 +446,47 @@ def main() -> None:
400446
any_bad |= check_failure_in_td_exclusion(f, s.run_id)
401447
if any_bad:
402448
caused_by_bad_td.append(s)
403-
p.print(f"Commit {s.last_pr_sha} with run_id {s.run_id} is caused by bad TD")
404-
p.print(f"CAUSED BY BAD TD: {len(caused_by_bad_td)} / {i + 1} = {len(caused_by_bad_td) / (i + 1):.2%}")
405-
p.print(f"Unable to check (lack run id) on PR: {unable_to_check} / {i + 1} = {unable_to_check / (i + 1):.2%}")
449+
p.print(
450+
f"Commit {s.last_pr_sha} with run_id {s.run_id} is caused by bad TD"
451+
)
452+
p.print(
453+
f"CAUSED BY BAD TD: {len(caused_by_bad_td)} / {i + 1} = {len(caused_by_bad_td) / (i + 1):.2%}"
454+
)
455+
p.print(
456+
f"Unable to check (lack run id) on PR: {unable_to_check} / {i + 1} = {unable_to_check / (i + 1):.2%}"
457+
)
406458

407-
p.print(f"Total caused by bad TD: {len(caused_by_bad_td)} / {len(commits_reverted)} = {len(caused_by_bad_td) / len(commits_reverted):.2%}")
459+
p.print(
460+
f"Total caused by bad TD: {len(caused_by_bad_td)} / {len(commits_reverted)} = {len(caused_by_bad_td) / len(commits_reverted):.2%}"
461+
)
408462
# Group by month, this is a massive oversimplification, but we'll take it
409463
month_groups = {}
410464
for commit in caused_by_bad_td:
411465
month = commit.timestamp_of_revert // (30 * 24 * 60 * 60)
412466
if month not in month_groups:
413-
month_groups[month] = (0,0)
467+
month_groups[month] = (0, 0)
414468
month_groups[month] = (month_groups[month][0] + 1, month_groups[month][1])
415469
for commit in commits_reverted:
416470
month = commit.timestamp_of_merge // (30 * 24 * 60 * 60)
417471
if month not in month_groups:
418-
month_groups[month] = (0,0)
472+
month_groups[month] = (0, 0)
419473
month_groups[month] = (month_groups[month][0], month_groups[month][1] + 1)
420474

421475
for month, (bad_td_count, total_count) in sorted(month_groups.items()):
422-
p.print(f"Month {month}: {bad_td_count} bad TD / {total_count} total = {bad_td_count / total_count:.2%}")
476+
p.print(
477+
f"Month {month}: {bad_td_count} bad TD / {total_count} total = {bad_td_count / total_count:.2%}"
478+
)
479+
423480

424481
def parse_args() -> argparse.Namespace:
425-
parser = argparse.ArgumentParser(description="Get reverts caused by bad TD exclusions.")
482+
parser = argparse.ArgumentParser(
483+
description="Get reverts caused by bad TD exclusions."
484+
)
426485
parser.add_argument(
427-
"--num",
428-
type=int,
429-
default=2000,
430-
help="Number of commits to examine"
486+
"--num", type=int, default=2000, help="Number of commits to examine"
431487
)
432488
return parser.parse_args()
433489

490+
434491
if __name__ == "__main__":
435492
main()

0 commit comments

Comments
 (0)