diff --git a/.github/workflows/bin/spack-labeler.py b/.github/workflows/bin/spack-labeler.py index b68124c3085..209fb0d1293 100755 --- a/.github/workflows/bin/spack-labeler.py +++ b/.github/workflows/bin/spack-labeler.py @@ -28,21 +28,45 @@ def import_labels_config(path: str): def main(): + # Validate required environment variables + required_vars = ["LABELS_CONFIG", "GH_REPO", "GH_PR_NUMBER"] + missing_vars = [var for var in required_vars if var not in os.environ] + if missing_vars: + raise Exception(f"Missing required environment variables: {', '.join(missing_vars)}") + labels_config_path = os.environ["LABELS_CONFIG"] repository = os.environ["GH_REPO"] pr_number = os.environ["GH_PR_NUMBER"] - token = os.environ["GH_TOKEN"] + token = os.environ.get("GH_TOKEN", "") - headers = {"Accept": "application/vnd.github+json"} + headers = {"Accept": "application/vnd.github+json", "User-Agent": "spack-labeler"} if token: headers["Authorization"] = f"Bearer {token}" label_patterns = import_labels_config(labels_config_path) + # use a requests session to attempt to retry failed requests if the GitHub API fails + session = requests.Session() + retries = requests.adapters.Retry( + total=3, backoff_factor=1, status_forcelist=[500, 502, 503, 504] + ) + session.mount("https://", requests.adapters.HTTPAdapter(max_retries=retries)) + url = f"https://api.github.com/repos/{repository}/pulls/{pr_number}" - pull_request = requests.get(url, headers=headers).json() + pull_request_resp = session.get(url, headers=headers, timeout=30) + if pull_request_resp.status_code != 200: + raise Exception( + f"Failed to query GitHub API for PR info [{pull_request_resp.status_code}]: " + f"{pull_request_resp.text}" + ) - pull_request_files = requests.get(url + "/files", headers=headers) + pull_request = pull_request_resp.json() + pull_request_files = session.get(url + "/files", headers=headers, timeout=30) + if pull_request_files.status_code != 200: + raise Exception( + f"Failed to query GitHub API for PR files [{pull_request_files.status_code}]: " + f"{pull_request_files.text}" + ) labels = set() @@ -57,7 +81,7 @@ def main(): # the corresponding labels. for label, pattern_dict in label_patterns.items(): attr_matches = [] - # Pattern matches for for each attribute are or'd together + # Pattern matches for each attribute are or'd together for attr, patterns in pattern_dict.items(): # 'patch' is an example of an attribute that is not required to # appear in response when listing pull request files. See here: @@ -78,6 +102,9 @@ def main(): # Maintain non-managed labels (i.e. those that are not in label_patterns) labels.update(existing_labels.difference(label_patterns)) + if not token: + print("Warning: No GH_TOKEN defined, performing a local dry run only") + if existing_labels == labels: print(f"[PR #{pr_number}]: labels already up-to-date") return @@ -91,10 +118,11 @@ def main(): print(f"[PR #{pr_number}]: Removing label(s): [{'] ['.join(removed_labels)}]") if token: - resp = requests.put( + resp = session.put( f"https://api.github.com/repos/{repository}/issues/{pr_number}/labels", json={"labels": list(labels)}, headers=headers, + timeout=30, ) resp.raise_for_status()