6
6
past year. It expects the folder setup to have test-infra and pytorch in the
7
7
same folder, and will use whatever branch is currently checked out on pytorch.
8
8
"""
9
+
10
+ import argparse
11
+ import re
12
+ from concurrent .futures import ThreadPoolExecutor
13
+ from dataclasses import dataclass
9
14
from functools import lru_cache
15
+ from typing import Optional
16
+
10
17
import requests
11
- from torchci .utils import run_command
12
18
from torchci .clickhouse import query_clickhouse
13
- import re
14
- from dataclasses import dataclass
15
- from concurrent .futures import ThreadPoolExecutor
16
- import argparse
19
+ from torchci .utils import run_command
20
+
17
21
18
22
@dataclass
19
23
class JobFailure :
20
- torchci_classification_line : str | None = None
21
- job_name : str | None = None
22
- failed_test : str | None = None
24
+ torchci_classification_line : str
25
+ job_name : str
26
+ failed_test : Optional [ str ] = None
23
27
24
28
25
29
@dataclass
26
30
class CommitInfo :
27
31
id : str
28
- last_pr_sha : str | None = None
29
- merge_commit_sha : str | None = None
30
- merge_commit_sha_prev : str | None = None
31
- revert_commit_sha : str | None = None
32
- revert_commit_sha_prev : str | None = None
32
+ last_pr_sha : Optional [ str ] = None
33
+ merge_commit_sha : str
34
+ merge_commit_sha_prev : str
35
+ revert_commit_sha : str
36
+ revert_commit_sha_prev : str
33
37
timestamp_of_revert : int = 0
34
38
timestamp_of_merge : int = 0
35
39
pr_num : int = 0
36
- run_id : str | None = None
40
+ run_id : Optional [ int ] = None
37
41
38
42
39
43
class IndentPrinter :
@@ -62,13 +66,13 @@ def __exit__(self, exc_type, exc_val, exc_tb):
62
66
p = IndentPrinter ()
63
67
64
68
# Match against things like Reverted https://github.com/pytorch/pytorch/pull/155998 on behalf of https://github.com/malfet due to
65
- REVERT_REGEX = (
66
- r"(?s)This reverts commit (.*)\..*Reverted https:\/\/github.com\/pytorch\/pytorch\/pull\/(\d+) on behalf of"
67
- )
69
+ REVERT_REGEX = r"(?s)This reverts commit (.*)\..*Reverted https:\/\/github.com\/pytorch\/pytorch\/pull\/(\d+) on behalf of"
68
70
# Matches stuff like FAILED [2.1965s] inductor/test_analysis.py::TestAnalysisCUDA::test_augment_trace_against_flop_counter_maxat0_cuda_float16 - IndexError: list index out of range
69
71
FAILED_TEST_REGEX = r"FAILED \[.*\] (.*)\.py::.*"
70
72
# Matches stuff like The following tests failed consistently: ['test/inductor/test_distributed_patterns.py::DistributedPatternTests::test_nn_param_return3']
71
- CONSISTENTLY_FAILED_TEST_REGEX = r"The following tests failed consistently: \['test/(.*).py::.*'\]"
73
+ CONSISTENTLY_FAILED_TEST_REGEX = (
74
+ r"The following tests failed consistently: \['test/(.*).py::.*'\]"
75
+ )
72
76
73
77
JOB_NAME_REGEX = r"(.*) / test \(([^,]*), .*\)"
74
78
@@ -146,7 +150,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
146
150
"""
147
151
148
152
149
- def get_git_log () -> None :
153
+ def get_git_log () -> list [ tuple [ str , int , str ]] :
150
154
"""Fetches commit sha and message for all commits"""
151
155
return [
152
156
line .split (" " , 2 )
@@ -171,7 +175,7 @@ def get_td_exclusions(run_id: int) -> dict:
171
175
return {}
172
176
173
177
174
- def get_test_file (torchci_classification_line : str ) -> str | None :
178
+ def get_test_file (torchci_classification_line : str ) -> Optional [ str ] :
175
179
"""Extracts the test file from the torchci classification line."""
176
180
match = re .search (FAILED_TEST_REGEX , torchci_classification_line )
177
181
if match :
@@ -188,7 +192,7 @@ def get_commit_info(num_to_process: int) -> list[CommitInfo]:
188
192
commits_reverted : list [CommitInfo ] = []
189
193
sha_to_idx = {sha [0 ]: i for i , sha in enumerate (shas )}
190
194
191
- def process_sha (i : int ) -> CommitInfo | None :
195
+ def process_sha (i : int ) -> Optional [ CommitInfo ] :
192
196
item = shas [i ]
193
197
sha , timestamp , message = item
194
198
if not message .startswith ('Revert "' ) or not message .endswith ('"' ):
@@ -198,7 +202,9 @@ def process_sha(i: int) -> CommitInfo | None:
198
202
reverted_sha = x .group (1 )
199
203
reverted_pr = x .group (2 )
200
204
if reverted_sha not in sha_to_idx :
201
- p .print (f"Reverted commit { reverted_sha } not found in the log, skipping revert commit { sha } " )
205
+ p .print (
206
+ f"Reverted commit { reverted_sha } not found in the log, skipping revert commit { sha } "
207
+ )
202
208
return None
203
209
return CommitInfo (
204
210
id = sha ,
@@ -210,6 +216,7 @@ def process_sha(i: int) -> CommitInfo | None:
210
216
pr_num = int (reverted_pr ),
211
217
timestamp_of_merge = int (shas [sha_to_idx [reverted_sha ]][1 ]),
212
218
)
219
+ return None
213
220
214
221
with ThreadPoolExecutor (max_workers = 8 ) as executor :
215
222
results = list (executor .map (process_sha , range (num_to_process )))
@@ -236,11 +243,15 @@ def process_sha(i: int) -> CommitInfo | None:
236
243
while commit .merge_commit_sha not in run_ids_present :
237
244
commit .merge_commit_sha = shas [sha_to_idx [commit .merge_commit_sha ] - 1 ][0 ]
238
245
while commit .merge_commit_sha_prev not in run_ids_present :
239
- commit .merge_commit_sha_prev = shas [sha_to_idx [commit .merge_commit_sha_prev ] + 1 ][0 ]
246
+ commit .merge_commit_sha_prev = shas [
247
+ sha_to_idx [commit .merge_commit_sha_prev ] + 1
248
+ ][0 ]
240
249
while commit .revert_commit_sha not in run_ids_present :
241
250
commit .revert_commit_sha = shas [sha_to_idx [commit .revert_commit_sha ] - 1 ][0 ]
242
251
while commit .revert_commit_sha_prev not in run_ids_present :
243
- commit .revert_commit_sha_prev = shas [sha_to_idx [commit .revert_commit_sha_prev ] + 1 ][0 ]
252
+ commit .revert_commit_sha_prev = shas [
253
+ sha_to_idx [commit .revert_commit_sha_prev ] + 1
254
+ ][0 ]
244
255
245
256
# For ghstacked PRs, we might not have info about which sha got merged
246
257
# because it was merged as a stack, so we query to the most recent workflow
@@ -254,7 +265,10 @@ def process_sha(i: int) -> CommitInfo | None:
254
265
alt_last_pr_sha = ("" , 0 )
255
266
for row in ghstack_last_pr_commits :
256
267
timestamp = int (row ["timestamp" ])
257
- if int (row ["pr_number" ]) == commit .pr_num and alt_last_pr_sha [1 ] < timestamp < commit .timestamp_of_merge :
268
+ if (
269
+ int (row ["pr_number" ]) == commit .pr_num
270
+ and alt_last_pr_sha [1 ] < timestamp < commit .timestamp_of_merge
271
+ ):
258
272
alt_last_pr_sha = (row ["head_sha" ], timestamp )
259
273
if alt_last_pr_sha [0 ] != commit .last_pr_sha and commit .last_pr_sha is not None :
260
274
p .print (
@@ -263,20 +277,29 @@ def process_sha(i: int) -> CommitInfo | None:
263
277
bad += 1
264
278
if commit .last_pr_sha is None :
265
279
commit .last_pr_sha = alt_last_pr_sha [0 ]
266
- p .print (f"Found { bad } , { bad / len (commits_reverted ):<.2%} where last pr sha != alt last pr sha" )
280
+ p .print (
281
+ f"Found { bad } , { bad / len (commits_reverted ):<.2%} where last pr sha != alt last pr sha"
282
+ )
267
283
268
284
# Get the run_id for the jobs on the pr
269
285
run_ids = query_clickhouse (
270
286
WORKFLOW_ID_QUERY ,
271
- {"shas" : [x .last_pr_sha for x in commits_reverted if x .last_pr_sha is not None ]},
287
+ {
288
+ "shas" : [
289
+ x .last_pr_sha for x in commits_reverted if x .last_pr_sha is not None
290
+ ]
291
+ },
272
292
)
273
293
for row in run_ids :
274
294
run_id = row ["id" ]
275
295
head_sha = row ["head_sha" ]
276
296
created_at = row ["created_at" ]
277
297
for commit in commits_reverted :
278
- if commit .last_pr_sha == head_sha and created_at < commit .timestamp_of_merge :
279
- commit .run_id = run_id
298
+ if (
299
+ commit .last_pr_sha == head_sha
300
+ and created_at < commit .timestamp_of_merge
301
+ ):
302
+ commit .run_id = int (run_id )
280
303
281
304
return commits_reverted
282
305
@@ -306,7 +329,11 @@ def get_job_failures(shas: list[str]) -> dict[str, list[JobFailure]]:
306
329
if head_sha not in failures_dict :
307
330
failures_dict [head_sha ] = []
308
331
failures_dict [head_sha ].append (
309
- JobFailure (torchci_classification_line = line , job_name = job_name , failed_test = get_test_file (line ))
332
+ JobFailure (
333
+ torchci_classification_line = line ,
334
+ job_name = job_name ,
335
+ failed_test = get_test_file (line ),
336
+ )
310
337
)
311
338
return failures_dict
312
339
@@ -315,18 +342,24 @@ def check_failure_in_td_exclusion(f: JobFailure, run_id: int) -> bool:
315
342
"""True if the commit is bad (excluded in TD)"""
316
343
x = re .search (JOB_NAME_REGEX , f .job_name )
317
344
if x is None :
318
- p .print (f"Failed to parse job name { f .job_name } for failure { f .torchci_classification_line } " )
345
+ p .print (
346
+ f"Failed to parse job name { f .job_name } for failure { f .torchci_classification_line } "
347
+ )
319
348
return False
320
349
321
350
td_exclusions = get_td_exclusions (run_id )
322
351
build_env = x .group (1 )
323
352
test_config = x .group (2 )
324
- p .print (f"Build environment: { build_env } , Test config: { test_config } , len(td_exclusions): { len (td_exclusions )} " )
353
+ p .print (
354
+ f"Build environment: { build_env } , Test config: { test_config } , len(td_exclusions): { len (td_exclusions )} "
355
+ )
325
356
if len (td_exclusions ) == 0 :
326
357
p .print (f"No TD exclusions found for run { run_id } " )
327
358
return False
328
359
if build_env not in td_exclusions :
329
- p .print (f"Build environment { build_env } not found in TD exclusions for run { run_id } " )
360
+ p .print (
361
+ f"Build environment { build_env } not found in TD exclusions for run { run_id } "
362
+ )
330
363
elif test_config not in td_exclusions [build_env ]:
331
364
p .print (f"Test { test_config } not found in TD exclusions for run { run_id } " )
332
365
elif f .failed_test in td_exclusions [build_env ][test_config ]:
@@ -337,7 +370,9 @@ def check_failure_in_td_exclusion(f: JobFailure, run_id: int) -> bool:
337
370
return False
338
371
339
372
340
- def check_on_commit (sha : str , job_name : str , test_file : str , failures : dict [str , list [JobFailure ]]) -> bool :
373
+ def check_on_commit (
374
+ sha : str , job_name : str , test_file : str , failures : dict [str , list [JobFailure ]]
375
+ ) -> bool :
341
376
"""True if the test failed on the given commit."""
342
377
for failure in failures .get (sha , []):
343
378
if failure .failed_test == test_file :
@@ -383,15 +418,26 @@ def main() -> None:
383
418
any_bad = False
384
419
for f in job_failures .get (s .merge_commit_sha , []):
385
420
with p :
386
- p .print (f"Failure: { f .job_name } , { f .torchci_classification_line } , { f .failed_test } " )
421
+ p .print (
422
+ f"Failure: { f .job_name } , { f .torchci_classification_line } , { f .failed_test } "
423
+ )
387
424
388
425
if f .failed_test is None :
389
426
continue
390
427
with p :
391
- if check_on_commit (s .revert_commit_sha , f .job_name , f .failed_test , job_failures ):
392
- p .print (f"Failure { f .failed_test } is present on the revert commit { s .revert_commit_sha } " )
428
+ if check_on_commit (
429
+ s .revert_commit_sha , f .job_name , f .failed_test , job_failures
430
+ ):
431
+ p .print (
432
+ f"Failure { f .failed_test } is present on the revert commit { s .revert_commit_sha } "
433
+ )
393
434
continue
394
- if check_on_commit (s .merge_commit_sha_prev , f .job_name , f .failed_test , job_failures ):
435
+ if check_on_commit (
436
+ s .merge_commit_sha_prev ,
437
+ f .job_name ,
438
+ f .failed_test ,
439
+ job_failures ,
440
+ ):
395
441
p .print (
396
442
f"Failure { f .failed_test } is present on commit before the merge { s .merge_commit_sha_prev } "
397
443
)
@@ -400,36 +446,47 @@ def main() -> None:
400
446
any_bad |= check_failure_in_td_exclusion (f , s .run_id )
401
447
if any_bad :
402
448
caused_by_bad_td .append (s )
403
- p .print (f"Commit { s .last_pr_sha } with run_id { s .run_id } is caused by bad TD" )
404
- p .print (f"CAUSED BY BAD TD: { len (caused_by_bad_td )} / { i + 1 } = { len (caused_by_bad_td ) / (i + 1 ):.2%} " )
405
- p .print (f"Unable to check (lack run id) on PR: { unable_to_check } / { i + 1 } = { unable_to_check / (i + 1 ):.2%} " )
449
+ p .print (
450
+ f"Commit { s .last_pr_sha } with run_id { s .run_id } is caused by bad TD"
451
+ )
452
+ p .print (
453
+ f"CAUSED BY BAD TD: { len (caused_by_bad_td )} / { i + 1 } = { len (caused_by_bad_td ) / (i + 1 ):.2%} "
454
+ )
455
+ p .print (
456
+ f"Unable to check (lack run id) on PR: { unable_to_check } / { i + 1 } = { unable_to_check / (i + 1 ):.2%} "
457
+ )
406
458
407
- p .print (f"Total caused by bad TD: { len (caused_by_bad_td )} / { len (commits_reverted )} = { len (caused_by_bad_td ) / len (commits_reverted ):.2%} " )
459
+ p .print (
460
+ f"Total caused by bad TD: { len (caused_by_bad_td )} / { len (commits_reverted )} = { len (caused_by_bad_td ) / len (commits_reverted ):.2%} "
461
+ )
408
462
# Group by month, this is a massive oversimplification, but we'll take it
409
463
month_groups = {}
410
464
for commit in caused_by_bad_td :
411
465
month = commit .timestamp_of_revert // (30 * 24 * 60 * 60 )
412
466
if month not in month_groups :
413
- month_groups [month ] = (0 ,0 )
467
+ month_groups [month ] = (0 , 0 )
414
468
month_groups [month ] = (month_groups [month ][0 ] + 1 , month_groups [month ][1 ])
415
469
for commit in commits_reverted :
416
470
month = commit .timestamp_of_merge // (30 * 24 * 60 * 60 )
417
471
if month not in month_groups :
418
- month_groups [month ] = (0 ,0 )
472
+ month_groups [month ] = (0 , 0 )
419
473
month_groups [month ] = (month_groups [month ][0 ], month_groups [month ][1 ] + 1 )
420
474
421
475
for month , (bad_td_count , total_count ) in sorted (month_groups .items ()):
422
- p .print (f"Month { month } : { bad_td_count } bad TD / { total_count } total = { bad_td_count / total_count :.2%} " )
476
+ p .print (
477
+ f"Month { month } : { bad_td_count } bad TD / { total_count } total = { bad_td_count / total_count :.2%} "
478
+ )
479
+
423
480
424
481
def parse_args () -> argparse .Namespace :
425
- parser = argparse .ArgumentParser (description = "Get reverts caused by bad TD exclusions." )
482
+ parser = argparse .ArgumentParser (
483
+ description = "Get reverts caused by bad TD exclusions."
484
+ )
426
485
parser .add_argument (
427
- "--num" ,
428
- type = int ,
429
- default = 2000 ,
430
- help = "Number of commits to examine"
486
+ "--num" , type = int , default = 2000 , help = "Number of commits to examine"
431
487
)
432
488
return parser .parse_args ()
433
489
490
+
434
491
if __name__ == "__main__" :
435
492
main ()
0 commit comments