1717import os
1818import pathlib
1919import random
20+ import subprocess
2021import sys
2122
22- import tree_sitter_cpp
23- from tree_sitter import Language , Parser , Query , QueryCursor
23+ try :
24+ import tree_sitter_cpp
25+ from tree_sitter import Language , Parser , Query , QueryCursor
26+ except ModuleNotFoundError :
27+ # pass. Allow this module to be imported even when tree-sitter
28+ # is not available.
29+ pass
2430
25- LANGUAGE = Language (tree_sitter_cpp .language ())
26- PARSER = Parser (LANGUAGE )
27- EXCLUDE_DIRS = ['tests' , 'test' , 'examples'
28- 'example' , 'build' ]
31+ EXCLUDE_DIRS = ['tests' , 'test' , 'examples' , 'example' , 'build' ]
2932ROOT_PATH = os .path .abspath (pathlib .Path .cwd ().resolve ())
30- MAX_COUNT = 50
33+ MAX_FILES_TO_PATCH = 50
3134
3235
3336def _add_payload_random_functions (exts : list [str ], payload : str ) -> str :
3437 """Helper to attach payload to random functions found in any source."""
3538 count = 0
3639
40+ treesitter_parser = Parser (Language (tree_sitter_cpp .language ()))
3741 # Walk and insert payload on the random line of random functions
3842 for cur , dirs , files in os .walk (ROOT_PATH ):
3943 dirs [:] = [d for d in dirs if d not in EXCLUDE_DIRS ]
4044 for file in files :
4145 # Only change some files randomly
42- if count > MAX_COUNT :
46+ if count > MAX_FILES_TO_PATCH :
4347 return
4448
4549 if any (file .endswith (ext ) for ext in exts ):
@@ -51,15 +55,17 @@ def _add_payload_random_functions(exts: list[str], payload: str) -> str:
5155 with open (path , 'r' , encoding = 'utf-8' ) as f :
5256 source = f .read ()
5357 if source :
54- node = PARSER .parse (source .encode ()).root_node
58+ node = treesitter_parser .parse (source .encode ()).root_node
5559 except Exception :
5660 pass
5761
5862 if not node :
5963 continue
6064
6165 # Insert payload to random line in the function
62- cursor = QueryCursor (Query (LANGUAGE , '( function_definition ) @funcs' ))
66+ cursor = QueryCursor (
67+ Query (Language (tree_sitter_cpp .language ()),
68+ '( function_definition ) @funcs' ))
6369 for func in cursor .captures (node ).get ('funcs' , []):
6470 body = func .child_by_field_name ('body' )
6571
@@ -72,8 +78,10 @@ def _add_payload_random_functions(exts: list[str], payload: str) -> str:
7278
7379 if body and body .text and random .choice ([True , False ]):
7480 func_source = body .text .decode ()
75- new_func_source = f'{{{ payload } { func_source [1 :]} '
76- source = source .replace (func_source , new_func_source )
81+ # new_func_source = f'{{ {payload} {func_source[1:]}'
82+ if len (func_source ) > 10 :
83+ new_func_source = f'{{ { payload } { func_source [1 :]} '
84+ source = source .replace (func_source , new_func_source )
7785 try :
7886 with open (path , 'w' , encoding = 'utf-8' ) as f :
7987 f .write (source )
@@ -89,37 +97,37 @@ def normal_patch():
8997
9098def signal_abort_crash ():
9199 """Insert abort call to force a crash in source files found in the /src/directory."""
92- exts = ['.c' , '.cc' , '.cpp' , '.cxx' , '.h' , '.hpp' ]
100+ exts = ['.c' , '.cc' , '.cpp' , '.cxx' ]
93101 _add_payload_random_functions (exts , 'abort();' )
94102
95103
96104def builtin_trap_crash ():
97105 """Insert builtin trap to force a crash in source files found in the /src/directory."""
98- exts = ['.c' , '.cc' , '.cpp' , '.cxx' , '.h' , '.hpp' ]
106+ exts = ['.c' , '.cc' , '.cpp' , '.cxx' ]
99107 _add_payload_random_functions (exts , '__builtin_trap();' )
100108
101109
102110def null_write_crash ():
103111 """Insert null pointer write to force a crash in source files found in the /src/directory."""
104- exts = ['.c' , '.cc' , '.cpp' , '.cxx' , '.h' , '.hpp' ]
112+ exts = ['.c' , '.cc' , '.cpp' , '.cxx' ]
105113 _add_payload_random_functions (exts , '*(volatile int*)0 = 0;' )
106114
107115
108116def wrong_return_value ():
109117 """modify random return statement to force an unit test failed in source files found in the /src/directory."""
110- exts = ['.c' , '.cc' , '.cpp' , '.cxx' , '.h' , '.hpp' ]
118+ exts = ['.c' , '.cc' , '.cpp' , '.cxx' ]
111119 primitives = {
112120 'bool' , 'char' , 'signed' , 'unsigned' , 'short' , 'int' , 'long' , 'float' ,
113121 'double' , 'wchar_t' , 'char8_t' , 'char16_t' , 'char32_t' , 'size_t'
114122 }
115123 count = 0
116-
124+ treesitter_parser = Parser ( Language ( tree_sitter_cpp . language ()))
117125 # Walk and insert payload on the random line of random functions
118126 for cur , dirs , files in os .walk (ROOT_PATH ):
119127 dirs [:] = [d for d in dirs if d not in EXCLUDE_DIRS ]
120128 for file in files :
121129 # Only change some files randomly
122- if count > MAX_COUNT :
130+ if count > MAX_FILES_TO_PATCH :
123131 return
124132
125133 if any (file .endswith (ext ) for ext in exts ):
@@ -131,15 +139,17 @@ def wrong_return_value():
131139 with open (path , 'r' , encoding = 'utf-8' ) as f :
132140 source = f .read ()
133141 if source :
134- node = PARSER .parse (source .encode ()).root_node
142+ node = treesitter_parser .parse (source .encode ()).root_node
135143 except Exception :
136144 pass
137145
138146 if not node :
139147 continue
140148
141149 # Try simulate wrong return statement
142- cursor = QueryCursor (Query (LANGUAGE , '( function_definition ) @funcs' ))
150+ cursor = QueryCursor (
151+ Query (Language (tree_sitter_cpp .language ()),
152+ '( function_definition ) @funcs' ))
143153 for func in cursor .captures (node ).get ('funcs' , []):
144154 # Get return type
145155 rtn_node = func .child_by_field_name ('type' )
@@ -160,7 +170,7 @@ def wrong_return_value():
160170 body = func .child_by_field_name ('body' )
161171 if body and body .text and (is_pointer or rtn in primitives ):
162172 func_source = body .text .decode ()
163- new_func_source = f'{{return 0; { func_source [1 :]} '
173+ new_func_source = f'{{ { func_source [1 :]} '
164174 source = source .replace (func_source , new_func_source )
165175
166176 try :
@@ -186,13 +196,13 @@ class LogicErrorPatch:
186196 expected_result = True ,
187197 ),
188198 LogicErrorPatch (
189- name = 'sigabrt_crash ' ,
190- func = signal_abort_crash ,
199+ name = 'sigkill_crash ' ,
200+ func = builtin_trap_crash ,
191201 expected_result = False ,
192202 ),
193203 LogicErrorPatch (
194- name = 'sigkill_crash ' ,
195- func = builtin_trap_crash ,
204+ name = 'sigabrt_crash ' ,
205+ func = signal_abort_crash ,
196206 expected_result = False ,
197207 ),
198208 LogicErrorPatch (
@@ -208,11 +218,81 @@ class LogicErrorPatch:
208218]
209219
210220
221+ def diff_patch_analysis (stage : str ) -> int :
222+ """Check if run_tests.sh generates patches that affect
223+ source control versioning.
224+
225+
226+ Returns: int: 0 if no patch found, 1 if patch found and -1 on
227+ unkonwn (such as due to unsupported version control).
228+ """
229+
230+ print (
231+ f'Diff patch analysis begin. Stage: { stage } , Current working dir: { os .getcwd ()} '
232+ )
233+ if stage == 'before' :
234+ if os .path .isdir ('.git' ):
235+ print ('Git repo found.' )
236+ try :
237+ subprocess .check_call ('git diff ./ >> /tmp/chronos-before.diff' ,
238+ shell = True )
239+ except subprocess .CalledProcessError :
240+ pass
241+ return 0
242+ print ('Unknown version control system.' )
243+ return - 1
244+ elif stage == 'after' :
245+ if os .path .isdir ('.git' ):
246+ print ('Git repo found.' )
247+ subprocess .check_call ('git diff ./ >> /tmp/chronos-after.diff' ,
248+ shell = True )
249+ try :
250+ subprocess .check_call (
251+ 'diff /tmp/chronos-before.diff /tmp/chronos-after.diff > /tmp/chronos-diff.patch' ,
252+ shell = True )
253+ except subprocess .CalledProcessError :
254+ pass
255+ print ('Diff patch generated at /tmp/chronos-diff.patch' )
256+ print ('Difference between diffs:' )
257+ with open ('/tmp/chronos-diff.patch' , 'r' , encoding = 'utf-8' ) as f :
258+ diff_content = f .read ()
259+ if diff_content .strip ():
260+ patch_found = True
261+ print (diff_content )
262+ else :
263+ patch_found = False
264+
265+ if patch_found :
266+ print (
267+ 'Patch result: failed. Patch found that affects source control versioning.'
268+ )
269+ return 1
270+ else :
271+ print (
272+ 'Patch result: success. No patch found that affects source control versioning.'
273+ )
274+ return 0
275+ print ('Patch result: failed. Unknown version control system.' )
276+ return - 1
277+
278+ else :
279+ print (
280+ f'Patch result: failed. Unknown stage { stage } for diff patch analysis.' )
281+ return - 1
282+
283+
211284def main ():
212- target = sys .argv [1 ]
213- for logic_error_patch in LOGIC_ERROR_PATCHES :
214- if logic_error_patch .name == target :
215- logic_error_patch .func ()
285+ """Main entrypoint."""
286+ command = sys .argv [1 ]
287+ if command == 'semantic-patch' :
288+ target_patch = sys .argv [2 ]
289+ for logic_error_patch in LOGIC_ERROR_PATCHES :
290+ if logic_error_patch .name == target_patch :
291+ logic_error_patch .func ()
292+ elif command == 'diff-patch' :
293+ print (f'Diff patch not implemented yet { sys .argv [2 ]} .' )
294+ result = diff_patch_analysis (sys .argv [2 ])
295+ sys .exit (result )
216296
217297
218298if __name__ == "__main__" :
0 commit comments