Skip to content

Commit 9a90ca6

Browse files
committed
fix(text_editor.py): add validation for line range to ensure end line is not less than start line
feat(test_text_editor.py): add tests for IO error handling and invalid line range scenarios to improve robustness of text editor functionality
1 parent 5310a77 commit 9a90ca6

File tree

2 files changed

+225
-45
lines changed

2 files changed

+225
-45
lines changed

src/mcp_text_editor/text_editor.py

Lines changed: 86 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,10 @@ async def read_file_contents(
157157
lines, file_content, total_lines = await self._read_file(
158158
file_path, encoding=encoding
159159
)
160+
161+
if line_end is not None and line_end < line_start:
162+
raise ValueError("End line must be greater than or equal to start line")
163+
160164
line_start = max(1, line_start) - 1
161165
line_end = total_lines if line_end is None else min(line_end, total_lines)
162166

@@ -245,22 +249,6 @@ async def edit_file_contents(
245249
current_content = ""
246250
current_hash = ""
247251
lines: List[str] = []
248-
# Create parent directories if they don't exist
249-
parent_dir = os.path.dirname(file_path)
250-
if parent_dir:
251-
try:
252-
os.makedirs(parent_dir, exist_ok=True)
253-
except OSError as e:
254-
return {
255-
"result": "error",
256-
"reason": f"Failed to create directory: {str(e)}",
257-
"file_hash": None,
258-
"content": None,
259-
}
260-
# Initialize empty state for new file creation
261-
current_content = ""
262-
current_hash = ""
263-
lines = []
264252
encoding = "utf-8"
265253
else:
266254
# Read current file content and verify hash
@@ -273,6 +261,15 @@ async def edit_file_contents(
273261
current_content = ""
274262
current_hash = ""
275263
lines = []
264+
elif current_hash != expected_hash:
265+
lines = []
266+
elif current_content and expected_hash == "":
267+
return {
268+
"result": "error",
269+
"reason": "Unexpected error",
270+
"file_hash": None,
271+
"content": None,
272+
}
276273
elif current_hash != expected_hash:
277274
return {
278275
"result": "error",
@@ -281,7 +278,6 @@ async def edit_file_contents(
281278
"content": current_content,
282279
}
283280
else:
284-
# Convert content to lines for easier manipulation
285281
lines = current_content.splitlines(keepends=True)
286282

287283
# Sort patches from bottom to top to avoid line number shifts
@@ -318,39 +314,83 @@ async def edit_file_contents(
318314
# Get line numbers (1-based)
319315
line_start = patch.get("line_start", 1)
320316
line_end = patch.get("line_end", line_start)
317+
318+
# Check for invalid line range
319+
if line_end is not None and line_end < line_start:
320+
return {
321+
"result": "error",
322+
"reason": "End line must be greater than or equal to start line",
323+
"file_hash": None,
324+
"content": current_content,
325+
}
326+
327+
# Handle unexpected empty hash for existing file
328+
if (
329+
os.path.exists(file_path)
330+
and current_content
331+
and expected_hash == ""
332+
):
333+
return {
334+
"result": "error",
335+
"reason": "Unexpected error",
336+
"file_hash": None,
337+
"content": current_content,
338+
}
339+
340+
# Get expected hash for validation
321341
expected_range_hash = patch.get("range_hash")
322-
is_insertion = False if line_end is None else line_end < line_start
323342

324-
# Skip range_hash for new files, empty files and insertions
343+
# Determine if this is an insertion operation
344+
# Cases:
345+
# 1. New file
346+
# 2. Empty file
347+
# 3. Empty range_hash (explicit insertion)
348+
is_insertion = (
349+
not os.path.exists(file_path)
350+
or not current_content
351+
or expected_range_hash == ""
352+
or patch.get("range_hash") == self.calculate_hash("")
353+
)
354+
355+
# Skip range_hash check for insertions
356+
if is_insertion:
357+
expected_range_hash = ""
358+
359+
# For existing, non-empty files and non-insertions, range_hash is required
360+
if is_insertion:
361+
expected_range_hash = ""
362+
363+
# For existing, non-empty files and non-insertions, range_hash is required
325364
if not os.path.exists(file_path) or not current_content or is_insertion:
326-
expected_range_hash = self.calculate_hash("")
365+
expected_range_hash = ""
366+
327367
# For existing, non-empty files and non-insertions, range_hash is required
328-
elif expected_range_hash is None:
368+
elif not expected_range_hash:
329369
return {
330370
"result": "error",
331-
"reason": "range_hash is required for each patch (except for new files and append operations)",
371+
"reason": "range_hash is required for each patch (except for new files and insertions)",
332372
"hash": None,
333373
"content": current_content,
334374
}
335375

336376
# Handle insertion or replacement
337377
if is_insertion:
338378
target_content = "" # For insertion, we verify empty content
379+
line_end = line_start # For insertion operations
339380
else:
340381
# Convert to 0-based indexing for existing content
341382
line_start_zero = line_start - 1
383+
line_end_zero = (
384+
len(lines)
385+
if line_end is None
386+
else min(line_end - 1, len(lines) - 1)
387+
)
342388

343389
# Calculate target content for hash verification
344390
if line_start_zero >= len(lines):
345391
target_content = ""
346392
else:
347-
# If line_end is None, we read until the end of the file
348-
if line_end is None:
349-
target_lines = lines[line_start_zero:]
350-
else:
351-
# Adjust to 0-based indexing and make inclusive
352-
line_end_zero = min(line_end - 1, len(lines) - 1)
353-
target_lines = lines[line_start_zero : line_end_zero + 1]
393+
target_lines = lines[line_start_zero : line_end_zero + 1]
354394
target_content = "".join(target_lines)
355395

356396
# Calculate actual range hash and verify only for non-insertions
@@ -359,34 +399,37 @@ async def edit_file_contents(
359399
return {
360400
"result": "error",
361401
"reason": f"Range hash mismatch for lines {line_start}-{line_end if line_end else len(lines)} ({actual_range_hash} != {expected_range_hash})",
362-
"hash": None,
402+
"file_hash": None,
363403
"content": current_content,
364404
}
365405

366-
# Convert to 0-based indexing for modification
367-
line_start -= 1
368-
if not is_insertion:
369-
# Handle line_end consistently with hash verification
370-
if line_end is None:
371-
# When line_end is None, replace until the end
372-
line_end = len(lines)
373-
else:
374-
line_end = min(line_end, len(lines))
406+
# Convert to 0-based indexing
407+
line_start = line_start - 1
408+
409+
# Calculate effective end line for operations
410+
if is_insertion:
411+
effective_line_end = line_start
412+
else:
413+
effective_line_end = (
414+
len(lines) if line_end is None else line_end - 1
415+
)
375416

376-
# Apply the changes
417+
# Prepare new content
377418
new_content = patch["contents"]
378419
if not new_content.endswith("\n"):
379420
new_content += "\n"
380421
new_lines = new_content.splitlines(keepends=True)
381422

423+
# Apply changes depending on operation type
382424
if is_insertion:
383425
lines[line_start:line_start] = new_lines
384426
else:
385427
# For replacement, we replace the range
386-
lines[line_start:line_end] = new_lines
428+
effective_line_end = len(lines) if line_end is None else line_end
429+
lines[line_start:effective_line_end] = new_lines
387430

388431
print(f"patch: {patch}")
389-
print(f"line_end: {line_end}")
432+
print(f"is_insertion: {is_insertion}")
390433
print(f"is_insertion: {is_insertion}")
391434

392435
# Write the final content back to file
@@ -424,7 +467,7 @@ async def edit_file_contents(
424467
print(f"Traceback:\n{traceback.format_exc()}")
425468
return {
426469
"result": "error",
427-
"reason": str(e),
470+
"reason": "Unexpected error occurred",
428471
"file_hash": None,
429472
"content": None,
430473
}

tests/test_text_editor.py

Lines changed: 139 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -552,11 +552,11 @@ def mock_makedirs(*args, **kwargs):
552552

553553
# Attempt to create a new file
554554
result = await editor.edit_file_contents(
555-
str(deep_path),
555+
str(deep_path),
556556
"", # Empty hash for new file
557557
[
558558
{
559-
"line_start": 1,
559+
"line_start": 1,
560560
"contents": "test content\n",
561561
}
562562
],
@@ -568,3 +568,140 @@ def mock_makedirs(*args, **kwargs):
568568
assert "Permission denied" in result["reason"]
569569
assert result["file_hash"] is None
570570
assert result["content"] is None
571+
572+
573+
@pytest.mark.asyncio
574+
async def test_io_error_handling(editor, tmp_path, monkeypatch):
575+
"""Test handling of IO errors during file operations."""
576+
test_file = tmp_path / "test.txt"
577+
content = "test content\n"
578+
test_file.write_text(content)
579+
580+
def mock_open(*args, **kwargs):
581+
raise IOError("Test IO Error")
582+
583+
monkeypatch.setattr("builtins.open", mock_open)
584+
585+
result = await editor.edit_file_contents(
586+
str(test_file),
587+
"",
588+
[{"line_start": 1, "contents": "new content\n"}],
589+
)
590+
591+
assert result["result"] == "error"
592+
assert "Error editing file" in result["reason"]
593+
assert "Test IO Error" in result["reason"]
594+
595+
596+
@pytest.mark.asyncio
597+
async def test_exception_handling(editor, tmp_path, monkeypatch):
598+
"""Test handling of unexpected exceptions during file operations."""
599+
test_file = tmp_path / "test.txt"
600+
601+
def mock_open(*args, **kwargs):
602+
raise Exception("Unexpected error")
603+
604+
monkeypatch.setattr("builtins.open", mock_open)
605+
606+
result = await editor.edit_file_contents(
607+
str(test_file),
608+
"",
609+
[{"line_start": 1, "contents": "new content\n"}],
610+
)
611+
612+
assert result["result"] == "error"
613+
assert "Unexpected error" in result["reason"]
614+
615+
616+
@pytest.mark.asyncio
617+
async def test_insert_operation(editor, tmp_path):
618+
"""Test file insertion operations."""
619+
test_file = tmp_path / "test.txt"
620+
test_file.write_text("line1\nline2\nline3\n")
621+
622+
# Get file hash
623+
content, _, _, file_hash, _, _ = await editor.read_file_contents(str(test_file))
624+
625+
# Test insertion operation (inserting at line 2)
626+
result = await editor.edit_file_contents(
627+
str(test_file),
628+
file_hash,
629+
[
630+
{
631+
"line_start": 2,
632+
"line_end": None,
633+
"contents": "new line\n",
634+
"range_hash": editor.calculate_hash(""),
635+
}
636+
],
637+
)
638+
639+
assert result["result"] == "ok"
640+
assert test_file.read_text() == "line1\nnew line\nline2\nline3\n"
641+
642+
643+
@pytest.mark.asyncio
644+
async def test_content_without_newline(editor, tmp_path):
645+
"""Test handling content without trailing newline."""
646+
test_file = tmp_path / "test.txt"
647+
test_file.write_text("line1\nline2\nline3\n")
648+
649+
# Get file hash
650+
content, _, _, file_hash, _, _ = await editor.read_file_contents(str(test_file))
651+
652+
# Update with content that doesn't have a trailing newline
653+
result = await editor.edit_file_contents(
654+
str(test_file),
655+
file_hash,
656+
[
657+
{
658+
"line_start": 2,
659+
"line_end": 2,
660+
"contents": "new line", # No trailing newline
661+
"range_hash": editor.calculate_hash("line2\n"),
662+
}
663+
],
664+
)
665+
666+
assert result["result"] == "ok"
667+
assert test_file.read_text() == "line1\nnew line\nline3\n"
668+
result = await editor.edit_file_contents(
669+
str(test_file),
670+
"",
671+
[{"line_start": 1, "contents": "new content\n"}],
672+
)
673+
674+
assert result["result"] == "error"
675+
assert "Unexpected error" in result["reason"]
676+
677+
678+
@pytest.mark.asyncio
679+
async def test_invalid_line_range(editor, tmp_path):
680+
"""Test handling of invalid line range where end line is less than start line."""
681+
test_file = tmp_path / "test.txt"
682+
test_file.write_text("line1\nline2\nline3\n")
683+
684+
# Try to read with invalid line range
685+
with pytest.raises(ValueError) as excinfo:
686+
await editor.read_file_contents(str(test_file), line_start=3, line_end=2)
687+
688+
assert "End line must be greater than or equal to start line" in str(excinfo.value)
689+
690+
# Try to edit with invalid line range
691+
content, _, _, file_hash, _, _ = await editor.read_file_contents(str(test_file))
692+
693+
result = await editor.edit_file_contents(
694+
str(test_file),
695+
file_hash,
696+
[
697+
{
698+
"line_start": 3,
699+
"line_end": 2,
700+
"contents": "new content\n",
701+
"range_hash": editor.calculate_hash("line3\n"),
702+
}
703+
],
704+
)
705+
706+
assert result["result"] == "error"
707+
assert "End line must be greater than or equal to start line" in result["reason"]

0 commit comments

Comments
 (0)