12
12
from concurrent .futures import ThreadPoolExecutor
13
13
from dataclasses import dataclass
14
14
from functools import lru_cache
15
- from typing import Optional
15
+ from typing import Any , Optional
16
16
17
17
import requests
18
18
from torchci .clickhouse import query_clickhouse
23
23
class JobFailure :
24
24
torchci_classification_line : str
25
25
job_name : str
26
+ run_id : int
26
27
failed_test : Optional [str ] = None
27
28
28
29
@@ -89,14 +90,15 @@ def __exit__(self, exc_type, exc_val, exc_tb):
89
90
TORCHCI_CLASSIFICATION_QUERY = """
90
91
select
91
92
name as job_name,
93
+ run_id as run_id,
92
94
torchci_classification.line as line,
93
95
head_sha
94
96
from
95
97
default.workflow_job
96
98
where
97
99
head_sha in {shas: Array(String)}
98
100
and conclusion = 'failure'
99
- and workflow_name = 'pull'
101
+ and workflow_name in ( 'pull', 'trunk', 'periodic', 'slow')
100
102
"""
101
103
102
104
WORKFLOW_ID_QUERY = """
@@ -108,7 +110,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
108
110
default .workflow_run
109
111
where
110
112
head_sha in {shas: Array(String) }
111
- and name = 'pull'
113
+ and name in ( 'pull', 'trunk', 'periodic', 'slow')
112
114
"""
113
115
114
116
@@ -175,6 +177,19 @@ def get_td_exclusions(run_id: int) -> dict:
175
177
return {}
176
178
177
179
180
+ @lru_cache
181
+ def get_failures_additional_test_info (
182
+ run_id : int ,
183
+ ) -> dict [str , Any ]:
184
+ """Fetches additional test info for failures in the given run_id."""
185
+ for i in range (3 ):
186
+ url = f"https://ossci-raw-job-status.s3.amazonaws.com/additional_info/reruns/{ run_id } /{ i + 1 } "
187
+ response = requests .get (url )
188
+ if response .status_code == 200 :
189
+ return response .json ()
190
+ return {}
191
+
192
+
178
193
def get_test_file (torchci_classification_line : str ) -> Optional [str ]:
179
194
"""Extracts the test file from the torchci classification line."""
180
195
match = re .search (FAILED_TEST_REGEX , torchci_classification_line )
@@ -272,7 +287,11 @@ def process_sha(i: int) -> Optional[CommitInfo]:
272
287
alt_last_pr_sha = (row ["head_sha" ], timestamp )
273
288
if alt_last_pr_sha [0 ] != commit .last_pr_sha and commit .last_pr_sha is not None :
274
289
p .print (
275
- f"for commit { commit .id } with pr { commit .pr_num } , found last pr sha != alt, { commit .last_pr_sha } != { alt_last_pr_sha [0 ]} "
290
+ f"commit={ commit .id } "
291
+ f"pr={ commit .pr_num } "
292
+ f"merge={ commit .merge_commit_sha } "
293
+ f"timestamp_of_merge={ commit .timestamp_of_merge } "
294
+ f"found last pr sha != alt, { commit .last_pr_sha } != { alt_last_pr_sha [0 ]} "
276
295
)
277
296
bad += 1
278
297
if commit .last_pr_sha is None :
@@ -325,19 +344,59 @@ def get_job_failures(shas: list[str]) -> dict[str, list[JobFailure]]:
325
344
for row in job_failures :
326
345
head_sha = row ["head_sha" ]
327
346
job_name = row ["job_name" ]
347
+ run_id = row ["run_id" ]
328
348
line = row ["line" ]
329
349
if head_sha not in failures_dict :
330
350
failures_dict [head_sha ] = []
331
351
failures_dict [head_sha ].append (
332
352
JobFailure (
333
353
torchci_classification_line = line ,
334
354
job_name = job_name ,
355
+ run_id = int (run_id ),
335
356
failed_test = get_test_file (line ),
336
357
)
337
358
)
359
+ del futures
360
+
361
+ futures2 = []
362
+ with ThreadPoolExecutor (max_workers = 8 ) as executor :
363
+ for sha , failures in failures_dict .items ():
364
+ run_ids = set (f .run_id for f in failures if f .run_id is not None )
365
+ for run_id in run_ids :
366
+ futures2 .append ((sha , executor .submit (get_failures_for_run_id , run_id )))
367
+ for sha , future in futures2 :
368
+ additional_failures = future .result ()
369
+ failures_dict [sha ].extend (additional_failures )
338
370
return failures_dict
339
371
340
372
373
+ @lru_cache
374
+ def get_failures_for_run_id (run_id : int ) -> list [JobFailure ]:
375
+ """Fetches the failures for the given run_id."""
376
+ failures = get_failures_additional_test_info (run_id )
377
+ job_failures = []
378
+ for build , d in failures .items ():
379
+ for test_config , dd in d .items ():
380
+ for test_file , ddd in dd .items ():
381
+ for test_class , dddd in ddd .items ():
382
+ for test_name , info in dddd .items ():
383
+ failed = True
384
+ for i in info :
385
+ if "failure" not in i :
386
+ failed = False
387
+ if failed :
388
+ job_failures .append (
389
+ JobFailure (
390
+ torchci_classification_line = f"{ test_file } ::{ test_class } ::{ test_name } " ,
391
+ job_name = f"{ build } / test ({ test_config } , 1, 1, runner)" ,
392
+ run_id = run_id ,
393
+ failed_test = f"{ test_file } " ,
394
+ )
395
+ )
396
+
397
+ return job_failures
398
+
399
+
341
400
def check_failure_in_td_exclusion (f : JobFailure , run_id : int ) -> bool :
342
401
"""True if the commit is bad (excluded in TD)"""
343
402
x = re .search (JOB_NAME_REGEX , f .job_name )
0 commit comments