Skip to content

Commit 3c1b952

Browse files
committed
tc
1 parent af35d79 commit 3c1b952

File tree

1 file changed

+63
-4
lines changed

1 file changed

+63
-4
lines changed

tools/torchci/td/get_reverts_caused_by_td.py

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from concurrent.futures import ThreadPoolExecutor
1313
from dataclasses import dataclass
1414
from functools import lru_cache
15-
from typing import Optional
15+
from typing import Any, Optional
1616

1717
import requests
1818
from torchci.clickhouse import query_clickhouse
@@ -23,6 +23,7 @@
2323
class JobFailure:
2424
torchci_classification_line: str
2525
job_name: str
26+
run_id: int
2627
failed_test: Optional[str] = None
2728

2829

@@ -89,14 +90,15 @@ def __exit__(self, exc_type, exc_val, exc_tb):
8990
TORCHCI_CLASSIFICATION_QUERY = """
9091
select
9192
name as job_name,
93+
run_id as run_id,
9294
torchci_classification.line as line,
9395
head_sha
9496
from
9597
default.workflow_job
9698
where
9799
head_sha in {shas: Array(String)}
98100
and conclusion = 'failure'
99-
and workflow_name = 'pull'
101+
and workflow_name in ('pull', 'trunk', 'periodic', 'slow')
100102
"""
101103

102104
WORKFLOW_ID_QUERY = """
@@ -108,7 +110,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
108110
default .workflow_run
109111
where
110112
head_sha in {shas: Array(String) }
111-
and name = 'pull'
113+
and name in ('pull', 'trunk', 'periodic', 'slow')
112114
"""
113115

114116

@@ -175,6 +177,19 @@ def get_td_exclusions(run_id: int) -> dict:
175177
return {}
176178

177179

180+
@lru_cache
181+
def get_failures_additional_test_info(
182+
run_id: int,
183+
) -> dict[str, Any]:
184+
"""Fetches additional test info for failures in the given run_id."""
185+
for i in range(3):
186+
url = f"https://ossci-raw-job-status.s3.amazonaws.com/additional_info/reruns/{run_id}/{i + 1}"
187+
response = requests.get(url)
188+
if response.status_code == 200:
189+
return response.json()
190+
return {}
191+
192+
178193
def get_test_file(torchci_classification_line: str) -> Optional[str]:
179194
"""Extracts the test file from the torchci classification line."""
180195
match = re.search(FAILED_TEST_REGEX, torchci_classification_line)
@@ -272,7 +287,11 @@ def process_sha(i: int) -> Optional[CommitInfo]:
272287
alt_last_pr_sha = (row["head_sha"], timestamp)
273288
if alt_last_pr_sha[0] != commit.last_pr_sha and commit.last_pr_sha is not None:
274289
p.print(
275-
f"for commit {commit.id} with pr {commit.pr_num}, found last pr sha != alt, {commit.last_pr_sha} != {alt_last_pr_sha[0]}"
290+
f"commit={commit.id} "
291+
f"pr={commit.pr_num} "
292+
f"merge={commit.merge_commit_sha} "
293+
f"timestamp_of_merge={commit.timestamp_of_merge} "
294+
f"found last pr sha != alt, {commit.last_pr_sha} != {alt_last_pr_sha[0]}"
276295
)
277296
bad += 1
278297
if commit.last_pr_sha is None:
@@ -325,19 +344,59 @@ def get_job_failures(shas: list[str]) -> dict[str, list[JobFailure]]:
325344
for row in job_failures:
326345
head_sha = row["head_sha"]
327346
job_name = row["job_name"]
347+
run_id = row["run_id"]
328348
line = row["line"]
329349
if head_sha not in failures_dict:
330350
failures_dict[head_sha] = []
331351
failures_dict[head_sha].append(
332352
JobFailure(
333353
torchci_classification_line=line,
334354
job_name=job_name,
355+
run_id=int(run_id),
335356
failed_test=get_test_file(line),
336357
)
337358
)
359+
del futures
360+
361+
futures2 = []
362+
with ThreadPoolExecutor(max_workers=8) as executor:
363+
for sha, failures in failures_dict.items():
364+
run_ids = set(f.run_id for f in failures if f.run_id is not None)
365+
for run_id in run_ids:
366+
futures2.append((sha, executor.submit(get_failures_for_run_id, run_id)))
367+
for sha, future in futures2:
368+
additional_failures = future.result()
369+
failures_dict[sha].extend(additional_failures)
338370
return failures_dict
339371

340372

373+
@lru_cache
374+
def get_failures_for_run_id(run_id: int) -> list[JobFailure]:
375+
"""Fetches the failures for the given run_id."""
376+
failures = get_failures_additional_test_info(run_id)
377+
job_failures = []
378+
for build, d in failures.items():
379+
for test_config, dd in d.items():
380+
for test_file, ddd in dd.items():
381+
for test_class, dddd in ddd.items():
382+
for test_name, info in dddd.items():
383+
failed = True
384+
for i in info:
385+
if "failure" not in i:
386+
failed = False
387+
if failed:
388+
job_failures.append(
389+
JobFailure(
390+
torchci_classification_line=f"{test_file}::{test_class}::{test_name}",
391+
job_name=f"{build} / test ({test_config}, 1, 1, runner)",
392+
run_id=run_id,
393+
failed_test=f"{test_file}",
394+
)
395+
)
396+
397+
return job_failures
398+
399+
341400
def check_failure_in_td_exclusion(f: JobFailure, run_id: int) -> bool:
342401
"""True if the commit is bad (excluded in TD)"""
343402
x = re.search(JOB_NAME_REGEX, f.job_name)

0 commit comments

Comments
 (0)