Skip to content

Commit c92fcc1

Browse files
chronos: refactor and extend (#14246)
- Removes unneeded CLI commands - Removes wrong patch testing - Merges two cli commands for run_tests.sh integrity - Dose light refactoring across the project - Adjusts so it can be run without treesitter on host (is installed in container) Signed-off-by: David Korczynski <david@adalogics.com>
1 parent 85c7119 commit c92fcc1

File tree

4 files changed

+184
-261
lines changed

4 files changed

+184
-261
lines changed

infra/experimental/chronos/bad_patch.py renamed to infra/experimental/chronos/integrity_validator_check_replay.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,6 @@
1818
import random
1919
import sys
2020

21-
import tree_sitter_cpp
22-
from tree_sitter import Language, Parser, Query, QueryCursor
23-
24-
LANGUAGE = Language(tree_sitter_cpp.language())
25-
PARSER = Parser(LANGUAGE)
2621
EXCLUDE_DIRS = ['tests', 'test', 'examples', 'example', 'build']
2722
ROOT_PATH = os.path.abspath(pathlib.Path.cwd().resolve())
2823
MAX_COUNT = 50
@@ -128,9 +123,15 @@ def missing_header_error():
128123

129124
def duplicate_symbol_error():
130125
"""Insert duplicate symbol to all found source files in the /src/ directory."""
126+
import tree_sitter_cpp
127+
from tree_sitter import Language, Parser, Query, QueryCursor
128+
131129
exts = ['.c', '.cc', '.cpp', '.cxx']
132130
count = 0
133131

132+
treesitter_lang = Language(tree_sitter_cpp.language())
133+
treesitter_parser = Parser(treesitter_lang)
134+
134135
# Walk and insert missing header inclusion
135136
for cur, dirs, files in os.walk(ROOT_PATH):
136137
dirs[:] = [d for d in dirs if d not in EXCLUDE_DIRS]
@@ -149,15 +150,15 @@ def duplicate_symbol_error():
149150
with open(path, 'r', encoding='utf-8') as f:
150151
source = f.read()
151152
if source:
152-
node = PARSER.parse(source.encode()).root_node
153+
node = treesitter_parser.parse(source.encode()).root_node
153154
except Exception:
154155
pass
155156

156157
if not node:
157158
continue
158159

159160
# Found random declaration and duplicate it
160-
cursor = QueryCursor(Query(LANGUAGE, '( declaration ) @decl'))
161+
cursor = QueryCursor(Query(treesitter_lang, '( declaration ) @decl'))
161162
for declaration in cursor.captures(node).get('decl', []):
162163
if declaration.text:
163164
target = declaration.text.decode()

infra/experimental/chronos/logic_error_patch.py renamed to infra/experimental/chronos/integrity_validator_run_tests.py

Lines changed: 109 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,33 @@
1717
import os
1818
import pathlib
1919
import random
20+
import subprocess
2021
import sys
2122

22-
import tree_sitter_cpp
23-
from tree_sitter import Language, Parser, Query, QueryCursor
23+
try:
24+
import tree_sitter_cpp
25+
from tree_sitter import Language, Parser, Query, QueryCursor
26+
except ModuleNotFoundError:
27+
# pass. Allow this module to be imported even when tree-sitter
28+
# is not available.
29+
pass
2430

25-
LANGUAGE = Language(tree_sitter_cpp.language())
26-
PARSER = Parser(LANGUAGE)
27-
EXCLUDE_DIRS = ['tests', 'test', 'examples'
28-
'example', 'build']
31+
EXCLUDE_DIRS = ['tests', 'test', 'examples', 'example', 'build']
2932
ROOT_PATH = os.path.abspath(pathlib.Path.cwd().resolve())
30-
MAX_COUNT = 50
33+
MAX_FILES_TO_PATCH = 50
3134

3235

3336
def _add_payload_random_functions(exts: list[str], payload: str) -> str:
3437
"""Helper to attach payload to random functions found in any source."""
3538
count = 0
3639

40+
treesitter_parser = Parser(Language(tree_sitter_cpp.language()))
3741
# Walk and insert payload on the random line of random functions
3842
for cur, dirs, files in os.walk(ROOT_PATH):
3943
dirs[:] = [d for d in dirs if d not in EXCLUDE_DIRS]
4044
for file in files:
4145
# Only change some files randomly
42-
if count > MAX_COUNT:
46+
if count > MAX_FILES_TO_PATCH:
4347
return
4448

4549
if any(file.endswith(ext) for ext in exts):
@@ -51,15 +55,17 @@ def _add_payload_random_functions(exts: list[str], payload: str) -> str:
5155
with open(path, 'r', encoding='utf-8') as f:
5256
source = f.read()
5357
if source:
54-
node = PARSER.parse(source.encode()).root_node
58+
node = treesitter_parser.parse(source.encode()).root_node
5559
except Exception:
5660
pass
5761

5862
if not node:
5963
continue
6064

6165
# Insert payload to random line in the function
62-
cursor = QueryCursor(Query(LANGUAGE, '( function_definition ) @funcs'))
66+
cursor = QueryCursor(
67+
Query(Language(tree_sitter_cpp.language()),
68+
'( function_definition ) @funcs'))
6369
for func in cursor.captures(node).get('funcs', []):
6470
body = func.child_by_field_name('body')
6571

@@ -72,8 +78,10 @@ def _add_payload_random_functions(exts: list[str], payload: str) -> str:
7278

7379
if body and body.text and random.choice([True, False]):
7480
func_source = body.text.decode()
75-
new_func_source = f'{{{payload} {func_source[1:]}'
76-
source = source.replace(func_source, new_func_source)
81+
# new_func_source = f'{{ {payload} {func_source[1:]}'
82+
if len(func_source) > 10:
83+
new_func_source = f'{{ {payload} {func_source[1:]}'
84+
source = source.replace(func_source, new_func_source)
7785
try:
7886
with open(path, 'w', encoding='utf-8') as f:
7987
f.write(source)
@@ -89,37 +97,37 @@ def normal_patch():
8997

9098
def signal_abort_crash():
9199
"""Insert abort call to force a crash in source files found in the /src/directory."""
92-
exts = ['.c', '.cc', '.cpp', '.cxx', '.h', '.hpp']
100+
exts = ['.c', '.cc', '.cpp', '.cxx']
93101
_add_payload_random_functions(exts, 'abort();')
94102

95103

96104
def builtin_trap_crash():
97105
"""Insert builtin trap to force a crash in source files found in the /src/directory."""
98-
exts = ['.c', '.cc', '.cpp', '.cxx', '.h', '.hpp']
106+
exts = ['.c', '.cc', '.cpp', '.cxx']
99107
_add_payload_random_functions(exts, '__builtin_trap();')
100108

101109

102110
def null_write_crash():
103111
"""Insert null pointer write to force a crash in source files found in the /src/directory."""
104-
exts = ['.c', '.cc', '.cpp', '.cxx', '.h', '.hpp']
112+
exts = ['.c', '.cc', '.cpp', '.cxx']
105113
_add_payload_random_functions(exts, '*(volatile int*)0 = 0;')
106114

107115

108116
def wrong_return_value():
109117
"""modify random return statement to force an unit test failed in source files found in the /src/directory."""
110-
exts = ['.c', '.cc', '.cpp', '.cxx', '.h', '.hpp']
118+
exts = ['.c', '.cc', '.cpp', '.cxx']
111119
primitives = {
112120
'bool', 'char', 'signed', 'unsigned', 'short', 'int', 'long', 'float',
113121
'double', 'wchar_t', 'char8_t', 'char16_t', 'char32_t', 'size_t'
114122
}
115123
count = 0
116-
124+
treesitter_parser = Parser(Language(tree_sitter_cpp.language()))
117125
# Walk and insert payload on the random line of random functions
118126
for cur, dirs, files in os.walk(ROOT_PATH):
119127
dirs[:] = [d for d in dirs if d not in EXCLUDE_DIRS]
120128
for file in files:
121129
# Only change some files randomly
122-
if count > MAX_COUNT:
130+
if count > MAX_FILES_TO_PATCH:
123131
return
124132

125133
if any(file.endswith(ext) for ext in exts):
@@ -131,15 +139,17 @@ def wrong_return_value():
131139
with open(path, 'r', encoding='utf-8') as f:
132140
source = f.read()
133141
if source:
134-
node = PARSER.parse(source.encode()).root_node
142+
node = treesitter_parser.parse(source.encode()).root_node
135143
except Exception:
136144
pass
137145

138146
if not node:
139147
continue
140148

141149
# Try simulate wrong return statement
142-
cursor = QueryCursor(Query(LANGUAGE, '( function_definition ) @funcs'))
150+
cursor = QueryCursor(
151+
Query(Language(tree_sitter_cpp.language()),
152+
'( function_definition ) @funcs'))
143153
for func in cursor.captures(node).get('funcs', []):
144154
# Get return type
145155
rtn_node = func.child_by_field_name('type')
@@ -160,7 +170,7 @@ def wrong_return_value():
160170
body = func.child_by_field_name('body')
161171
if body and body.text and (is_pointer or rtn in primitives):
162172
func_source = body.text.decode()
163-
new_func_source = f'{{return 0; {func_source[1:]}'
173+
new_func_source = f'{{ {func_source[1:]}'
164174
source = source.replace(func_source, new_func_source)
165175

166176
try:
@@ -186,13 +196,13 @@ class LogicErrorPatch:
186196
expected_result=True,
187197
),
188198
LogicErrorPatch(
189-
name='sigabrt_crash',
190-
func=signal_abort_crash,
199+
name='sigkill_crash',
200+
func=builtin_trap_crash,
191201
expected_result=False,
192202
),
193203
LogicErrorPatch(
194-
name='sigkill_crash',
195-
func=builtin_trap_crash,
204+
name='sigabrt_crash',
205+
func=signal_abort_crash,
196206
expected_result=False,
197207
),
198208
LogicErrorPatch(
@@ -208,11 +218,81 @@ class LogicErrorPatch:
208218
]
209219

210220

221+
def diff_patch_analysis(stage: str) -> int:
222+
"""Check if run_tests.sh generates patches that affect
223+
source control versioning.
224+
225+
226+
Returns: int: 0 if no patch found, 1 if patch found and -1 on
227+
unkonwn (such as due to unsupported version control).
228+
"""
229+
230+
print(
231+
f'Diff patch analysis begin. Stage: {stage}, Current working dir: {os.getcwd()}'
232+
)
233+
if stage == 'before':
234+
if os.path.isdir('.git'):
235+
print('Git repo found.')
236+
try:
237+
subprocess.check_call('git diff ./ >> /tmp/chronos-before.diff',
238+
shell=True)
239+
except subprocess.CalledProcessError:
240+
pass
241+
return 0
242+
print('Unknown version control system.')
243+
return -1
244+
elif stage == 'after':
245+
if os.path.isdir('.git'):
246+
print('Git repo found.')
247+
subprocess.check_call('git diff ./ >> /tmp/chronos-after.diff',
248+
shell=True)
249+
try:
250+
subprocess.check_call(
251+
'diff /tmp/chronos-before.diff /tmp/chronos-after.diff > /tmp/chronos-diff.patch',
252+
shell=True)
253+
except subprocess.CalledProcessError:
254+
pass
255+
print('Diff patch generated at /tmp/chronos-diff.patch')
256+
print('Difference between diffs:')
257+
with open('/tmp/chronos-diff.patch', 'r', encoding='utf-8') as f:
258+
diff_content = f.read()
259+
if diff_content.strip():
260+
patch_found = True
261+
print(diff_content)
262+
else:
263+
patch_found = False
264+
265+
if patch_found:
266+
print(
267+
'Patch result: failed. Patch found that affects source control versioning.'
268+
)
269+
return 1
270+
else:
271+
print(
272+
'Patch result: success. No patch found that affects source control versioning.'
273+
)
274+
return 0
275+
print('Patch result: failed. Unknown version control system.')
276+
return -1
277+
278+
else:
279+
print(
280+
f'Patch result: failed. Unknown stage {stage} for diff patch analysis.')
281+
return -1
282+
283+
211284
def main():
212-
target = sys.argv[1]
213-
for logic_error_patch in LOGIC_ERROR_PATCHES:
214-
if logic_error_patch.name == target:
215-
logic_error_patch.func()
285+
"""Main entrypoint."""
286+
command = sys.argv[1]
287+
if command == 'semantic-patch':
288+
target_patch = sys.argv[2]
289+
for logic_error_patch in LOGIC_ERROR_PATCHES:
290+
if logic_error_patch.name == target_patch:
291+
logic_error_patch.func()
292+
elif command == 'diff-patch':
293+
print(f'Diff patch not implemented yet {sys.argv[2]}.')
294+
result = diff_patch_analysis(sys.argv[2])
295+
sys.exit(result)
216296

217297

218298
if __name__ == "__main__":

0 commit comments

Comments
 (0)