Skip to content

Commit 84c6c82

Browse files
committed
Refactor pr_helper tests: simplify parser, mock UUID for exact output checks
1 parent 7aa6fe8 commit 84c6c82

File tree

1 file changed

+40
-38
lines changed

1 file changed

+40
-38
lines changed

infra/pr_helper_test.py

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
"""PR helper env variable injection and URL sanitization tests."""
1818

1919
import os
20-
import re
2120
import tempfile
2221
import unittest
2322
from unittest import mock
@@ -28,41 +27,41 @@
2827

2928

3029
def _parse_github_env(content):
31-
"""Parses a GITHUB_ENV file and returns a dict of env var names to values.
32-
33-
Supports both KEY=value and KEY<<DELIMITER formats.
34-
"""
30+
"""Parses GITHUB_ENV content into a dict of env var names to values."""
3531
env_vars = {}
3632
lines = content.split('\n')
3733
i = 0
3834
while i < len(lines):
3935
line = lines[i]
40-
if not line:
41-
i += 1
42-
continue
43-
44-
# Check for delimiter format: NAME<<DELIMITER
45-
delim_match = re.match(r'^([A-Z_]+)<<(.+)$', line)
46-
if delim_match:
47-
name = delim_match.group(1)
48-
delimiter = delim_match.group(2)
49-
value_lines = []
36+
if '<<' in line:
37+
name, delim = line.split('<<', 1)
38+
vals = []
5039
i += 1
51-
while i < len(lines) and lines[i] != delimiter:
52-
value_lines.append(lines[i])
40+
while i < len(lines) and lines[i] != delim:
41+
vals.append(lines[i])
5342
i += 1
54-
env_vars[name] = '\n'.join(value_lines)
55-
i += 1 # skip the closing delimiter
56-
continue
43+
env_vars[name] = '\n'.join(vals)
44+
elif '=' in line:
45+
name, val = line.split('=', 1)
46+
env_vars[name] = val
47+
i += 1
48+
return env_vars
5749

58-
# Check for simple KEY=value format
59-
eq_match = re.match(r'^([A-Z_]+)=(.*)$', line)
60-
if eq_match:
61-
env_vars[eq_match.group(1)] = eq_match.group(2)
6250

63-
i += 1
51+
class ParseGithubEnvTest(unittest.TestCase):
52+
"""Verify the test helper parses both GITHUB_ENV formats correctly."""
6453

65-
return env_vars
54+
def test_key_value_format(self):
55+
"""KEY=value lines are parsed correctly."""
56+
content = 'FOO=bar\nBAZ=qux\n'
57+
self.assertEqual(_parse_github_env(content), {'FOO': 'bar', 'BAZ': 'qux'})
58+
59+
def test_delimiter_format(self):
60+
"""KEY<<DELIM blocks are parsed correctly, including multiline values."""
61+
content = 'MSG<<EOF\nhello\nworld\nEOF\nOTHER<<END\nval\nEND\n'
62+
env_vars = _parse_github_env(content)
63+
self.assertEqual(env_vars['MSG'], 'hello\nworld')
64+
self.assertEqual(env_vars['OTHER'], 'val')
6665

6766

6867
class SaveEnvTest(unittest.TestCase):
@@ -85,22 +84,22 @@ def _read_env_file(self):
8584
with open(self.env_file.name, 'r', encoding='utf-8') as env_file:
8685
return env_file.read()
8786

88-
def test_save_env_basic(self):
89-
"""Normal values produce correct key=value output."""
87+
@mock.patch('pr_helper.uuid.uuid4')
88+
def test_save_env_basic(self, mock_uuid):
89+
"""Normal values produce correct delimiter-based output."""
90+
mock_uuid.return_value.hex = 'deadbeef'
9091
pr_helper.save_env('hello world', True, False)
91-
env_vars = _parse_github_env(self._read_env_file())
92-
self.assertEqual(env_vars['MESSAGE'], 'hello world')
93-
self.assertEqual(env_vars['IS_READY_FOR_MERGE'], 'True')
94-
self.assertEqual(env_vars['IS_INTERNAL'], 'False')
92+
expected = ('MESSAGE<<deadbeef\nhello world\ndeadbeef\n'
93+
'IS_READY_FOR_MERGE<<deadbeef\nTrue\ndeadbeef\n'
94+
'IS_INTERNAL<<deadbeef\nFalse\ndeadbeef\n')
95+
self.assertEqual(self._read_env_file(), expected)
9596

9697
def test_save_env_newline_injection_blocked(self):
9798
"""Newlines in message must not inject extra env vars."""
9899
malicious = 'hello\nGITHUB_API_URL=https://evil.com'
99100
pr_helper.save_env(malicious, True, False)
100101
env_vars = _parse_github_env(self._read_env_file())
101-
# The injected env var must NOT appear as a separate variable.
102102
self.assertNotIn('GITHUB_API_URL', env_vars)
103-
# There must be exactly 3 env vars.
104103
self.assertEqual(len(env_vars), 3)
105104

106105
def test_save_env_carriage_return_injection_blocked(self):
@@ -117,12 +116,15 @@ def test_save_env_injection_via_all_fields(self):
117116
self.assertNotIn('EVIL', env_vars)
118117
self.assertEqual(len(env_vars), 3)
119118

120-
def test_save_env_none_values(self):
119+
@mock.patch('pr_helper.uuid.uuid4')
120+
def test_save_env_none_values(self, mock_uuid):
121121
"""None values (internal member path) are written safely."""
122+
mock_uuid.return_value.hex = 'deadbeef'
122123
pr_helper.save_env(None, None, True)
123-
env_vars = _parse_github_env(self._read_env_file())
124-
self.assertEqual(env_vars['MESSAGE'], 'None')
125-
self.assertEqual(env_vars['IS_INTERNAL'], 'True')
124+
expected = ('MESSAGE<<deadbeef\nNone\ndeadbeef\n'
125+
'IS_READY_FOR_MERGE<<deadbeef\nNone\ndeadbeef\n'
126+
'IS_INTERNAL<<deadbeef\nTrue\ndeadbeef\n')
127+
self.assertEqual(self._read_env_file(), expected)
126128

127129
def test_save_env_full_attack_scenario(self):
128130
"""Reproduces the reported attack: malicious main_repo

0 commit comments

Comments
 (0)