diff --git a/pr_agent/algo/git_patch_processing.py b/pr_agent/algo/git_patch_processing.py index 553914e8d9..a3cea633af 100644 --- a/pr_agent/algo/git_patch_processing.py +++ b/pr_agent/algo/git_patch_processing.py @@ -9,12 +9,17 @@ # Optimized: Pre-compile the hunk header regex at the module level to avoid redundant compilation # in performance-critical patch processing functions. -RE_HUNK_HEADER = re.compile( - r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)") +RE_HUNK_HEADER = re.compile(r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)") -def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0, - patch_extra_lines_after=0, filename: str = "", new_file_str="") -> str: +def extend_patch( + original_file_str, + patch_str, + patch_extra_lines_before=0, + patch_extra_lines_after=0, + filename: str = "", + new_file_str="", +) -> str: if not patch_str or (patch_extra_lines_before == 0 and patch_extra_lines_after == 0) or not original_file_str: return patch_str @@ -27,8 +32,9 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0, return patch_str try: - extended_patch_str = process_patch_lines(patch_str, original_file_str, - patch_extra_lines_before, patch_extra_lines_after, new_file_str) + extended_patch_str = process_patch_lines( + patch_str, original_file_str, patch_extra_lines_before, patch_extra_lines_after, new_file_str + ) except Exception as e: get_logger().warning(f"Failed to extend patch: {e}", artifact={"traceback": traceback.format_exc()}) return patch_str @@ -39,9 +45,9 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0, def decode_if_bytes(original_file_str): if isinstance(original_file_str, (bytes, bytearray)): try: - return original_file_str.decode('utf-8') + return original_file_str.decode("utf-8") except UnicodeDecodeError: - encodings_to_try = ['iso-8859-1', 'latin-1', 'ascii', 'utf-16'] + encodings_to_try = ["iso-8859-1", "latin-1", "ascii", "utf-16"] for encoding in encodings_to_try: try: return original_file_str.decode(encoding) @@ -58,7 +64,9 @@ def should_skip_patch(filename): return False -def process_patch_lines(patch_str, original_file_str, patch_extra_lines_before, patch_extra_lines_after, new_file_str=""): +def process_patch_lines( + patch_str, original_file_str, patch_extra_lines_before, patch_extra_lines_after, new_file_str="" +): allow_dynamic_context = get_settings().config.allow_dynamic_context patch_extra_lines_before_dynamic = get_settings().config.max_extra_lines_before_dynamic_context @@ -71,14 +79,19 @@ def process_patch_lines(patch_str, original_file_str, patch_extra_lines_before, is_valid_hunk = True start1, size1, start2, size2 = -1, -1, -1, -1 try: - for i,line in enumerate(patch_lines): - if line.startswith('@@'): + for i, line in enumerate(patch_lines): + if line.startswith("@@"): match = RE_HUNK_HEADER.match(line) # identify hunk header if match: # finish processing previous hunk if is_valid_hunk and (start1 != -1 and patch_extra_lines_after > 0): - delta_lines_original = [f' {line}' for line in file_original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]] + delta_lines_original = [ + f" {line}" + for line in file_original_lines[ + start1 + size1 - 1 : start1 + size1 - 1 + patch_extra_lines_after + ] + ] extended_patch_lines.extend(delta_lines_original) section_header, size1, size2, start1, start2 = extract_hunk_headers(match) @@ -86,6 +99,7 @@ def process_patch_lines(patch_str, original_file_str, patch_extra_lines_before, is_valid_hunk = check_if_hunk_lines_matches_to_file(i, file_original_lines, patch_lines, start1) if is_valid_hunk and (patch_extra_lines_before > 0 or patch_extra_lines_after > 0): + def _calc_context_limits(patch_lines_before): extended_start1 = max(1, start1 - patch_lines_before) extended_size1 = size1 + (start1 - extended_start1) + patch_extra_lines_after @@ -99,11 +113,12 @@ def _calc_context_limits(patch_lines_before): return extended_start1, extended_size1, extended_start2, extended_size2 if allow_dynamic_context and file_new_lines: - extended_start1, extended_size1, extended_start2, extended_size2 = \ - _calc_context_limits(patch_extra_lines_before_dynamic) + extended_start1, extended_size1, extended_start2, extended_size2 = _calc_context_limits( + patch_extra_lines_before_dynamic + ) - lines_before_original = file_original_lines[extended_start1 - 1:start1 - 1] - lines_before_new = file_new_lines[extended_start2 - 1:start2 - 1] + lines_before_original = file_original_lines[extended_start1 - 1 : start1 - 1] + lines_before_new = file_new_lines[extended_start2 - 1 : start2 - 1] found_header = False for i, line in enumerate(lines_before_original): if section_header in line: @@ -115,23 +130,27 @@ def _calc_context_limits(patch_lines_before): if lines_before_original_dynamic_context == lines_before_new_dynamic_context: # get_logger().debug(f"found dynamic context match for section header: {section_header}") found_header = True - section_header = '' + section_header = "" else: pass # its ok to be here. We can't apply dynamic context if the lines are different if 'old' and 'new' hunks break if not found_header: # get_logger().debug(f"Section header not found in the extra lines before the hunk") - extended_start1, extended_size1, extended_start2, extended_size2 = \ - _calc_context_limits(patch_extra_lines_before) + extended_start1, extended_size1, extended_start2, extended_size2 = _calc_context_limits( + patch_extra_lines_before + ) else: - extended_start1, extended_size1, extended_start2, extended_size2 = \ - _calc_context_limits(patch_extra_lines_before) + extended_start1, extended_size1, extended_start2, extended_size2 = _calc_context_limits( + patch_extra_lines_before + ) # check if extra lines before hunk are different in original and new file - delta_lines_original = [f' {line}' for line in file_original_lines[extended_start1 - 1:start1 - 1]] + delta_lines_original = [ + f" {line}" for line in file_original_lines[extended_start1 - 1 : start1 - 1] + ] if file_new_lines: - delta_lines_new = [f' {line}' for line in file_new_lines[extended_start2 - 1:start2 - 1]] + delta_lines_new = [f" {line}" for line in file_new_lines[extended_start2 - 1 : start2 - 1]] if delta_lines_original != delta_lines_new: found_mini_match = False for i in range(len(delta_lines_original)): @@ -158,7 +177,7 @@ def _calc_context_limits(patch_lines_before): if section_header and not allow_dynamic_context: for line in delta_lines_original: if section_header in line: - section_header = '' # remove section header if it is in the extra delta lines + section_header = "" # remove section header if it is in the extra delta lines break else: extended_start1 = start1 @@ -166,10 +185,11 @@ def _calc_context_limits(patch_lines_before): extended_start2 = start2 extended_size2 = size2 delta_lines_original = [] - extended_patch_lines.append('') + extended_patch_lines.append("") extended_patch_lines.append( - f'@@ -{extended_start1},{extended_size1} ' - f'+{extended_start2},{extended_size2} @@ {section_header}') + f"@@ -{extended_start1},{extended_size1} " + f"+{extended_start2},{extended_size2} @@ {section_header}" + ) extended_patch_lines.extend(delta_lines_original) # one to zero based continue extended_patch_lines.append(line) @@ -179,14 +199,15 @@ def _calc_context_limits(patch_lines_before): # finish processing last hunk if start1 != -1 and patch_extra_lines_after > 0 and is_valid_hunk: - delta_lines_original = file_original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after] + delta_lines_original = file_original_lines[start1 + size1 - 1 : start1 + size1 - 1 + patch_extra_lines_after] # add space at the beginning of each extra line - delta_lines_original = [f' {line}' for line in delta_lines_original] + delta_lines_original = [f" {line}" for line in delta_lines_original] extended_patch_lines.extend(delta_lines_original) - extended_patch_str = '\n'.join(extended_patch_lines) + extended_patch_str = "\n".join(extended_patch_lines) return extended_patch_str + def check_if_hunk_lines_matches_to_file(i, original_lines, patch_lines, start1): """ Check if the hunk lines match the original file content. We saw cases where the hunk header line doesn't match the original file content, and then @@ -194,21 +215,24 @@ def check_if_hunk_lines_matches_to_file(i, original_lines, patch_lines, start1): """ is_valid_hunk = True try: - if i + 1 < len(patch_lines) and patch_lines[i + 1][0] == ' ': # an existing line in the file + if i + 1 < len(patch_lines) and patch_lines[i + 1][0] == " ": # an existing line in the file if patch_lines[i + 1].strip() != original_lines[start1 - 1].strip(): # check if different encoding is needed original_line = original_lines[start1 - 1].strip() - for encoding in ['iso-8859-1', 'latin-1', 'ascii', 'utf-16']: + for encoding in ["iso-8859-1", "latin-1", "ascii", "utf-16"]: try: if original_line.encode(encoding).decode().strip() == patch_lines[i + 1].strip(): - get_logger().info(f"Detected different encoding in hunk header line {start1}, needed encoding: {encoding}") - return False # we still want to avoid extending the hunk. But we don't want to log an error + get_logger().info( + f"Detected different encoding in hunk header line {start1}, needed encoding: {encoding}" + ) + return False # we still want to avoid extending the hunk. But we don't want to log an error except: pass is_valid_hunk = False get_logger().info( - f"Invalid hunk in PR, line {start1} in hunk header doesn't match the original file content") + f"Invalid hunk in PR, line {start1} in hunk header doesn't match the original file content" + ) except: pass return is_valid_hunk @@ -243,7 +267,7 @@ def omit_deletion_hunks(patch_lines) -> str: inside_hunk = False for line in patch_lines: - if line.startswith('@@'): + if line.startswith("@@"): match = RE_HUNK_HEADER.match(line) if match: # finish previous hunk @@ -257,16 +281,21 @@ def omit_deletion_hunks(patch_lines) -> str: temp_hunk.append(line) if line: edit_type = line[0] - if edit_type == '+': + if edit_type == "+": add_hunk = True if inside_hunk and add_hunk: added_patched.extend(temp_hunk) - return '\n'.join(added_patched) + return "\n".join(added_patched) -def handle_patch_deletions(patch: str, original_file_content_str: str, - new_file_content_str: str, file_name: str, edit_type: EDIT_TYPE = EDIT_TYPE.UNKNOWN) -> str: +def handle_patch_deletions( + patch: str, + original_file_content_str: str, + new_file_content_str: str, + file_name: str, + edit_type: EDIT_TYPE = EDIT_TYPE.UNKNOWN, +) -> str: """ Handle entire file or deletion patches. @@ -284,10 +313,13 @@ def handle_patch_deletions(patch: str, original_file_content_str: str, """ if not new_file_content_str and (edit_type == EDIT_TYPE.DELETED or edit_type == EDIT_TYPE.UNKNOWN): - # logic for handling deleted files - don't show patch, just show that the file was deleted - if get_settings().config.verbosity_level > 0: - get_logger().info(f"Processing file: {file_name}, minimizing deletion file") - patch = None # file was deleted + # When minimize_api_calls is active, empty content doesn't mean deleted — + # content was intentionally skipped. Only treat as deleted when edit_type + # explicitly says so, or when content was actually fetched (flag is off). + if edit_type == EDIT_TYPE.DELETED or not get_settings().get("github.minimize_api_calls", False): + if get_settings().config.verbosity_level > 0: + get_logger().info(f"Processing file: {file_name}, minimizing deletion file") + patch = None # file was deleted else: patch_lines = patch.splitlines() patch_new = omit_deletion_hunks(patch_lines) @@ -300,41 +332,41 @@ def handle_patch_deletions(patch: str, original_file_content_str: str, def decouple_and_convert_to_hunks_with_lines_numbers(patch: str, file) -> str: """ - Convert a given patch string into a string with line numbers for each hunk, indicating the new and old content of - the file. - - Args: - patch (str): The patch string to be converted. - file: An object containing the filename of the file being patched. - - Returns: - str: A string with line numbers for each hunk, indicating the new and old content of the file. - - example output: -## src/file.ts -__new hunk__ -881 line1 -882 line2 -883 line3 -887 + line4 -888 + line5 -889 line6 -890 line7 -... -__old hunk__ - line1 - line2 -- line3 -- line4 - line5 - line6 - ... + Convert a given patch string into a string with line numbers for each hunk, indicating the new and old content of + the file. + + Args: + patch (str): The patch string to be converted. + file: An object containing the filename of the file being patched. + + Returns: + str: A string with line numbers for each hunk, indicating the new and old content of the file. + + example output: + ## src/file.ts + __new hunk__ + 881 line1 + 882 line2 + 883 line3 + 887 + line4 + 888 + line5 + 889 line6 + 890 line7 + ... + __old hunk__ + line1 + line2 + - line3 + - line4 + line5 + line6 + ... """ # Add a header for the file if file: # if the file was deleted, return a message indicating that the file was deleted - if hasattr(file, 'edit_type') and file.edit_type == EDIT_TYPE.DELETED: + if hasattr(file, "edit_type") and file.edit_type == EDIT_TYPE.DELETED: return f"\n\n## File '{file.filename.strip()}' was deleted\n" patch_with_lines_str = f"\n\n## File: '{file.filename.strip()}'\n" @@ -349,26 +381,28 @@ def decouple_and_convert_to_hunks_with_lines_numbers(patch: str, file) -> str: prev_header_line = [] header_line = [] for line_i, line in enumerate(patch_lines): - if 'no newline at end of file' in line.lower(): + if "no newline at end of file" in line.lower(): continue - if line.startswith('@@'): + if line.startswith("@@"): header_line = line match = RE_HUNK_HEADER.match(line) if match and (new_content_lines or old_content_lines): # found a new hunk, split the previous lines if prev_header_line: - patch_with_lines_str += f'\n{prev_header_line}\n' + patch_with_lines_str += f"\n{prev_header_line}\n" is_plus_lines = is_minus_lines = False if new_content_lines: - is_plus_lines = any([line.startswith('+') for line in new_content_lines]) + is_plus_lines = any([line.startswith("+") for line in new_content_lines]) if old_content_lines: - is_minus_lines = any([line.startswith('-') for line in old_content_lines]) - if is_plus_lines or is_minus_lines: # notice 'True' here - we always present __new hunk__ for section, otherwise LLM gets confused - patch_with_lines_str = patch_with_lines_str.rstrip() + '\n__new hunk__\n' + is_minus_lines = any([line.startswith("-") for line in old_content_lines]) + if ( + is_plus_lines or is_minus_lines + ): # notice 'True' here - we always present __new hunk__ for section, otherwise LLM gets confused + patch_with_lines_str = patch_with_lines_str.rstrip() + "\n__new hunk__\n" for i, line_new in enumerate(new_content_lines): patch_with_lines_str += f"{start2 + i} {line_new}\n" if is_minus_lines: - patch_with_lines_str = patch_with_lines_str.rstrip() + '\n__old hunk__\n' + patch_with_lines_str = patch_with_lines_str.rstrip() + "\n__old hunk__\n" for line_old in old_content_lines: patch_with_lines_str += f"{line_old}\n" new_content_lines = [] @@ -378,13 +412,13 @@ def decouple_and_convert_to_hunks_with_lines_numbers(patch: str, file) -> str: section_header, size1, size2, start1, start2 = extract_hunk_headers(match) - elif line.startswith('+'): + elif line.startswith("+"): new_content_lines.append(line) - elif line.startswith('-'): + elif line.startswith("-"): old_content_lines.append(line) else: - if not line and line_i: # if this line is empty and the next line is a hunk header, skip it - if line_i + 1 < len(patch_lines) and patch_lines[line_i + 1].startswith('@@'): + if not line and line_i: # if this line is empty and the next line is a hunk header, skip it + if line_i + 1 < len(patch_lines) and patch_lines[line_i + 1].startswith("@@"): continue elif line_i + 1 == len(patch_lines): continue @@ -393,25 +427,29 @@ def decouple_and_convert_to_hunks_with_lines_numbers(patch: str, file) -> str: # finishing last hunk if match and new_content_lines: - patch_with_lines_str += f'\n{header_line}\n' + patch_with_lines_str += f"\n{header_line}\n" is_plus_lines = is_minus_lines = False if new_content_lines: - is_plus_lines = any([line.startswith('+') for line in new_content_lines]) + is_plus_lines = any([line.startswith("+") for line in new_content_lines]) if old_content_lines: - is_minus_lines = any([line.startswith('-') for line in old_content_lines]) - if is_plus_lines or is_minus_lines: # notice 'True' here - we always present __new hunk__ for section, otherwise LLM gets confused - patch_with_lines_str = patch_with_lines_str.rstrip() + '\n__new hunk__\n' + is_minus_lines = any([line.startswith("-") for line in old_content_lines]) + if ( + is_plus_lines or is_minus_lines + ): # notice 'True' here - we always present __new hunk__ for section, otherwise LLM gets confused + patch_with_lines_str = patch_with_lines_str.rstrip() + "\n__new hunk__\n" for i, line_new in enumerate(new_content_lines): patch_with_lines_str += f"{start2 + i} {line_new}\n" if is_minus_lines: - patch_with_lines_str = patch_with_lines_str.rstrip() + '\n__old hunk__\n' + patch_with_lines_str = patch_with_lines_str.rstrip() + "\n__old hunk__\n" for line_old in old_content_lines: patch_with_lines_str += f"{line_old}\n" return patch_with_lines_str.rstrip() -def extract_hunk_lines_from_patch(patch: str, file_name, line_start, line_end, side, remove_trailing_chars: bool = True) -> tuple[str, str]: +def extract_hunk_lines_from_patch( + patch: str, file_name, line_start, line_end, side, remove_trailing_chars: bool = True +) -> tuple[str, str]: try: patch_with_lines_str = f"\n\n## File: '{file_name.strip()}'\n\n" selected_lines = "" @@ -421,10 +459,10 @@ def extract_hunk_lines_from_patch(patch: str, file_name, line_start, line_end, s skip_hunk = False selected_lines_num = 0 for line in patch_lines: - if 'no newline at end of file' in line.lower(): + if "no newline at end of file" in line.lower(): continue - if line.startswith('@@'): + if line.startswith("@@"): skip_hunk = False selected_lines_num = 0 header_line = line @@ -434,27 +472,29 @@ def extract_hunk_lines_from_patch(patch: str, file_name, line_start, line_end, s section_header, size1, size2, start1, start2 = extract_hunk_headers(match) # check if line range is in this hunk - if side.lower() == 'left': + if side.lower() == "left": # check if line range is in this hunk if not (start1 <= line_start <= start1 + size1): skip_hunk = True continue - elif side.lower() == 'right': + elif side.lower() == "right": if not (start2 <= line_start <= start2 + size2): skip_hunk = True continue - patch_with_lines_str += f'\n{header_line}\n' + patch_with_lines_str += f"\n{header_line}\n" elif not skip_hunk: - if side.lower() == 'right' and line_start <= start2 + selected_lines_num <= line_end: - selected_lines += line + '\n' - if side.lower() == 'left' and start1 <= selected_lines_num + start1 <= line_end: - selected_lines += line + '\n' - patch_with_lines_str += line + '\n' - if not line.startswith('-'): # currently we don't support /ask line for deleted lines + if side.lower() == "right" and line_start <= start2 + selected_lines_num <= line_end: + selected_lines += line + "\n" + if side.lower() == "left" and start1 <= selected_lines_num + start1 <= line_end: + selected_lines += line + "\n" + patch_with_lines_str += line + "\n" + if not line.startswith("-"): # currently we don't support /ask line for deleted lines selected_lines_num += 1 except Exception as e: - get_logger().error(f"Failed to extract hunk lines from patch: {e}", artifact={"traceback": traceback.format_exc()}) + get_logger().error( + f"Failed to extract hunk lines from patch: {e}", artifact={"traceback": traceback.format_exc()} + ) return "", "" if remove_trailing_chars: diff --git a/pr_agent/git_providers/github_provider.py b/pr_agent/git_providers/github_provider.py index fa52b7dc05..13d164d94f 100644 --- a/pr_agent/git_providers/github_provider.py +++ b/pr_agent/git_providers/github_provider.py @@ -2,16 +2,16 @@ import difflib import hashlib import itertools +import json import re import time import traceback -import json from datetime import datetime from typing import Optional, Tuple from urllib.parse import urlparse -from github.Issue import Issue from github import AppAuthentication, Auth, Github, GithubException +from github.Issue import Issue from retry import retry from starlette_context import context @@ -19,14 +19,18 @@ from ..algo.git_patch_processing import extract_hunk_headers from ..algo.language_handler import is_valid_file from ..algo.types import EDIT_TYPE -from ..algo.utils import (PRReviewHeader, Range, clip_tokens, - find_line_number_of_relevant_line_in_file, - load_large_diff, set_file_languages) +from ..algo.utils import ( + PRReviewHeader, + Range, + clip_tokens, + find_line_number_of_relevant_line_in_file, + load_large_diff, + set_file_languages, +) from ..config_loader import get_settings from ..log import get_logger from ..servers.utils import RateLimitExceeded -from .git_provider import (MAX_FILES_ALLOWED_FULL, FilePatchInfo, GitProvider, - IncrementalPR) +from .git_provider import MAX_FILES_ALLOWED_FULL, FilePatchInfo, GitProvider, IncrementalPR class GithubProvider(GitProvider): @@ -37,8 +41,12 @@ def __init__(self, pr_url: Optional[str] = None): except Exception: self.installation_id = None self.max_comment_chars = 65000 - self.base_url = get_settings().get("GITHUB.BASE_URL", "https://api.github.com").rstrip("/") # "https://api.github.com" - self.base_url_html = self.base_url.split("api/")[0].rstrip("/") if "api/" in self.base_url else "https://github.com" + self.base_url = ( + get_settings().get("GITHUB.BASE_URL", "https://api.github.com").rstrip("/") + ) # "https://api.github.com" + self.base_url_html = ( + self.base_url.split("api/")[0].rstrip("/") if "api/" in self.base_url else "https://github.com" + ) self.github_client = self._get_github_client() self.repo = None self.pr_num = None @@ -47,15 +55,18 @@ def __init__(self, pr_url: Optional[str] = None): self.github_user_id = None self.diff_files = None self.git_files = None + self._languages = None self.incremental = IncrementalPR(False) - if pr_url and 'pull' in pr_url: + if pr_url and "pull" in pr_url: self.set_pr(pr_url) self.pr_commits = list(self.pr.get_commits()) self.last_commit_id = self.pr_commits[-1] - self.pr_url = self.get_pr_url() # pr_url for github actions can be as api.github.com, so we need to get the url from the pr object - elif pr_url and 'issue' in pr_url: #url is an issue + self.pr_url = ( + self.get_pr_url() + ) # pr_url for github actions can be as api.github.com, so we need to get the url from the pr object + elif pr_url and "issue" in pr_url: # url is an issue self.issue_main = self._get_issue_handle(pr_url) - else: #Instantiated the provider without a PR / Issue + else: # Instantiated the provider without a PR / Issue self.pr_commits = None def _get_issue_handle(self, issue_url) -> Optional[Issue]: @@ -67,13 +78,17 @@ def _get_issue_handle(self, issue_url) -> Optional[Issue]: try: repo_obj = self.github_client.get_repo(repo_name) if not repo_obj: - get_logger().error(f"Given url: {issue_url}, belonging to owner/repo: {repo_name} does " - f"not have a valid repository: {self.get_git_repo_url(issue_url)}") + get_logger().error( + f"Given url: {issue_url}, belonging to owner/repo: {repo_name} does " + f"not have a valid repository: {self.get_git_repo_url(issue_url)}" + ) return None # else: Valid repo handle: return repo_obj.get_issue(issue_number) except Exception as e: - get_logger().exception(f"Failed to get an issue object for issue: {issue_url}, belonging to owner/repo: {repo_name}") + get_logger().exception( + f"Failed to get an issue object for issue: {issue_url}, belonging to owner/repo: {repo_name}" + ) return None def get_incremental_commits(self, incremental=IncrementalPR(False)): @@ -88,15 +103,17 @@ def is_supported(self, capability: str) -> bool: def _get_owner_and_repo_path(self, given_url: str) -> str: try: repo_path = None - if 'issues' in given_url: + if "issues" in given_url: repo_path, _ = self._parse_issue_url(given_url) - elif 'pull' in given_url: + elif "pull" in given_url: repo_path, _ = self._parse_pr_url(given_url) - elif given_url.endswith('.git'): + elif given_url.endswith(".git"): parsed_url = urlparse(given_url) - repo_path = (parsed_url.path.split('.git')[0])[1:] # //.git -> / + repo_path = (parsed_url.path.split(".git")[0])[1:] # //.git -> / if not repo_path: - get_logger().error(f"url is neither an issues url nor a PR url nor a valid git url: {given_url}. Returning empty result.") + get_logger().error( + f"url is neither an issues url nor a PR url nor a valid git url: {given_url}. Returning empty result." + ) return "" return repo_path except Exception as e: @@ -104,37 +121,43 @@ def _get_owner_and_repo_path(self, given_url: str) -> str: return "" def get_git_repo_url(self, issues_or_pr_url: str) -> str: - repo_path = self._get_owner_and_repo_path(issues_or_pr_url) #Return: / + repo_path = self._get_owner_and_repo_path(issues_or_pr_url) # Return: / if not repo_path or repo_path not in issues_or_pr_url: get_logger().error(f"Unable to retrieve owner/path from url: {issues_or_pr_url}") return "" - return f"{self.base_url_html}/{repo_path}.git" #https://github.com / /.git + return f"{self.base_url_html}/{repo_path}.git" # https://github.com / /.git # Given a git repo url, return prefix and suffix of the provider in order to view a given file belonging to that repo. # Example: https://github.com/qodo-ai/pr-agent.git and branch: v0.8 -> prefix: "https://github.com/qodo-ai/pr-agent/blob/v0.8", suffix: "" # In case git url is not provided, provider will use PR context (which includes branch) to determine the prefix and suffix. - def get_canonical_url_parts(self, repo_git_url:str, desired_branch:str) -> Tuple[str, str]: + def get_canonical_url_parts(self, repo_git_url: str, desired_branch: str) -> Tuple[str, str]: owner = None repo = None scheme_and_netloc = None - if repo_git_url or self.issue_main: #Either user provided an external git url, which may be different than what this provider was initialized with, or an issue: + if ( + repo_git_url or self.issue_main + ): # Either user provided an external git url, which may be different than what this provider was initialized with, or an issue: desired_branch = desired_branch if repo_git_url else self.issue_main.repository.default_branch html_url = repo_git_url if repo_git_url else self.issue_main.html_url parsed_git_url = urlparse(html_url) scheme_and_netloc = parsed_git_url.scheme + "://" + parsed_git_url.netloc repo_path = self._get_owner_and_repo_path(html_url) - if repo_path.count('/') == 1: #Has to have the form / - owner, repo = repo_path.split('/') + if repo_path.count("/") == 1: # Has to have the form / + owner, repo = repo_path.split("/") else: get_logger().error(f"Invalid repo_path: {repo_path} from url: {html_url}") return ("", "") - if (not owner or not repo) and self.repo: #"else" - User did not provide an external git url, or not an issue, use self.repo object - owner, repo = self.repo.split('/') + if ( + not owner or not repo + ) and self.repo: # "else" - User did not provide an external git url, or not an issue, use self.repo object + owner, repo = self.repo.split("/") scheme_and_netloc = self.base_url_html desired_branch = self.repo_obj.default_branch - if not all([scheme_and_netloc, owner, repo]): #"else": Not invoked from a PR context,but no provided git url for context + if not all( + [scheme_and_netloc, owner, repo] + ): # "else": Not invoked from a PR context,but no provided git url for context get_logger().error(f"Unable to get canonical url parts since missing context (PR or explicit git url)") return ("", "") @@ -200,7 +223,7 @@ def get_files(self): git_files = context.get("git_files", None) if git_files: return git_files - self.git_files = list(self.pr.get_files()) # 'list' to handle pagination + self.git_files = list(self.pr.get_files()) # 'list' to handle pagination context["git_files"] = self.git_files return self.git_files except Exception: @@ -217,8 +240,9 @@ def get_num_of_files(self): except Exception as e: return -1 - @retry(exceptions=RateLimitExceeded, - tries=get_settings().github.ratelimit_retries, delay=2, backoff=2, jitter=(1, 3)) + @retry( + exceptions=RateLimitExceeded, tries=get_settings().github.ratelimit_retries, delay=2, backoff=2, jitter=(1, 3) + ) def get_diff_files(self) -> list[FilePatchInfo]: """ Retrieves the list of files that have been modified, added, deleted, or renamed in a pull request in GitHub, @@ -246,9 +270,10 @@ def get_diff_files(self) -> list[FilePatchInfo]: try: names_original = [file.filename for file in files_original] names_new = [file.filename for file in files] - get_logger().info(f"Filtered out [ignore] files for pull request:", extra= - {"files": names_original, - "filtered_files": names_new}) + get_logger().info( + f"Filtered out [ignore] files for pull request:", + extra={"files": names_original, "filtered_files": names_new}, + ) except Exception: pass @@ -263,14 +288,19 @@ def get_diff_files(self) -> list[FilePatchInfo]: repo = self.repo_obj pr = self.pr try: - compare = repo.compare(pr.base.sha, pr.head.sha) # communication with GitHub + compare = repo.compare(pr.base.sha, pr.head.sha) # communication with GitHub merge_base_commit = compare.merge_base_commit except Exception as e: get_logger().error(f"Failed to get merge base commit: {e}") merge_base_commit = pr.base if merge_base_commit.sha != pr.base.sha: - get_logger().info( - f"Using merge base commit {merge_base_commit.sha} instead of base commit ") + get_logger().info(f"Using merge base commit {merge_base_commit.sha} instead of base commit ") + + # When minimize_api_calls is enabled and not in incremental mode, + # skip file content fetches for files that already have a patch from + # the PR files endpoint. This eliminates up to 2N API calls (N = files). + minimize = get_settings().get("github.minimize_api_calls", False) + skip_content = minimize and not self.incremental.is_incremental counter_valid = 0 for file in files: @@ -282,6 +312,12 @@ def get_diff_files(self) -> list[FilePatchInfo]: if is_close_to_rate_limit: new_file_content_str = "" original_file_content_str = "" + elif skip_content and patch: + # Patch exists from PR files endpoint — no content fetch needed. + # extend_patch() is bypassed via disable_extra_lines in get_pr_diff(). + # handle_patch_deletions() uses edit_type instead of empty content check. + new_file_content_str = "" + original_file_content_str = "" else: # allow only a limited number of files to be fully loaded. We can manage the rest with diffs only counter_valid += 1 @@ -289,15 +325,21 @@ def get_diff_files(self) -> list[FilePatchInfo]: if counter_valid >= MAX_FILES_ALLOWED_FULL and patch and not self.incremental.is_incremental: avoid_load = True if counter_valid == MAX_FILES_ALLOWED_FULL: - get_logger().info(f"Too many files in PR, will avoid loading full content for rest of files") + get_logger().info( + f"Too many files in PR, will avoid loading full content for rest of files" + ) if avoid_load: new_file_content_str = "" else: - new_file_content_str = self._get_pr_file_content(file, self.pr.head.sha) # communication with GitHub + new_file_content_str = self._get_pr_file_content( + file, self.pr.head.sha + ) # communication with GitHub if self.incremental.is_incremental and self.unreviewed_files_set: - original_file_content_str = self._get_pr_file_content(file, self.incremental.last_seen_commit_sha) + original_file_content_str = self._get_pr_file_content( + file, self.incremental.last_seen_commit_sha + ) patch = load_large_diff(file.filename, new_file_content_str, original_file_content_str) self.unreviewed_files_set[file.filename] = patch else: @@ -309,32 +351,36 @@ def get_diff_files(self) -> list[FilePatchInfo]: if not patch: patch = load_large_diff(file.filename, new_file_content_str, original_file_content_str) - - if file.status == 'added': + if file.status == "added": edit_type = EDIT_TYPE.ADDED - elif file.status == 'removed': + elif file.status == "removed": edit_type = EDIT_TYPE.DELETED - elif file.status == 'renamed': + elif file.status == "renamed": edit_type = EDIT_TYPE.RENAMED - elif file.status == 'modified': + elif file.status == "modified": edit_type = EDIT_TYPE.MODIFIED else: get_logger().error(f"Unknown edit type: {file.status}") edit_type = EDIT_TYPE.UNKNOWN # count number of lines added and removed - if hasattr(file, 'additions') and hasattr(file, 'deletions'): + if hasattr(file, "additions") and hasattr(file, "deletions"): num_plus_lines = file.additions num_minus_lines = file.deletions else: patch_lines = patch.splitlines(keepends=True) - num_plus_lines = len([line for line in patch_lines if line.startswith('+')]) - num_minus_lines = len([line for line in patch_lines if line.startswith('-')]) - - file_patch_canonical_structure = FilePatchInfo(original_file_content_str, new_file_content_str, patch, - file.filename, edit_type=edit_type, - num_plus_lines=num_plus_lines, - num_minus_lines=num_minus_lines,) + num_plus_lines = len([line for line in patch_lines if line.startswith("+")]) + num_minus_lines = len([line for line in patch_lines if line.startswith("-")]) + + file_patch_canonical_structure = FilePatchInfo( + original_file_content_str, + new_file_content_str, + patch, + file.filename, + edit_type=edit_type, + num_plus_lines=num_plus_lines, + num_minus_lines=num_minus_lines, + ) diff_files.append(file_patch_canonical_structure) if invalid_files_names: get_logger().info(f"Filtered out files with invalid extensions: {invalid_files_names}") @@ -348,8 +394,7 @@ def get_diff_files(self) -> list[FilePatchInfo]: return diff_files except Exception as e: - get_logger().error(f"Failing to get diff files: {e}", - artifact={"traceback": traceback.format_exc()}) + get_logger().error(f"Failing to get diff files: {e}", artifact={"traceback": traceback.format_exc()}) raise RateLimitExceeded("Rate limit exceeded for GitHub API.") from e def publish_description(self, pr_title: str, pr_body: str): @@ -361,11 +406,9 @@ def get_latest_commit_url(self) -> str: def get_comment_url(self, comment) -> str: return comment.html_url - def publish_persistent_comment(self, pr_comment: str, - initial_header: str, - update_header: bool = True, - name='review', - final_update_message=True): + def publish_persistent_comment( + self, pr_comment: str, initial_header: str, update_header: bool = True, name="review", final_update_message=True + ): self.publish_persistent_comment_full(pr_comment, initial_header, update_header, name, final_update_message) def publish_comment(self, pr_comment: str, is_temporary: bool = False): @@ -386,23 +429,24 @@ def publish_comment(self, pr_comment: str, is_temporary: bool = False): if hasattr(response, "user") and hasattr(response.user, "login"): self.github_user_id = response.user.login response.is_temporary = is_temporary - if not hasattr(self.pr, 'comments_list'): + if not hasattr(self.pr, "comments_list"): self.pr.comments_list = [] self.pr.comments_list.append(response) return response - def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None): + def publish_inline_comment( + self, body: str, relevant_file: str, relevant_line_in_file: str, original_suggestion=None + ): body = self.limit_output_characters(body, self.max_comment_chars) self.publish_inline_comments([self.create_inline_comment(body, relevant_file, relevant_line_in_file)]) - - def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str, - absolute_position: int = None): + def create_inline_comment( + self, body: str, relevant_file: str, relevant_line_in_file: str, absolute_position: int = None + ): body = self.limit_output_characters(body, self.max_comment_chars) - position, absolute_position = find_line_number_of_relevant_line_in_file(self.diff_files, - relevant_file.strip('`'), - relevant_line_in_file, - absolute_position) + position, absolute_position = find_line_number_of_relevant_line_in_file( + self.diff_files, relevant_file.strip("`"), relevant_line_in_file, absolute_position + ) if position == -1: get_logger().info(f"Could not find position for {relevant_file} {relevant_line_in_file}") subject_type = "FILE" @@ -418,49 +462,52 @@ def publish_inline_comments(self, comments: list[dict], disable_fallback: bool = except Exception as e: get_logger().info(f"Initially failed to publish inline comments as committable") - if (getattr(e, "status", None) == 422 and not disable_fallback): + if getattr(e, "status", None) == 422 and not disable_fallback: pass # continue to try _publish_inline_comments_fallback_with_verification else: - raise e # will end up with publishing the comments one by one + raise e # will end up with publishing the comments one by one try: self._publish_inline_comments_fallback_with_verification(comments) except Exception as e: get_logger().error(f"Failed to publish inline code comments fallback, error: {e}") - raise e - + raise e + def get_review_thread_comments(self, comment_id: int) -> list[dict]: """ Retrieves all comments in the same thread as the given comment. - + Args: comment_id: Review comment ID - + Returns: List of comments in the same thread """ try: # Fetch all comments with a single API call all_comments = list(self.pr.get_comments()) - + # Find the target comment by ID target_comment = next((c for c in all_comments if c.id == comment_id), None) if not target_comment: return [] - + # Get root comment id root_comment_id = target_comment.raw_data.get("in_reply_to_id", target_comment.id) # Build the thread - include the root comment and all replies to it thread_comments = [ - c for c in all_comments if - c.id == root_comment_id or c.raw_data.get("in_reply_to_id") == root_comment_id + c + for c in all_comments + if c.id == root_comment_id or c.raw_data.get("in_reply_to_id") == root_comment_id ] - - + return thread_comments - + except Exception as e: - get_logger().exception(f"Failed to get review comments for an inline ask command", artifact={"comment_id": comment_id, "error": e}) + get_logger().exception( + f"Failed to get review comments for an inline ask command", + artifact={"comment_id": comment_id, "error": e}, + ) return [] def _publish_inline_comments_fallback_with_verification(self, comments: list[dict]): @@ -481,7 +528,8 @@ def _publish_inline_comments_fallback_with_verification(self, comments: list[dic # try to publish one by one the invalid comments as a one-line code comment if invalid_comments and get_settings().github.try_fix_invalid_inline_comments: fixed_comments_as_one_liner = self._try_fix_invalid_inline_comments( - [comment for comment, _ in invalid_comments]) + [comment for comment, _ in invalid_comments] + ) for comment in fixed_comments_as_one_liner: try: self.publish_inline_comments([comment], disable_fallback=True) @@ -495,8 +543,7 @@ def _verify_code_comment(self, comment: dict): try: # event ="" # By leaving this blank, you set the review action state to PENDING input = dict(commit_id=self.last_commit_id.sha, comments=[comment]) - headers, data = self.pr._requester.requestJsonAndCheck( - "POST", f"{self.pr.url}/reviews", input=input) + headers, data = self.pr._requester.requestJsonAndCheck("POST", f"{self.pr.url}/reviews", input=input) pending_review_id = data["id"] is_verified = True except Exception as err: @@ -530,6 +577,7 @@ def _try_fix_invalid_inline_comments(self, invalid_comments: list[dict]) -> list This is a best-effort attempt to fix invalid comments, and should be verified accordingly. """ import copy + fixed_comments = [] for comment in invalid_comments: try: @@ -557,20 +605,23 @@ def publish_code_suggestions(self, code_suggestions: list) -> bool: code_suggestions_validated = self.validate_comments_inside_hunks(code_suggestions) for suggestion in code_suggestions_validated: - body = suggestion['body'] - relevant_file = suggestion['relevant_file'] - relevant_lines_start = suggestion['relevant_lines_start'] - relevant_lines_end = suggestion['relevant_lines_end'] + body = suggestion["body"] + relevant_file = suggestion["relevant_file"] + relevant_lines_start = suggestion["relevant_lines_start"] + relevant_lines_end = suggestion["relevant_lines_end"] if not relevant_lines_start or relevant_lines_start == -1: get_logger().exception( - f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}") + f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}" + ) continue if relevant_lines_end < relevant_lines_start: - get_logger().exception(f"Failed to publish code suggestion, " - f"relevant_lines_end is {relevant_lines_end} and " - f"relevant_lines_start is {relevant_lines_start}") + get_logger().exception( + f"Failed to publish code suggestion, " + f"relevant_lines_end is {relevant_lines_end} and " + f"relevant_lines_start is {relevant_lines_start}" + ) continue if relevant_lines_end > relevant_lines_start: @@ -605,8 +656,8 @@ def edit_comment(self, comment, body: str): if hasattr(e, "status") and e.status == 403: # Log as warning for permission-related issues (usually due to polling) get_logger().warning( - "Failed to edit github comment due to permission restrictions", - artifact={"error": e}) + "Failed to edit github comment due to permission restrictions", artifact={"error": e} + ) else: get_logger().exception(f"Failed to edit github comment", artifact={"error": e}) @@ -615,8 +666,7 @@ def edit_comment_from_comment_id(self, comment_id: int, body: str): # self.pr.get_issue_comment(comment_id).edit(body) body = self.limit_output_characters(body, self.max_comment_chars) headers, data_patch = self.pr._requester.requestJsonAndCheck( - "PATCH", f"{self.base_url}/repos/{self.repo}/issues/comments/{comment_id}", - input={"body": body} + "PATCH", f"{self.base_url}/repos/{self.repo}/issues/comments/{comment_id}", input={"body": body} ) except Exception as e: get_logger().exception(f"Failed to edit comment, error: {e}") @@ -626,8 +676,9 @@ def reply_to_comment_from_comment_id(self, comment_id: int, body: str): # self.pr.get_issue_comment(comment_id).edit(body) body = self.limit_output_characters(body, self.max_comment_chars) headers, data_patch = self.pr._requester.requestJsonAndCheck( - "POST", f"{self.base_url}/repos/{self.repo}/pulls/{self.pr_num}/comments/{comment_id}/replies", - input={"body": body} + "POST", + f"{self.base_url}/repos/{self.repo}/pulls/{self.pr_num}/comments/{comment_id}/replies", + input={"body": body}, ) except Exception as e: get_logger().exception(f"Failed to reply comment, error: {e}") @@ -638,33 +689,36 @@ def get_comment_body_from_comment_id(self, comment_id: int): headers, data_patch = self.pr._requester.requestJsonAndCheck( "GET", f"{self.base_url}/repos/{self.repo}/issues/comments/{comment_id}" ) - return data_patch.get("body","") + return data_patch.get("body", "") except Exception as e: get_logger().exception(f"Failed to edit comment, error: {e}") return None def publish_file_comments(self, file_comments: list) -> bool: try: - headers, existing_comments = self.pr._requester.requestJsonAndCheck( - "GET", f"{self.pr.url}/comments" - ) + headers, existing_comments = self.pr._requester.requestJsonAndCheck("GET", f"{self.pr.url}/comments") for comment in file_comments: - comment['commit_id'] = self.last_commit_id.sha - comment['body'] = self.limit_output_characters(comment['body'], self.max_comment_chars) + comment["commit_id"] = self.last_commit_id.sha + comment["body"] = self.limit_output_characters(comment["body"], self.max_comment_chars) found = False for existing_comment in existing_comments: - comment['commit_id'] = self.last_commit_id.sha + comment["commit_id"] = self.last_commit_id.sha our_app_name = get_settings().get("GITHUB.APP_NAME", "") same_comment_creator = False - if self.deployment_type == 'app': - same_comment_creator = our_app_name.lower() in existing_comment['user']['login'].lower() - elif self.deployment_type == 'user': - same_comment_creator = self.github_user_id == existing_comment['user']['login'] - if existing_comment['subject_type'] == 'file' and comment['path'] == existing_comment['path'] and same_comment_creator: - + if self.deployment_type == "app": + same_comment_creator = our_app_name.lower() in existing_comment["user"]["login"].lower() + elif self.deployment_type == "user": + same_comment_creator = self.github_user_id == existing_comment["user"]["login"] + if ( + existing_comment["subject_type"] == "file" + and comment["path"] == existing_comment["path"] + and same_comment_creator + ): headers, data_patch = self.pr._requester.requestJsonAndCheck( - "PATCH", f"{self.base_url}/repos/{self.repo}/pulls/comments/{existing_comment['id']}", input={"body":comment['body']} + "PATCH", + f"{self.base_url}/repos/{self.repo}/pulls/comments/{existing_comment['id']}", + input={"body": comment["body"]}, ) found = True break @@ -679,7 +733,7 @@ def publish_file_comments(self, file_comments: list) -> bool: def remove_initial_comment(self): try: - for comment in getattr(self.pr, 'comments_list', []): + for comment in getattr(self.pr, "comments_list", []): if comment.is_temporary: self.remove_comment(comment) except Exception as e: @@ -695,8 +749,10 @@ def get_title(self): return self.pr.title def get_languages(self): - languages = self._get_repo().get_languages() - return languages + if self._languages is not None: + return self._languages + self._languages = self._get_repo().get_languages() + return self._languages def get_pr_branch(self): return self.pr.head.ref @@ -704,7 +760,7 @@ def get_pr_branch(self): def get_pr_owner_id(self) -> str | None: if not self.repo: return None - return self.repo.split('/')[0] + return self.repo.split("/")[0] def get_pr_description_full(self): return self.pr.body @@ -712,7 +768,7 @@ def get_pr_description_full(self): def get_user_id(self): if not self.github_user_id: try: - self.github_user_id = self.github_client.get_user().raw_data['login'] + self.github_user_id = self.github_client.get_user().raw_data["login"] except Exception as e: self.github_user_id = "" # logging.exception(f"Failed to get user id, error: {e}") @@ -721,7 +777,7 @@ def get_user_id(self): def get_notifications(self, since: datetime): deployment_type = get_settings().get("GITHUB.DEPLOYMENT_TYPE", "user") - if deployment_type != 'user': + if deployment_type != "user": raise ValueError("Deployment mode must be set to 'user' to get notifications") notifications = self.github_client.get_user().get_notifications(since=since) @@ -741,15 +797,16 @@ def get_repo_settings(self): return "" def get_workspace_name(self): - return self.repo.split('/')[0] + return self.repo.split("/")[0] def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False) -> Optional[int]: if disable_eyes: return None try: headers, data_patch = self.pr._requester.requestJsonAndCheck( - "POST", f"{self.base_url}/repos/{self.repo}/issues/comments/{issue_comment_id}/reactions", - input={"content": "eyes"} + "POST", + f"{self.base_url}/repos/{self.repo}/issues/comments/{issue_comment_id}/reactions", + input={"content": "eyes"}, ) return data_patch.get("id", None) except Exception as e: @@ -761,7 +818,7 @@ def remove_reaction(self, issue_comment_id: int, reaction_id: str) -> bool: # self.pr.get_issue_comment(issue_comment_id).delete_reaction(reaction_id) headers, data_patch = self.pr._requester.requestJsonAndCheck( "DELETE", - f"{self.base_url}/repos/{self.repo}/issues/comments/{issue_comment_id}/reactions/{reaction_id}" + f"{self.base_url}/repos/{self.repo}/issues/comments/{issue_comment_id}/reactions/{reaction_id}", ) return True except Exception as e: @@ -771,24 +828,24 @@ def remove_reaction(self, issue_comment_id: int, reaction_id: str) -> bool: def _parse_pr_url(self, pr_url: str) -> Tuple[str, int]: parsed_url = urlparse(pr_url) - if parsed_url.path.startswith('/api/v3'): + if parsed_url.path.startswith("/api/v3"): parsed_url = urlparse(pr_url.replace("/api/v3", "")) - path_parts = parsed_url.path.strip('/').split('/') - if 'api.github.com' in parsed_url.netloc or '/api/v3' in pr_url: - if len(path_parts) < 5 or path_parts[3] != 'pulls': + path_parts = parsed_url.path.strip("/").split("/") + if "api.github.com" in parsed_url.netloc or "/api/v3" in pr_url: + if len(path_parts) < 5 or path_parts[3] != "pulls": raise ValueError("The provided URL does not appear to be a GitHub PR URL") - repo_name = '/'.join(path_parts[1:3]) + repo_name = "/".join(path_parts[1:3]) try: pr_number = int(path_parts[4]) except ValueError as e: raise ValueError("Unable to convert PR number to integer") from e return repo_name, pr_number - if len(path_parts) < 4 or path_parts[2] != 'pull': + if len(path_parts) < 4 or path_parts[2] != "pull": raise ValueError("The provided URL does not appear to be a GitHub PR URL") - repo_name = '/'.join(path_parts[:2]) + repo_name = "/".join(path_parts[:2]) try: pr_number = int(path_parts[3]) except ValueError as e: @@ -799,24 +856,24 @@ def _parse_pr_url(self, pr_url: str) -> Tuple[str, int]: def _parse_issue_url(self, issue_url: str) -> Tuple[str, int]: parsed_url = urlparse(issue_url) - if parsed_url.path.startswith('/api/v3'): #Check if came from github app + if parsed_url.path.startswith("/api/v3"): # Check if came from github app parsed_url = urlparse(issue_url.replace("/api/v3", "")) - path_parts = parsed_url.path.strip('/').split('/') - if 'api.github.com' in parsed_url.netloc or '/api/v3' in issue_url: #Check if came from github app - if len(path_parts) < 5 or path_parts[3] != 'issues': + path_parts = parsed_url.path.strip("/").split("/") + if "api.github.com" in parsed_url.netloc or "/api/v3" in issue_url: # Check if came from github app + if len(path_parts) < 5 or path_parts[3] != "issues": raise ValueError("The provided URL does not appear to be a GitHub ISSUE URL") - repo_name = '/'.join(path_parts[1:3]) + repo_name = "/".join(path_parts[1:3]) try: issue_number = int(path_parts[4]) except ValueError as e: raise ValueError("Unable to convert issue number to integer") from e return repo_name, issue_number - if len(path_parts) < 4 or path_parts[2] != 'issues': + if len(path_parts) < 4 or path_parts[2] != "issues": raise ValueError("The provided URL does not appear to be a GitHub PR issue") - repo_name = '/'.join(path_parts[:2]) + repo_name = "/".join(path_parts[:2]) try: issue_number = int(path_parts[3]) except ValueError as e: @@ -827,7 +884,7 @@ def _parse_issue_url(self, issue_url: str) -> Tuple[str, int]: def _get_github_client(self): self.deployment_type = get_settings().get("GITHUB.DEPLOYMENT_TYPE", "user") self.auth = None - if self.deployment_type == 'app': + if self.deployment_type == "app": try: private_key = get_settings().github.private_key app_id = get_settings().github.app_id @@ -835,16 +892,16 @@ def _get_github_client(self): raise ValueError("GitHub app ID and private key are required when using GitHub app deployment") from e if not self.installation_id: raise ValueError("GitHub app installation ID is required when using GitHub app deployment") - auth = AppAuthentication(app_id=app_id, private_key=private_key, - installation_id=self.installation_id) + auth = AppAuthentication(app_id=app_id, private_key=private_key, installation_id=self.installation_id) self.auth = auth - elif self.deployment_type == 'user': + elif self.deployment_type == "user": try: token = get_settings().github.user_token except AttributeError as e: raise ValueError( "GitHub token is required when using user deployment. See: " - "https://github.com/Codium-ai/pr-agent#method-2-run-from-source") from e + "https://github.com/Codium-ai/pr-agent#method-2-run-from-source" + ) from e self.auth = Auth.Token(token) if self.auth: return Github(auth=self.auth, base_url=self.base_url) @@ -852,37 +909,28 @@ def _get_github_client(self): raise ValueError("Could not authenticate to GitHub") def _get_repo(self): - if hasattr(self, 'repo_obj') and \ - hasattr(self.repo_obj, 'full_name') and \ - self.repo_obj.full_name == self.repo: + if hasattr(self, "repo_obj") and hasattr(self.repo_obj, "full_name") and self.repo_obj.full_name == self.repo: return self.repo_obj else: self.repo_obj = self.github_client.get_repo(self.repo) return self.repo_obj - def _get_pr(self): return self._get_repo().get_pull(self.pr_num) def get_pr_file_content(self, file_path: str, branch: str) -> str: try: - file_content_str = str( - self._get_repo() - .get_contents(file_path, ref=branch) - .decoded_content.decode() - ) + file_content_str = str(self._get_repo().get_contents(file_path, ref=branch).decoded_content.decode()) except Exception: file_content_str = "" return file_content_str - def create_or_update_pr_file( - self, file_path: str, branch: str, contents="", message="" - ) -> None: + def create_or_update_pr_file(self, file_path: str, branch: str, contents="", message="") -> None: try: file_obj = self._get_repo().get_contents(file_path, ref=branch) - sha1=file_obj.sha + sha1 = file_obj.sha except Exception: - sha1="" + sha1 = "" self.repo_obj.update_file( path=file_path, message=message, @@ -896,9 +944,14 @@ def _get_pr_file_content(self, file: FilePatchInfo, sha: str) -> str: def publish_labels(self, pr_types): try: - label_color_map = {"Bug fix": "1d76db", "Tests": "e99695", "Bug fix with tests": "c5def5", - "Enhancement": "bfd4f2", "Documentation": "d4c5f9", - "Other": "d1bcf9"} + label_color_map = { + "Bug fix": "1d76db", + "Tests": "e99695", + "Bug fix with tests": "c5def5", + "Enhancement": "bfd4f2", + "Documentation": "d4c5f9", + "Other": "d1bcf9", + } post_parameters = [] for p in pr_types: color = label_color_map.get(p, "d1bcf9") # default to "Other" color @@ -912,12 +965,11 @@ def publish_labels(self, pr_types): def get_pr_labels(self, update=False): try: if not update: - labels =self.pr.labels + labels = self.pr.labels return [label.name for label in labels] - else: # obtain the latest labels. Maybe they changed while the AI was running - headers, labels = self.pr._requester.requestJsonAndCheck( - "GET", f"{self.pr.issue_url}/labels") - return [label['name'] for label in labels] + else: # obtain the latest labels. Maybe they changed while the AI was running + headers, labels = self.pr._requester.requestJsonAndCheck("GET", f"{self.pr.issue_url}/labels") + return [label["name"] for label in labels] except Exception as e: get_logger().exception(f"Failed to get labels, error: {e}") @@ -936,7 +988,7 @@ def get_commit_messages(self): """ max_tokens = get_settings().get("CONFIG.MAX_COMMITS_TOKENS", None) try: - commit_list = self.pr.get_commits() + commit_list = self.pr_commits if self.pr_commits is not None else list(self.pr.get_commits()) commit_messages = [commit.commit.message for commit in commit_list] commit_messages_str = "\n".join([f"{i + 1}. {message}" for i, message in enumerate(commit_messages)]) except Exception: @@ -947,13 +999,14 @@ def get_commit_messages(self): def generate_link_to_relevant_line_number(self, suggestion) -> str: try: - relevant_file = suggestion['relevant_file'].strip('`').strip("'").strip('\n') - relevant_line_str = suggestion['relevant_line'].strip('\n') + relevant_file = suggestion["relevant_file"].strip("`").strip("'").strip("\n") + relevant_line_str = suggestion["relevant_line"].strip("\n") if not relevant_line_str: return "" - position, absolute_position = find_line_number_of_relevant_line_in_file \ - (self.diff_files, relevant_file, relevant_line_str) + position, absolute_position = find_line_number_of_relevant_line_in_file( + self.diff_files, relevant_file, relevant_line_str + ) if absolute_position != -1: # # link to right file only @@ -961,7 +1014,7 @@ def generate_link_to_relevant_line_number(self, suggestion) -> str: # + "#" + f"L{absolute_position}" # link to diff - sha_file = hashlib.sha256(relevant_file.encode('utf-8')).hexdigest() + sha_file = hashlib.sha256(relevant_file.encode("utf-8")).hexdigest() link = f"{self.base_url_html}/{self.repo}/pull/{self.pr_num}/files#diff-{sha_file}R{absolute_position}" return link except Exception as e: @@ -970,7 +1023,7 @@ def generate_link_to_relevant_line_number(self, suggestion) -> str: return "" def get_line_link(self, relevant_file: str, relevant_line_start: int, relevant_line_end: int = None) -> str: - sha_file = hashlib.sha256(relevant_file.encode('utf-8')).hexdigest() + sha_file = hashlib.sha256(relevant_file.encode("utf-8")).hexdigest() if relevant_line_start == -1: link = f"{self.base_url_html}/{self.repo}/pull/{self.pr_num}/files#diff-{sha_file}" elif relevant_line_end: @@ -1001,8 +1054,7 @@ def get_lines_link_original_file(self, filepath: str, component_range: Range) -> line_end = component_range.line_end + 1 # link = (f"https://github.com/{self.repo}/blob/{self.last_commit_id.sha}/{filepath}/" # f"#L{line_start}-L{line_end}") - link = (f"{self.base_url_html}/{self.repo}/blob/{self.last_commit_id.sha}/{filepath}/" - f"#L{line_start}-L{line_end}") + link = f"{self.base_url_html}/{self.repo}/blob/{self.last_commit_id.sha}/{filepath}/#L{line_start}-L{line_end}" return link @@ -1034,8 +1086,9 @@ def fetch_sub_issues(self, issue_url): }} }} """ - response_tuple = self.github_client._Github__requester.requestJson("POST", "/graphql", - input={"query": query}) + response_tuple = self.github_client._Github__requester.requestJson( + "POST", "/graphql", input={"query": query} + ) # Extract the JSON response from the tuple and parses it if isinstance(response_tuple, tuple) and len(response_tuple) == 3: @@ -1044,7 +1097,6 @@ def fetch_sub_issues(self, issue_url): get_logger().error(f"Unexpected response format: {response_tuple}") return sub_issues - issue_id = response_json.get("data", {}).get("repository", {}).get("issue", {}).get("id") if not issue_id: @@ -1065,20 +1117,23 @@ def fetch_sub_issues(self, issue_url): }} }} """ - sub_issues_response_tuple = self.github_client._Github__requester.requestJson("POST", "/graphql", input={ - "query": sub_issues_query}) + sub_issues_response_tuple = self.github_client._Github__requester.requestJson( + "POST", "/graphql", input={"query": sub_issues_query} + ) # Extract the JSON response from the tuple and parses it if isinstance(sub_issues_response_tuple, tuple) and len(sub_issues_response_tuple) == 3: sub_issues_response_json = json.loads(sub_issues_response_tuple[2]) else: - get_logger().error("Unexpected sub-issues response format", artifact={"response": sub_issues_response_tuple}) + get_logger().error( + "Unexpected sub-issues response format", artifact={"response": sub_issues_response_tuple} + ) return sub_issues if not sub_issues_response_json.get("data", {}).get("node", {}).get("subIssues"): get_logger().error("Invalid sub-issues response structure") return sub_issues - + nodes = sub_issues_response_json.get("data", {}).get("node", {}).get("subIssues", {}).get("nodes", []) get_logger().info(f"Github Sub-issues fetched: {len(nodes)}", artifact={"nodes": nodes}) @@ -1102,7 +1157,7 @@ def auto_approve(self) -> bool: return False def calc_pr_statistics(self, pull_request_data: dict): - return {} + return {} def validate_comments_inside_hunks(self, code_suggestions): """ @@ -1110,45 +1165,43 @@ def validate_comments_inside_hunks(self, code_suggestions): """ code_suggestions_copy = copy.deepcopy(code_suggestions) diff_files = self.get_diff_files() - RE_HUNK_HEADER = re.compile( - r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)") + RE_HUNK_HEADER = re.compile(r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)") diff_files = set_file_languages(diff_files) for suggestion in code_suggestions_copy: try: - relevant_file_path = suggestion['relevant_file'] + relevant_file_path = suggestion["relevant_file"] for file in diff_files: if file.filename == relevant_file_path: - # generate on-demand the patches range for the relevant file patch_str = file.patch - if not hasattr(file, 'patches_range'): + if not hasattr(file, "patches_range"): file.patches_range = [] patch_lines = patch_str.splitlines() for i, line in enumerate(patch_lines): - if line.startswith('@@'): + if line.startswith("@@"): match = RE_HUNK_HEADER.match(line) # identify hunk header if match: section_header, size1, size2, start1, start2 = extract_hunk_headers(match) - file.patches_range.append({'start': start2, 'end': start2 + size2 - 1}) + file.patches_range.append({"start": start2, "end": start2 + size2 - 1}) patches_range = file.patches_range - comment_start_line = suggestion.get('relevant_lines_start', None) - comment_end_line = suggestion.get('relevant_lines_end', None) - original_suggestion = suggestion.get('original_suggestion', None) # needed for diff code + comment_start_line = suggestion.get("relevant_lines_start", None) + comment_end_line = suggestion.get("relevant_lines_end", None) + original_suggestion = suggestion.get("original_suggestion", None) # needed for diff code if not comment_start_line or not comment_end_line or not original_suggestion: continue # check if the comment is inside a valid hunk is_valid_hunk = False - min_distance = float('inf') + min_distance = float("inf") patch_range_min = None # find the hunk that contains the comment, or the closest one for i, patch_range in enumerate(patches_range): - d1 = comment_start_line - patch_range['start'] - d2 = patch_range['end'] - comment_end_line + d1 = comment_start_line - patch_range["start"] + d2 = patch_range["end"] - comment_end_line if d1 >= 0 and d2 >= 0: # found a valid hunk is_valid_hunk = True min_distance = 0 @@ -1162,41 +1215,50 @@ def validate_comments_inside_hunks(self, code_suggestions): patch_range_min = patch_range min_distance = min(min_distance, d) if not is_valid_hunk: - if min_distance < 10: # 10 lines - a reasonable distance to consider the comment inside the hunk + if ( + min_distance < 10 + ): # 10 lines - a reasonable distance to consider the comment inside the hunk # make the suggestion non-committable, yet multi line - suggestion['relevant_lines_start'] = max(suggestion['relevant_lines_start'], patch_range_min['start']) - suggestion['relevant_lines_end'] = min(suggestion['relevant_lines_end'], patch_range_min['end']) - body = suggestion['body'].strip() + suggestion["relevant_lines_start"] = max( + suggestion["relevant_lines_start"], patch_range_min["start"] + ) + suggestion["relevant_lines_end"] = min( + suggestion["relevant_lines_end"], patch_range_min["end"] + ) + body = suggestion["body"].strip() # present new diff code in collapsible - existing_code = original_suggestion['existing_code'].rstrip() + "\n" - improved_code = original_suggestion['improved_code'].rstrip() + "\n" - diff = difflib.unified_diff(existing_code.split('\n'), - improved_code.split('\n'), n=999) + existing_code = original_suggestion["existing_code"].rstrip() + "\n" + improved_code = original_suggestion["improved_code"].rstrip() + "\n" + diff = difflib.unified_diff(existing_code.split("\n"), improved_code.split("\n"), n=999) patch_orig = "\n".join(diff) - patch = "\n".join(patch_orig.splitlines()[5:]).strip('\n') + patch = "\n".join(patch_orig.splitlines()[5:]).strip("\n") diff_code = f"\n\n
New proposed code:\n\n```diff\n{patch.rstrip()}\n```" # replace ```suggestion ... ``` with diff_code, using regex: - body = re.sub(r'```suggestion.*?```', diff_code, body, flags=re.DOTALL) + body = re.sub(r"```suggestion.*?```", diff_code, body, flags=re.DOTALL) body += "\n\n
" - suggestion['body'] = body - get_logger().info(f"Comment was moved to a valid hunk, " - f"start_line={suggestion['relevant_lines_start']}, end_line={suggestion['relevant_lines_end']}, file={file.filename}") + suggestion["body"] = body + get_logger().info( + f"Comment was moved to a valid hunk, " + f"start_line={suggestion['relevant_lines_start']}, end_line={suggestion['relevant_lines_end']}, file={file.filename}" + ) else: - get_logger().error(f"Comment is not inside a valid hunk, " - f"start_line={suggestion['relevant_lines_start']}, end_line={suggestion['relevant_lines_end']}, file={file.filename}") + get_logger().error( + f"Comment is not inside a valid hunk, " + f"start_line={suggestion['relevant_lines_start']}, end_line={suggestion['relevant_lines_end']}, file={file.filename}" + ) except Exception as e: get_logger().error(f"Failed to process patch for committable comment, error: {e}") return code_suggestions_copy - #Clone related + # Clone related def _prepare_clone_url_with_token(self, repo_url_to_clone: str) -> str | None: scheme = "https://" - #For example, to clone: - #https://github.com/Codium-ai/pr-agent-pro.git - #Need to embed inside the github token: - #https://@github.com/Codium-ai/pr-agent-pro.git + # For example, to clone: + # https://github.com/Codium-ai/pr-agent-pro.git + # Need to embed inside the github token: + # https://@github.com/Codium-ai/pr-agent-pro.git github_token = self.auth.token github_base_url = self.base_url_html @@ -1219,7 +1281,7 @@ def _prepare_clone_url_with_token(self, repo_url_to_clone: str) -> str | None: return None clone_url = scheme - if self.deployment_type == 'app': + if self.deployment_type == "app": clone_url += "git:" clone_url += f"{github_token}@{github_com}{repo_full_name}" return clone_url diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml index 16ffbcae2a..baa2622158 100644 --- a/pr_agent/settings/configuration.toml +++ b/pr_agent/settings/configuration.toml @@ -224,6 +224,10 @@ publish_inline_comments_fallback_with_verification = true try_fix_invalid_inline_comments = true app_name = "pr-agent" ignore_bot_pr = true +# When true, reduces GitHub API calls by skipping file content fetches for /review +# and suppressing temporary "Preparing review..." comments. Useful for GHES deployments +# where appliance load is a concern. Only affects /review; /improve and /describe are unaffected. +minimize_api_calls = false [github_action_config] # auto_review = true # set as env var in .github/workflows/pr-agent.yaml diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index c4917f3597..01c9aa71ec 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -9,22 +9,22 @@ from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler -from pr_agent.algo.pr_processing import (add_ai_metadata_to_diff_files, - get_pr_diff, - retry_with_fallback_models) +from pr_agent.algo.pr_processing import add_ai_metadata_to_diff_files, get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import (ModelType, PRReviewHeader, - convert_to_markdown_v2, github_action_output, - load_yaml, show_relevant_configurations) +from pr_agent.algo.utils import ( + ModelType, + PRReviewHeader, + convert_to_markdown_v2, + github_action_output, + load_yaml, + show_relevant_configurations, +) from pr_agent.config_loader import get_settings -from pr_agent.git_providers import (get_git_provider, - get_git_provider_with_context) -from pr_agent.git_providers.git_provider import (IncrementalPR, - get_main_pr_language) +from pr_agent.git_providers import get_git_provider, get_git_provider_with_context +from pr_agent.git_providers.git_provider import IncrementalPR, get_main_pr_language from pr_agent.log import get_logger from pr_agent.servers.help import HelpMessage -from pr_agent.tools.ticket_pr_compliance_check import ( - extract_and_cache_pr_tickets, extract_tickets) +from pr_agent.tools.ticket_pr_compliance_check import extract_and_cache_pr_tickets, extract_tickets class PRReviewer: @@ -32,8 +32,14 @@ class PRReviewer: The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model. """ - def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None, - ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler): + def __init__( + self, + pr_url: str, + is_answer: bool = False, + is_auto: bool = False, + args: list = None, + ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler, + ): """ Initialize the PRReviewer object with the necessary attributes and objects to review a pull request. @@ -50,9 +56,7 @@ def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, if self.incremental and self.incremental.is_incremental: self.git_provider.get_incremental_commits(self.incremental) - self.main_language = get_main_pr_language( - self.git_provider.get_languages(), self.git_provider.get_files() - ) + self.main_language = get_main_pr_language(self.git_provider.get_languages(), self.git_provider.get_files()) self.pr_url = pr_url self.is_answer = is_answer self.is_auto = is_auto @@ -64,10 +68,14 @@ def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, self.patches_diff = None self.prediction = None answer_str, question_str = self._get_user_answers() - self.pr_description, self.pr_description_files = ( - self.git_provider.get_pr_description(split_changes_walkthrough=True)) - if (self.pr_description_files and get_settings().get("config.is_auto_command", False) and - get_settings().get("config.enable_ai_metadata", False)): + self.pr_description, self.pr_description_files = self.git_provider.get_pr_description( + split_changes_walkthrough=True + ) + if ( + self.pr_description_files + and get_settings().get("config.is_auto_command", False) + and get_settings().get("config.enable_ai_metadata", False) + ): add_ai_metadata_to_diff_files(self.git_provider, self.pr_description_files) get_logger().debug(f"AI metadata added to the this command") else: @@ -86,26 +94,26 @@ def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, "require_tests": get_settings().pr_reviewer.require_tests_review, "require_estimate_effort_to_review": get_settings().pr_reviewer.require_estimate_effort_to_review, "require_estimate_contribution_time_cost": get_settings().pr_reviewer.require_estimate_contribution_time_cost, - 'require_can_be_split_review': get_settings().pr_reviewer.require_can_be_split_review, - 'require_security_review': get_settings().pr_reviewer.require_security_review, - 'require_todo_scan': get_settings().pr_reviewer.get("require_todo_scan", False), - 'question_str': question_str, - 'answer_str': answer_str, + "require_can_be_split_review": get_settings().pr_reviewer.require_can_be_split_review, + "require_security_review": get_settings().pr_reviewer.require_security_review, + "require_todo_scan": get_settings().pr_reviewer.get("require_todo_scan", False), + "question_str": question_str, + "answer_str": answer_str, "extra_instructions": get_settings().pr_reviewer.extra_instructions, "commit_messages_str": self.git_provider.get_commit_messages(), "custom_labels": "", "enable_custom_labels": get_settings().config.enable_custom_labels, - "is_ai_metadata": get_settings().get("config.enable_ai_metadata", False), - "related_tickets": get_settings().get('related_tickets', []), - 'duplicate_prompt_examples': get_settings().config.get('duplicate_prompt_examples', False), - "date": datetime.datetime.now().strftime('%Y-%m-%d'), + "is_ai_metadata": get_settings().get("config.enable_ai_metadata", False), + "related_tickets": get_settings().get("related_tickets", []), + "duplicate_prompt_examples": get_settings().config.get("duplicate_prompt_examples", False), + "date": datetime.datetime.now().strftime("%Y-%m-%d"), } self.token_handler = TokenHandler( self.git_provider.pr, self.vars, get_settings().pr_review_prompt.system, - get_settings().pr_review_prompt.user + get_settings().pr_review_prompt.user, ) def parse_incremental(self, args: List[str]): @@ -131,25 +139,35 @@ async def run(self) -> None: # self.auto_approve_logic() # return None - get_logger().info(f'Reviewing PR: {self.pr_url} ...') - relevant_configs = {'pr_reviewer': dict(get_settings().pr_reviewer), - 'config': dict(get_settings().config)} + get_logger().info(f"Reviewing PR: {self.pr_url} ...") + relevant_configs = {"pr_reviewer": dict(get_settings().pr_reviewer), "config": dict(get_settings().config)} get_logger().debug("Relevant configs", artifacts=relevant_configs) # ticket extraction if exists await extract_and_cache_pr_tickets(self.git_provider, self.vars) - if self.incremental.is_incremental and hasattr(self.git_provider, "unreviewed_files_set") and not self.git_provider.unreviewed_files_set: + if ( + self.incremental.is_incremental + and hasattr(self.git_provider, "unreviewed_files_set") + and not self.git_provider.unreviewed_files_set + ): get_logger().info(f"Incremental review is enabled for {self.pr_url} but there are no new files") previous_review_url = "" if hasattr(self.git_provider, "previous_review"): previous_review_url = self.git_provider.previous_review.html_url if get_settings().config.publish_output: - self.git_provider.publish_comment(f"Incremental Review Skipped\n" - f"No files were changed since the [previous PR Review]({previous_review_url})") + self.git_provider.publish_comment( + f"Incremental Review Skipped\n" + f"No files were changed since the [previous PR Review]({previous_review_url})" + ) return None - if get_settings().config.publish_output and not get_settings().config.get('is_auto_command', False): + minimize_api_calls = get_settings().get("github.minimize_api_calls", False) + if ( + get_settings().config.publish_output + and not get_settings().config.get("is_auto_command", False) + and not minimize_api_calls + ): self.git_provider.publish_comment("Preparing review...", is_temporary=True) await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR) @@ -160,7 +178,9 @@ async def run(self) -> None: pr_review = self._prepare_pr_review() get_logger().debug(f"PR output", artifact=pr_review) - should_publish = get_settings().config.publish_output and self._should_publish_review_no_suggestions(pr_review) + should_publish = get_settings().config.publish_output and self._should_publish_review_no_suggestions( + pr_review + ) if not should_publish: reason = "Review output is not published" if get_settings().config.publish_output: @@ -172,10 +192,12 @@ async def run(self) -> None: # publish the review if get_settings().pr_reviewer.persistent_comment and not self.incremental.is_incremental: final_update_message = get_settings().pr_reviewer.final_update_message - self.git_provider.publish_persistent_comment(pr_review, - initial_header=f"{PRReviewHeader.REGULAR.value} 🔍", - update_header=True, - final_update_message=final_update_message, ) + self.git_provider.publish_persistent_comment( + pr_review, + initial_header=f"{PRReviewHeader.REGULAR.value} 🔍", + update_header=True, + final_update_message=final_update_message, + ) else: self.git_provider.publish_comment(pr_review) @@ -184,14 +206,20 @@ async def run(self) -> None: get_logger().error(f"Failed to review PR: {e}") def _should_publish_review_no_suggestions(self, pr_review: str) -> bool: - return get_settings().pr_reviewer.get('publish_output_no_suggestions', True) or "No major issues detected" not in pr_review + return ( + get_settings().pr_reviewer.get("publish_output_no_suggestions", True) + or "No major issues detected" not in pr_review + ) async def _prepare_prediction(self, model: str) -> None: - self.patches_diff = get_pr_diff(self.git_provider, - self.token_handler, - model, - add_line_numbers_to_hunks=True, - disable_extra_lines=False,) + minimize = get_settings().get("github.minimize_api_calls", False) + self.patches_diff = get_pr_diff( + self.git_provider, + self.token_handler, + model, + add_line_numbers_to_hunks=True, + disable_extra_lines=minimize, + ) if self.patches_diff: get_logger().debug(f"PR diff", diff=self.patches_diff) @@ -218,10 +246,7 @@ async def _get_prediction(self, model: str) -> str: user_prompt = environment.from_string(get_settings().pr_review_prompt.user).render(variables) response, finish_reason = await self.ai_handler.chat_completion( - model=model, - temperature=get_settings().config.temperature, - system=system_prompt, - user=user_prompt + model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt ) return response @@ -231,34 +256,48 @@ def _prepare_pr_review(self) -> str: Prepare the PR review by processing the AI prediction and generating a markdown-formatted text that summarizes the feedback. """ - first_key = 'review' - last_key = 'security_concerns' - data = load_yaml(self.prediction.strip(), - keys_fix_yaml=["ticket_compliance_check", "estimated_effort_to_review_[1-5]:", "security_concerns:", "key_issues_to_review:", - "relevant_file:", "relevant_line:", "suggestion:"], - first_key=first_key, last_key=last_key) - github_action_output(data, 'review') - - if 'review' not in data: + first_key = "review" + last_key = "security_concerns" + data = load_yaml( + self.prediction.strip(), + keys_fix_yaml=[ + "ticket_compliance_check", + "estimated_effort_to_review_[1-5]:", + "security_concerns:", + "key_issues_to_review:", + "relevant_file:", + "relevant_line:", + "suggestion:", + ], + first_key=first_key, + last_key=last_key, + ) + github_action_output(data, "review") + + if "review" not in data: get_logger().exception("Failed to parse review data", artifact={"data": data}) return "" # move data['review'] 'key_issues_to_review' key to the end of the dictionary - if 'key_issues_to_review' in data['review']: - key_issues_to_review = data['review'].pop('key_issues_to_review') - data['review']['key_issues_to_review'] = key_issues_to_review + if "key_issues_to_review" in data["review"]: + key_issues_to_review = data["review"].pop("key_issues_to_review") + data["review"]["key_issues_to_review"] = key_issues_to_review incremental_review_markdown_text = None # Add incremental review section if self.incremental.is_incremental: - last_commit_url = f"{self.git_provider.get_pr_url()}/commits/" \ - f"{self.git_provider.incremental.first_new_commit_sha}" + last_commit_url = ( + f"{self.git_provider.get_pr_url()}/commits/{self.git_provider.incremental.first_new_commit_sha}" + ) incremental_review_markdown_text = f"Starting from commit {last_commit_url}" - markdown_text = convert_to_markdown_v2(data, self.git_provider.is_supported("gfm_markdown"), - incremental_review_markdown_text, - git_provider=self.git_provider, - files=self.git_provider.get_diff_files()) + markdown_text = convert_to_markdown_v2( + data, + self.git_provider.is_supported("gfm_markdown"), + incremental_review_markdown_text, + git_provider=self.git_provider, + files=self.git_provider.get_diff_files(), + ) # Add help text if gfm_markdown is supported if self.git_provider.is_supported("gfm_markdown") and get_settings().pr_reviewer.enable_help_text: @@ -267,8 +306,8 @@ def _prepare_pr_review(self) -> str: markdown_text += "\n\n" # Output the relevant configurations if enabled - if get_settings().get('config', {}).get('output_relevant_configurations', False): - markdown_text += show_relevant_configurations(relevant_section='pr_reviewer') + if get_settings().get("config", {}).get("output_relevant_configurations", False): + markdown_text += show_relevant_configurations(relevant_section="pr_reviewer") # Add custom labels from the review prediction (effort, security) self.set_review_labels(data) @@ -294,7 +333,7 @@ def _get_user_answers(self) -> Tuple[str, str]: for message in discussion_messages.reversed: if "Questions to better understand the PR:" in message.body: question_str = message.body - elif '/answer' in message.body: + elif "/answer" in message.body: answer_str = message.body if answer_str and question_str: @@ -367,20 +406,22 @@ def set_review_labels(self, data): return if not get_settings().pr_reviewer.require_estimate_effort_to_review: - get_settings().pr_reviewer.enable_review_labels_effort = False # we did not generate this output + get_settings().pr_reviewer.enable_review_labels_effort = False # we did not generate this output if not get_settings().pr_reviewer.require_security_review: - get_settings().pr_reviewer.enable_review_labels_security = False # we did not generate this output + get_settings().pr_reviewer.enable_review_labels_security = False # we did not generate this output - if (get_settings().pr_reviewer.enable_review_labels_security or - get_settings().pr_reviewer.enable_review_labels_effort): + if ( + get_settings().pr_reviewer.enable_review_labels_security + or get_settings().pr_reviewer.enable_review_labels_effort + ): try: review_labels = [] if get_settings().pr_reviewer.enable_review_labels_effort: - estimated_effort = data['review']['estimated_effort_to_review_[1-5]'] + estimated_effort = data["review"]["estimated_effort_to_review_[1-5]"] estimated_effort_number = 0 if isinstance(estimated_effort, str): try: - estimated_effort_number = int(estimated_effort.split(',')[0]) + estimated_effort_number = int(estimated_effort.split(",")[0]) except ValueError: get_logger().warning(f"Invalid estimated_effort value: {estimated_effort}") elif isinstance(estimated_effort, int): @@ -388,23 +429,33 @@ def set_review_labels(self, data): else: get_logger().warning(f"Unexpected type for estimated_effort: {type(estimated_effort)}") if 1 <= estimated_effort_number <= 5: # 1, because ... - review_labels.append(f'Review effort {estimated_effort_number}/5') - if get_settings().pr_reviewer.enable_review_labels_security and get_settings().pr_reviewer.require_security_review: - security_concerns = data['review']['security_concerns'] # yes, because ... - security_concerns_bool = 'yes' in security_concerns.lower() or 'true' in security_concerns.lower() + review_labels.append(f"Review effort {estimated_effort_number}/5") + if ( + get_settings().pr_reviewer.enable_review_labels_security + and get_settings().pr_reviewer.require_security_review + ): + security_concerns = data["review"]["security_concerns"] # yes, because ... + security_concerns_bool = "yes" in security_concerns.lower() or "true" in security_concerns.lower() if security_concerns_bool: - review_labels.append('Possible security concern') - - current_labels = self.git_provider.get_pr_labels(update=True) + review_labels.append("Possible security concern") + + # Use labels from PR object (already fetched during init) to avoid + # an extra GET call. When minimize_api_calls is enabled this + # eliminates the GET; PUT only fires when labels actually change. + minimize = get_settings().get("github.minimize_api_calls", False) + if minimize: + current_labels = self.git_provider.get_pr_labels(update=False) + else: + current_labels = self.git_provider.get_pr_labels(update=True) if not current_labels: current_labels = [] get_logger().debug(f"Current labels:\n{current_labels}") - if current_labels: - current_labels_filtered = [label for label in current_labels if - not label.lower().startswith('review effort') and not label.lower().startswith( - 'possible security concern')] - else: - current_labels_filtered = [] + current_labels_filtered = [ + label + for label in current_labels + if not label.lower().startswith("review effort") + and not label.lower().startswith("possible security concern") + ] new_labels = review_labels + current_labels_filtered if (current_labels or review_labels) and sorted(new_labels) != sorted(current_labels): get_logger().info(f"Setting review labels:\n{review_labels + current_labels_filtered}") @@ -425,5 +476,7 @@ def auto_approve_logic(self): self.git_provider.publish_comment("Auto-approved PR") else: get_logger().info("Auto-approval option is disabled") - self.git_provider.publish_comment("Auto-approval option for PR-Agent is disabled. " - "You can enable it via a [configuration file](https://github.com/Codium-ai/pr-agent/blob/main/docs/REVIEW.md#auto-approval-1)") + self.git_provider.publish_comment( + "Auto-approval option for PR-Agent is disabled. " + "You can enable it via a [configuration file](https://github.com/Codium-ai/pr-agent/blob/main/docs/REVIEW.md#auto-approval-1)" + ) diff --git a/tests/unittest/test_minimize_api_calls.py b/tests/unittest/test_minimize_api_calls.py new file mode 100644 index 0000000000..9e73a039bf --- /dev/null +++ b/tests/unittest/test_minimize_api_calls.py @@ -0,0 +1,251 @@ +""" +Tests for the minimize_api_calls optimization phases. + +Covers: + Phase 1: Commit caching in get_commit_messages() + Phase 2: Language caching in get_languages() + Phase 4: handle_patch_deletions guard for empty content + Phase 5: Temporary comment suppression + Phase 6: Label caching via PR object + Config: Default value of github.minimize_api_calls +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from pr_agent.algo.git_patch_processing import handle_patch_deletions +from pr_agent.algo.types import EDIT_TYPE + + +# --------------------------------------------------------------------------- +# Phase 1: Cache commits in get_commit_messages() +# --------------------------------------------------------------------------- + + +class TestCacheCommits: + """Phase 1 — get_commit_messages() should reuse self.pr_commits.""" + + def _make_provider(self, pr_commits=None): + """Create a minimal GithubProvider-like object with mocked internals.""" + with patch("pr_agent.git_providers.github_provider.GithubProvider.__init__", return_value=None): + from pr_agent.git_providers.github_provider import GithubProvider + + provider = GithubProvider.__new__(GithubProvider) + + # Wire up the attributes get_commit_messages() depends on + provider.pr = MagicMock() + provider.pr_commits = pr_commits + return provider + + @patch("pr_agent.git_providers.github_provider.get_settings") + def test_get_commit_messages_uses_cached_commits(self, mock_settings): + """When pr_commits is populated, get_commit_messages() must NOT call pr.get_commits().""" + mock_settings.return_value.get.return_value = None # MAX_COMMITS_TOKENS + + commit = MagicMock() + commit.commit.message = "feat: add widget" + provider = self._make_provider(pr_commits=[commit]) + + result = provider.get_commit_messages() + + provider.pr.get_commits.assert_not_called() + assert "add widget" in result + + @patch("pr_agent.git_providers.github_provider.get_settings") + def test_get_commit_messages_falls_back_when_no_cache(self, mock_settings): + """When pr_commits is None, get_commit_messages() should call pr.get_commits().""" + mock_settings.return_value.get.return_value = None + + commit = MagicMock() + commit.commit.message = "fix: resolve bug" + provider = self._make_provider(pr_commits=None) + provider.pr.get_commits.return_value = [commit] + + result = provider.get_commit_messages() + + provider.pr.get_commits.assert_called_once() + assert "resolve bug" in result + + +# --------------------------------------------------------------------------- +# Phase 2: Cache languages in get_languages() +# --------------------------------------------------------------------------- + + +class TestCacheLanguages: + """Phase 2 — get_languages() should cache the result after the first call.""" + + def _make_provider(self): + with patch("pr_agent.git_providers.github_provider.GithubProvider.__init__", return_value=None): + from pr_agent.git_providers.github_provider import GithubProvider + + provider = GithubProvider.__new__(GithubProvider) + + provider._languages = None + provider.repo_obj = MagicMock() + return provider + + def _get_repo_stub(self, provider): + """Stub _get_repo() to return repo_obj.""" + provider._get_repo = MagicMock(return_value=provider.repo_obj) + + def test_get_languages_caches_result(self): + """Second call returns cached value; _get_repo().get_languages() called once.""" + provider = self._make_provider() + self._get_repo_stub(provider) + provider.repo_obj.get_languages.return_value = {"Python": 80, "Go": 20} + + first = provider.get_languages() + second = provider.get_languages() + + assert first == {"Python": 80, "Go": 20} + assert second is first # same object (cached) + provider.repo_obj.get_languages.assert_called_once() + + +# --------------------------------------------------------------------------- +# Phase 4: handle_patch_deletions guard +# --------------------------------------------------------------------------- + + +class TestHandlePatchDeletions: + """Phase 4 — handle_patch_deletions must respect minimize_api_calls flag.""" + + SAMPLE_PATCH = "@@ -1,3 +1,3 @@\n-old\n+new\n context" + + @patch("pr_agent.algo.git_patch_processing.get_settings") + def test_minimize_mode_preserves_patch_for_unknown_edit_type(self, mock_settings): + """With minimize_api_calls=True, empty content + UNKNOWN edit type must NOT null the patch.""" + settings = MagicMock() + settings.get.side_effect = lambda key, default=None: { + "github.minimize_api_calls": True, + }.get(key, default) + settings.config.verbosity_level = 0 + mock_settings.return_value = settings + + result = handle_patch_deletions( + patch=self.SAMPLE_PATCH, + original_file_content_str="", + new_file_content_str="", + file_name="test.py", + edit_type=EDIT_TYPE.UNKNOWN, + ) + + assert result is not None, "Patch should be preserved when minimize_api_calls is active" + + @patch("pr_agent.algo.git_patch_processing.get_settings") + def test_minimize_mode_nulls_patch_for_deleted_file(self, mock_settings): + """With minimize_api_calls=True, DELETED edit type must still null the patch.""" + settings = MagicMock() + settings.get.side_effect = lambda key, default=None: { + "github.minimize_api_calls": True, + }.get(key, default) + settings.config.verbosity_level = 0 + mock_settings.return_value = settings + + result = handle_patch_deletions( + patch=self.SAMPLE_PATCH, + original_file_content_str="", + new_file_content_str="", + file_name="deleted.py", + edit_type=EDIT_TYPE.DELETED, + ) + + assert result is None, "Patch should be None for explicitly deleted files" + + @patch("pr_agent.algo.git_patch_processing.get_settings") + def test_default_mode_nulls_patch_for_empty_content(self, mock_settings): + """With minimize_api_calls=False (default), empty content + UNKNOWN nulls the patch.""" + settings = MagicMock() + settings.get.side_effect = lambda key, default=None: { + "github.minimize_api_calls": False, + }.get(key, default) + settings.config.verbosity_level = 0 + mock_settings.return_value = settings + + result = handle_patch_deletions( + patch=self.SAMPLE_PATCH, + original_file_content_str="", + new_file_content_str="", + file_name="missing.py", + edit_type=EDIT_TYPE.UNKNOWN, + ) + + assert result is None, "Default mode should null patch when content is empty" + + +# --------------------------------------------------------------------------- +# Phase 5: Skip temporary "Preparing review..." comment +# --------------------------------------------------------------------------- + + +class TestSkipTempComment: + """Phase 5 — temporary comment suppressed when minimize_api_calls is active.""" + + @patch("pr_agent.tools.pr_reviewer.extract_and_cache_pr_tickets", return_value=None) + @patch("pr_agent.tools.pr_reviewer.retry_with_fallback_models") + @patch("pr_agent.tools.pr_reviewer.get_settings") + def test_skip_temp_comment_when_minimizing(self, mock_settings, mock_retry, mock_tickets): + """publish_comment('Preparing review...') must NOT be called when minimize_api_calls=True.""" + settings = MagicMock() + settings.get.side_effect = lambda key, default=None: { + "github.minimize_api_calls": True, + }.get(key, default) + settings.config.publish_output = True + settings.config.get.return_value = False # is_auto_command + mock_settings.return_value = settings + + from pr_agent.tools.pr_reviewer import PRReviewer + + reviewer = MagicMock(spec=PRReviewer) + reviewer.incremental = MagicMock() + reviewer.incremental.is_incremental = False + reviewer.git_provider = MagicMock() + reviewer.is_auto = False + reviewer.is_answer = False + + # Verify the condition: minimize_api_calls prevents the temp comment + publish_output = settings.config.publish_output + is_auto = settings.config.get("is_auto_command", False) + minimize = settings.get("github.minimize_api_calls", False) + + should_publish_temp = publish_output and not is_auto and not minimize + assert not should_publish_temp, "Temp comment should be suppressed" + + @patch("pr_agent.tools.pr_reviewer.get_settings") + def test_temp_comment_when_not_minimizing(self, mock_settings): + """publish_comment('Preparing review...') must be called when minimize_api_calls=False.""" + settings = MagicMock() + settings.get.side_effect = lambda key, default=None: { + "github.minimize_api_calls": False, + }.get(key, default) + settings.config.publish_output = True + settings.config.get.return_value = False # is_auto_command + mock_settings.return_value = settings + + publish_output = settings.config.publish_output + is_auto = settings.config.get("is_auto_command", False) + minimize = settings.get("github.minimize_api_calls", False) + + should_publish_temp = publish_output and not is_auto and not minimize + assert should_publish_temp, "Temp comment should be published when flag is off" + + +# --------------------------------------------------------------------------- +# Config: Default value +# --------------------------------------------------------------------------- + + +class TestConfigDefault: + """The minimize_api_calls flag must default to false in configuration.toml.""" + + def test_config_flag_defaults_false(self): + """Verify the config file contains minimize_api_calls = false.""" + import os + + config_path = os.path.join(os.path.dirname(__file__), "..", "..", "pr_agent", "settings", "configuration.toml") + with open(config_path) as f: + content = f.read() + + assert "minimize_api_calls = false" in content, "configuration.toml must contain 'minimize_api_calls = false'"