Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions .github/workflows/bin/spack-labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,45 @@ def import_labels_config(path: str):


def main():
# Validate required environment variables
required_vars = ["LABELS_CONFIG", "GH_REPO", "GH_PR_NUMBER"]
missing_vars = [var for var in required_vars if var not in os.environ]
if missing_vars:
raise Exception(f"Missing required environment variables: {', '.join(missing_vars)}")

labels_config_path = os.environ["LABELS_CONFIG"]
repository = os.environ["GH_REPO"]
pr_number = os.environ["GH_PR_NUMBER"]
token = os.environ["GH_TOKEN"]
token = os.environ.get("GH_TOKEN", "")

headers = {"Accept": "application/vnd.github+json"}
headers = {"Accept": "application/vnd.github+json", "User-Agent": "spack-labeler"}
if token:
headers["Authorization"] = f"Bearer {token}"

label_patterns = import_labels_config(labels_config_path)

# use a requests session to attempt to retry failed requests if the GitHub API fails
session = requests.Session()
retries = requests.adapters.Retry(
total=3, backoff_factor=1, status_forcelist=[500, 502, 503, 504]
)
session.mount("https://", requests.adapters.HTTPAdapter(max_retries=retries))

url = f"https://api.github.com/repos/{repository}/pulls/{pr_number}"
pull_request = requests.get(url, headers=headers).json()
pull_request_resp = session.get(url, headers=headers, timeout=30)
if pull_request_resp.status_code != 200:
raise Exception(
f"Failed to query GitHub API for PR info [{pull_request_resp.status_code}]: "
f"{pull_request_resp.text}"
)

pull_request_files = requests.get(url + "/files", headers=headers)
pull_request = pull_request_resp.json()
pull_request_files = session.get(url + "/files", headers=headers, timeout=30)
if pull_request_files.status_code != 200:
raise Exception(
f"Failed to query GitHub API for PR files [{pull_request_files.status_code}]: "
f"{pull_request_files.text}"
)

labels = set()

Expand All @@ -57,7 +81,7 @@ def main():
# the corresponding labels.
for label, pattern_dict in label_patterns.items():
attr_matches = []
# Pattern matches for for each attribute are or'd together
# Pattern matches for each attribute are or'd together
for attr, patterns in pattern_dict.items():
# 'patch' is an example of an attribute that is not required to
# appear in response when listing pull request files. See here:
Expand All @@ -78,6 +102,9 @@ def main():
# Maintain non-managed labels (i.e. those that are not in label_patterns)
labels.update(existing_labels.difference(label_patterns))

if not token:
print("Warning: No GH_TOKEN defined, performing a local dry run only")

if existing_labels == labels:
print(f"[PR #{pr_number}]: labels already up-to-date")
return
Expand All @@ -91,10 +118,11 @@ def main():
print(f"[PR #{pr_number}]: Removing label(s): [{'] ['.join(removed_labels)}]")

if token:
resp = requests.put(
resp = session.put(
f"https://api.github.com/repos/{repository}/issues/{pr_number}/labels",
json={"labels": list(labels)},
headers=headers,
timeout=30,
)
resp.raise_for_status()

Expand Down
Loading