Skip to content

Commit f8eb995

Browse files
authored
Merge pull request #1 from setkyar/feature/async
Add pytest-asyncio dependency and refactor review process to support …
2 parents 8b4b4f9 + 0311956 commit f8eb995

File tree

7 files changed

+281
-85
lines changed

7 files changed

+281
-85
lines changed

aireview/ai_reviewer.py

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""Module for handling AI review generation."""
22
import click
3+
import asyncio
34
from dataclasses import dataclass
4-
from openai import OpenAI
5+
from openai import AsyncOpenAI
56
from typing import List, Optional
67
from .git_handler import FileChange
78

@@ -13,44 +14,74 @@ class Review:
1314

1415
class AIReviewer:
1516
def __init__(self, model: str, api_key: str, base_url: Optional[str] = None):
16-
self.client = OpenAI(
17+
self.client = AsyncOpenAI(
1718
api_key=api_key,
1819
base_url=base_url if base_url else None
1920
)
2021
self.model = model
2122

22-
def review_changes(self, changes: List[FileChange],
23-
project_context: str, prompt_template: str) -> List[Review]:
24-
"""Generate AI reviews for all file changes."""
25-
reviews = []
23+
async def review_changes(self, changes: List[FileChange],
24+
project_context: str, prompt_template: str) -> List[Review]:
25+
"""Generate AI reviews for all file changes in parallel."""
26+
click.echo(f"Generating reviews for {len(changes)} files...")
27+
28+
# Create tasks for all reviews
29+
tasks = []
2630
for change in changes:
27-
click.echo(f"Generating review for {change.filename}...")
31+
click.echo(f"Starting review for {change.filename}...")
2832

29-
review_content = self._get_review(
30-
change.content,
31-
change.filename,
32-
project_context,
33-
prompt_template
33+
# Create prompt with filename included
34+
prompt = self._create_prompt(
35+
changes=change.content,
36+
filename=change.filename,
37+
file_content=change.file_content,
38+
project_context=project_context,
39+
prompt_template=prompt_template
3440
)
35-
reviews.append(Review(filename=change.filename, content=review_content))
41+
42+
# Create task for this review
43+
tasks.append(self._get_review(prompt, change.filename))
44+
45+
# Run all reviews concurrently
46+
review_contents = await asyncio.gather(*tasks)
47+
48+
# Create Review objects from results
49+
reviews = [
50+
Review(filename=changes[i].filename, content=content)
51+
for i, content in enumerate(review_contents)
52+
]
53+
3654
return reviews
3755

38-
def _get_review(self, changes: str, filename: str,
39-
project_context: str, prompt_template: str) -> str:
40-
"""Get AI review for a single file's changes."""
41-
prompt = f"""{project_context}
56+
def _create_prompt(self, changes: str, filename: str,
57+
file_content: Optional[str], project_context: str,
58+
prompt_template: str) -> str:
59+
"""Create the prompt for a single file review."""
60+
# Include file content in the prompt if available
61+
file_content_section = ""
62+
if file_content:
63+
file_content_section = f"""
64+
Current file content:
65+
```
66+
{file_content}
67+
```
68+
"""
69+
70+
return f"""{project_context}
4271
4372
{prompt_template}
4473
4574
Review the following changes in {filename}:
4675
```
4776
{changes}
4877
```
49-
78+
{file_content_section}
5079
Please focus your review on these specific changes."""
51-
80+
81+
async def _get_review(self, prompt: str, filename: str) -> str:
82+
"""Get AI review for the provided prompt."""
5283
try:
53-
completion = self.client.chat.completions.create(
84+
completion = await self.client.chat.completions.create(
5485
model=self.model,
5586
n=1,
5687
messages=[
@@ -59,6 +90,7 @@ def _get_review(self, changes: str, filename: str,
5990
]
6091
)
6192

93+
click.echo(f"Completed review for {filename}")
6294
return f"## Review for changes in {filename}\n\n{completion.choices[0].message.content}"
6395
except Exception as e:
6496
raise RuntimeError(f"OpenAI API error for {filename}: {str(e)}")

aireview/git_handler.py

Lines changed: 109 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
"""Module for handling Git operations."""
22
import subprocess
33
from dataclasses import dataclass
4-
from typing import List, Optional
4+
from typing import List, Optional, Dict
5+
import os
56

67
@dataclass
78
class FileChange:
89
"""Represents changes in a single file."""
910
filename: str
1011
content: str
12+
file_content: Optional[str] = None
1113

1214
class GitHandler:
1315
@staticmethod
1416
def get_file_changes() -> List[FileChange]:
15-
"""Retrieves only staged changes from Git (after git add)."""
17+
"""Retrieves staged changes from Git and their corresponding file content efficiently."""
1618
try:
1719
# Get staged changes
1820
staged_cmd = subprocess.run(
@@ -22,12 +24,115 @@ def get_file_changes() -> List[FileChange]:
2224

2325
if not staged_cmd.stdout:
2426
return []
25-
26-
return GitHandler._parse_diff_output(staged_cmd.stdout)
27+
28+
# Parse the diff output first
29+
changes = GitHandler._parse_diff_output(staged_cmd.stdout)
30+
31+
# Get the list of files we need content for
32+
files_to_fetch = [change.filename for change in changes]
33+
34+
# Batch fetch file contents
35+
file_contents = GitHandler._batch_get_file_contents(files_to_fetch)
36+
37+
# Update FileChange objects with their content
38+
for change in changes:
39+
change.file_content = file_contents.get(change.filename)
40+
41+
return changes
2742

2843
except subprocess.CalledProcessError as e:
2944
raise RuntimeError(f"Git command failed: {e.stderr}")
3045

46+
@staticmethod
47+
def _batch_get_file_contents(filenames: List[str]) -> Dict[str, Optional[str]]:
48+
"""
49+
Efficiently get contents of multiple files using git cat-file --batch.
50+
Returns a dictionary mapping filenames to their content.
51+
"""
52+
if not filenames:
53+
return {}
54+
55+
try:
56+
# Get object IDs for staged versions of files
57+
file_revs = {}
58+
for filename in filenames:
59+
try:
60+
rev_cmd = subprocess.run(
61+
['git', 'rev-parse', f':"{filename}"'],
62+
capture_output=True, text=True, check=True
63+
)
64+
file_revs[filename] = rev_cmd.stdout.strip()
65+
except subprocess.CalledProcessError:
66+
# File might be new/deleted
67+
file_revs[filename] = None
68+
69+
# Prepare batch input
70+
valid_revs = {f: rev for f, rev in file_revs.items() if rev is not None}
71+
if not valid_revs:
72+
return {f: None for f in filenames}
73+
74+
# Start git cat-file --batch process
75+
process = subprocess.Popen(
76+
['git', 'cat-file', '--batch'],
77+
stdin=subprocess.PIPE,
78+
stdout=subprocess.PIPE,
79+
stderr=subprocess.PIPE
80+
)
81+
82+
# Write object IDs to git cat-file
83+
input_data = '\n'.join(valid_revs.values()) + '\n'
84+
stdout, stderr = process.communicate(input_data.encode())
85+
86+
if process.returncode != 0:
87+
raise subprocess.CalledProcessError(
88+
process.returncode, 'git cat-file', stderr
89+
)
90+
91+
# Parse the output
92+
contents = {}
93+
current_content = []
94+
current_file = None
95+
rev_to_file = {rev: f for f, rev in valid_revs.items()}
96+
97+
for line in stdout.decode().split('\n'):
98+
if line.strip() and ' blob ' in line:
99+
# New blob header - save previous content if any
100+
if current_file and current_content:
101+
contents[current_file] = ''.join(current_content)
102+
current_content = []
103+
104+
# Get filename for this blob
105+
obj_id = line.split()[0]
106+
current_file = rev_to_file.get(obj_id)
107+
else:
108+
current_content.append(line + '\n')
109+
110+
# Save last file's content
111+
if current_file and current_content:
112+
contents[current_file] = ''.join(current_content)
113+
114+
# Include None for files that weren't found
115+
return {f: contents.get(f) for f in filenames}
116+
117+
except Exception as e:
118+
# If batch operation fails, fall back to individual git show commands
119+
return GitHandler._fallback_get_file_contents(filenames)
120+
121+
@staticmethod
122+
def _fallback_get_file_contents(filenames: List[str]) -> Dict[str, Optional[str]]:
123+
"""Fallback method to get file contents using git show."""
124+
contents = {}
125+
for filename in filenames:
126+
try:
127+
show_cmd = subprocess.run(
128+
['git', 'show', f':{filename}'],
129+
capture_output=True, text=True, check=True
130+
)
131+
contents[filename] = show_cmd.stdout
132+
except subprocess.CalledProcessError:
133+
contents[filename] = None
134+
return contents
135+
31136
@staticmethod
32137
def _parse_diff_output(diff_output: str) -> List[FileChange]:
33138
"""Parse git diff output into FileChange objects."""

aireview/main.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Main module for the AI code review tool."""
22
import click
33
import logging
4+
import asyncio
45
from typing import List
56
from .config import ConfigLoader
67
from .git_handler import GitHandler
@@ -47,11 +48,12 @@ def main(config: str):
4748
base_url=ai_config.base_url
4849
)
4950

50-
reviews = reviewer.review_changes(
51+
# Run the async review process
52+
reviews = asyncio.run(reviewer.review_changes(
5153
file_changes,
5254
review_config.project_context,
5355
review_config.prompt_template
54-
)
56+
))
5557

5658
# Write output
5759
write_reviews(reviews, review_config.output_file)
@@ -61,4 +63,7 @@ def main(config: str):
6163
except Exception as e:
6264
click.echo(f"Error: {str(e)}", err=True)
6365
logging.error(f"Error: {str(e)}")
64-
return
66+
return
67+
68+
if __name__ == '__main__':
69+
main()

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def get_version_from_git():
8484
'pytest>=7.0.0,<8.0.0',
8585
'pytest-cov>=4.1.0,<5.0.0',
8686
'pytest-mock>=3.10.0',
87+
'pytest-asyncio>=0.23.0',
8788
],
8889
},
8990
)

tests/test_ai_reviewer.py

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,87 @@
11
import pytest
2-
from unittest.mock import Mock, patch
2+
from unittest.mock import Mock, patch, AsyncMock
33
from aireview.ai_reviewer import AIReviewer
44
from aireview.git_handler import FileChange
55

66
@pytest.fixture
77
def mock_openai():
8-
"""Mock OpenAI client responses."""
9-
with patch('aireview.ai_reviewer.OpenAI') as mock:
8+
"""Mock AsyncOpenAI client responses."""
9+
with patch('aireview.ai_reviewer.AsyncOpenAI') as mock:
1010
mock_client = Mock()
11-
mock_client.chat.completions.create.return_value = Mock(
12-
choices=[Mock(message=Mock(content="Test review content"))]
11+
# Use AsyncMock for async methods
12+
mock_client.chat = Mock()
13+
mock_client.chat.completions = Mock()
14+
mock_client.chat.completions.create = AsyncMock(
15+
return_value=Mock(
16+
choices=[Mock(message=Mock(content="Mock review content"))]
17+
)
1318
)
1419
mock.return_value = mock_client
1520
yield mock
1621

17-
def test_review_changes(mock_openai):
18-
"""Test AI review generation."""
22+
@pytest.mark.asyncio
23+
async def test_review_changes_with_file_content(mock_openai):
24+
"""Test AI review generation with file content."""
1925
reviewer = AIReviewer("test-model", "test-key")
2026
changes = [
2127
FileChange(
2228
filename="test.py",
23-
content="Added: print('hello world')"
29+
content="Added: print('hello world')",
30+
file_content="print('hello')\nprint('hello world')"
2431
)
2532
]
2633

27-
reviews = reviewer.review_changes(
34+
reviews = await reviewer.review_changes(
2835
changes,
2936
project_context="Test context",
3037
prompt_template="Test template"
3138
)
3239

40+
# Verify the review was created
3341
assert len(reviews) == 1
3442
assert reviews[0].filename == "test.py"
35-
assert "Test review content" in reviews[0].content
43+
44+
# Verify the prompt sent to the API includes file content
45+
api_call_args = mock_openai.return_value.chat.completions.create.call_args
46+
prompt_sent = api_call_args[1]['messages'][1]['content']
47+
assert "Current file content:" in prompt_sent
48+
assert "print('hello')" in prompt_sent
49+
50+
@pytest.mark.asyncio
51+
async def test_review_changes_without_file_content(mock_openai):
52+
"""Test AI review generation without file content."""
53+
reviewer = AIReviewer("test-model", "test-key")
54+
changes = [
55+
FileChange(
56+
filename="test.py",
57+
content="Added: print('hello world')",
58+
file_content=None # No file content
59+
)
60+
]
61+
62+
reviews = await reviewer.review_changes(
63+
changes,
64+
project_context="Test context",
65+
prompt_template="Test template"
66+
)
67+
68+
# Verify the review was created
69+
assert len(reviews) == 1
70+
assert reviews[0].filename == "test.py"
71+
72+
# Verify the prompt sent to the API doesn't include file content section
73+
api_call_args = mock_openai.return_value.chat.completions.create.call_args
74+
prompt_sent = api_call_args[1]['messages'][1]['content']
75+
assert "Current file content:" not in prompt_sent
76+
assert changes[0].content in prompt_sent
3677

37-
def test_review_changes_api_error(mock_openai):
78+
@pytest.mark.asyncio
79+
async def test_review_changes_api_error(mock_openai):
3880
"""Test handling of API errors during review."""
3981
mock_openai.return_value.chat.completions.create.side_effect = Exception("API Error")
4082

4183
reviewer = AIReviewer("test-model", "test-key")
4284
changes = [FileChange(filename="test.py", content="test content")]
4385

4486
with pytest.raises(RuntimeError, match="OpenAI API error"):
45-
reviewer.review_changes(changes, "", "")
87+
await reviewer.review_changes(changes, "", "")

0 commit comments

Comments
 (0)