Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,6 @@
import random
import sys

import tree_sitter_cpp
from tree_sitter import Language, Parser, Query, QueryCursor

LANGUAGE = Language(tree_sitter_cpp.language())
PARSER = Parser(LANGUAGE)
EXCLUDE_DIRS = ['tests', 'test', 'examples', 'example', 'build']
ROOT_PATH = os.path.abspath(pathlib.Path.cwd().resolve())
MAX_COUNT = 50
Expand Down Expand Up @@ -128,9 +123,15 @@ def missing_header_error():

def duplicate_symbol_error():
"""Insert duplicate symbol to all found source files in the /src/ directory."""
import tree_sitter_cpp
from tree_sitter import Language, Parser, Query, QueryCursor

exts = ['.c', '.cc', '.cpp', '.cxx']
count = 0

treesitter_lang = Language(tree_sitter_cpp.language())
treesitter_parser = Parser(treesitter_lang)

# Walk and insert missing header inclusion
for cur, dirs, files in os.walk(ROOT_PATH):
dirs[:] = [d for d in dirs if d not in EXCLUDE_DIRS]
Expand All @@ -149,15 +150,15 @@ def duplicate_symbol_error():
with open(path, 'r', encoding='utf-8') as f:
source = f.read()
if source:
node = PARSER.parse(source.encode()).root_node
node = treesitter_parser.parse(source.encode()).root_node
except Exception:
pass

if not node:
continue

# Found random declaration and duplicate it
cursor = QueryCursor(Query(LANGUAGE, '( declaration ) @decl'))
cursor = QueryCursor(Query(treesitter_lang, '( declaration ) @decl'))
for declaration in cursor.captures(node).get('decl', []):
if declaration.text:
target = declaration.text.decode()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,33 @@
import os
import pathlib
import random
import subprocess
import sys

import tree_sitter_cpp
from tree_sitter import Language, Parser, Query, QueryCursor
try:
import tree_sitter_cpp
from tree_sitter import Language, Parser, Query, QueryCursor
except ModuleNotFoundError:
# pass. Allow this module to be imported even when tree-sitter
# is not available.
pass

LANGUAGE = Language(tree_sitter_cpp.language())
PARSER = Parser(LANGUAGE)
EXCLUDE_DIRS = ['tests', 'test', 'examples'
'example', 'build']
EXCLUDE_DIRS = ['tests', 'test', 'examples', 'example', 'build']
ROOT_PATH = os.path.abspath(pathlib.Path.cwd().resolve())
MAX_COUNT = 50
MAX_FILES_TO_PATCH = 50


def _add_payload_random_functions(exts: list[str], payload: str) -> str:
"""Helper to attach payload to random functions found in any source."""
count = 0

treesitter_parser = Parser(Language(tree_sitter_cpp.language()))
# Walk and insert payload on the random line of random functions
for cur, dirs, files in os.walk(ROOT_PATH):
dirs[:] = [d for d in dirs if d not in EXCLUDE_DIRS]
for file in files:
# Only change some files randomly
if count > MAX_COUNT:
if count > MAX_FILES_TO_PATCH:
return

if any(file.endswith(ext) for ext in exts):
Expand All @@ -51,15 +55,17 @@ def _add_payload_random_functions(exts: list[str], payload: str) -> str:
with open(path, 'r', encoding='utf-8') as f:
source = f.read()
if source:
node = PARSER.parse(source.encode()).root_node
node = treesitter_parser.parse(source.encode()).root_node
except Exception:
pass

if not node:
continue

# Insert payload to random line in the function
cursor = QueryCursor(Query(LANGUAGE, '( function_definition ) @funcs'))
cursor = QueryCursor(
Query(Language(tree_sitter_cpp.language()),
'( function_definition ) @funcs'))
for func in cursor.captures(node).get('funcs', []):
body = func.child_by_field_name('body')

Expand All @@ -72,8 +78,10 @@ def _add_payload_random_functions(exts: list[str], payload: str) -> str:

if body and body.text and random.choice([True, False]):
func_source = body.text.decode()
new_func_source = f'{{{payload} {func_source[1:]}'
source = source.replace(func_source, new_func_source)
# new_func_source = f'{{ {payload} {func_source[1:]}'
if len(func_source) > 10:
new_func_source = f'{{ {payload} {func_source[1:]}'
source = source.replace(func_source, new_func_source)
try:
with open(path, 'w', encoding='utf-8') as f:
f.write(source)
Expand All @@ -89,37 +97,37 @@ def normal_patch():

def signal_abort_crash():
"""Insert abort call to force a crash in source files found in the /src/directory."""
exts = ['.c', '.cc', '.cpp', '.cxx', '.h', '.hpp']
exts = ['.c', '.cc', '.cpp', '.cxx']
_add_payload_random_functions(exts, 'abort();')


def builtin_trap_crash():
"""Insert builtin trap to force a crash in source files found in the /src/directory."""
exts = ['.c', '.cc', '.cpp', '.cxx', '.h', '.hpp']
exts = ['.c', '.cc', '.cpp', '.cxx']
_add_payload_random_functions(exts, '__builtin_trap();')


def null_write_crash():
"""Insert null pointer write to force a crash in source files found in the /src/directory."""
exts = ['.c', '.cc', '.cpp', '.cxx', '.h', '.hpp']
exts = ['.c', '.cc', '.cpp', '.cxx']
_add_payload_random_functions(exts, '*(volatile int*)0 = 0;')


def wrong_return_value():
"""modify random return statement to force an unit test failed in source files found in the /src/directory."""
exts = ['.c', '.cc', '.cpp', '.cxx', '.h', '.hpp']
exts = ['.c', '.cc', '.cpp', '.cxx']
primitives = {
'bool', 'char', 'signed', 'unsigned', 'short', 'int', 'long', 'float',
'double', 'wchar_t', 'char8_t', 'char16_t', 'char32_t', 'size_t'
}
count = 0

treesitter_parser = Parser(Language(tree_sitter_cpp.language()))
# Walk and insert payload on the random line of random functions
for cur, dirs, files in os.walk(ROOT_PATH):
dirs[:] = [d for d in dirs if d not in EXCLUDE_DIRS]
for file in files:
# Only change some files randomly
if count > MAX_COUNT:
if count > MAX_FILES_TO_PATCH:
return

if any(file.endswith(ext) for ext in exts):
Expand All @@ -131,15 +139,17 @@ def wrong_return_value():
with open(path, 'r', encoding='utf-8') as f:
source = f.read()
if source:
node = PARSER.parse(source.encode()).root_node
node = treesitter_parser.parse(source.encode()).root_node
except Exception:
pass

if not node:
continue

# Try simulate wrong return statement
cursor = QueryCursor(Query(LANGUAGE, '( function_definition ) @funcs'))
cursor = QueryCursor(
Query(Language(tree_sitter_cpp.language()),
'( function_definition ) @funcs'))
for func in cursor.captures(node).get('funcs', []):
# Get return type
rtn_node = func.child_by_field_name('type')
Expand All @@ -160,7 +170,7 @@ def wrong_return_value():
body = func.child_by_field_name('body')
if body and body.text and (is_pointer or rtn in primitives):
func_source = body.text.decode()
new_func_source = f'{{return 0; {func_source[1:]}'
new_func_source = f'{{ {func_source[1:]}'
source = source.replace(func_source, new_func_source)

try:
Expand All @@ -186,13 +196,13 @@ class LogicErrorPatch:
expected_result=True,
),
LogicErrorPatch(
name='sigabrt_crash',
func=signal_abort_crash,
name='sigkill_crash',
func=builtin_trap_crash,
expected_result=False,
),
LogicErrorPatch(
name='sigkill_crash',
func=builtin_trap_crash,
name='sigabrt_crash',
func=signal_abort_crash,
expected_result=False,
),
LogicErrorPatch(
Expand All @@ -208,11 +218,81 @@ class LogicErrorPatch:
]


def diff_patch_analysis(stage: str) -> int:
"""Check if run_tests.sh generates patches that affect
source control versioning.


Returns: int: 0 if no patch found, 1 if patch found and -1 on
unkonwn (such as due to unsupported version control).
"""

print(
f'Diff patch analysis begin. Stage: {stage}, Current working dir: {os.getcwd()}'
)
if stage == 'before':
if os.path.isdir('.git'):
print('Git repo found.')
try:
subprocess.check_call('git diff ./ >> /tmp/chronos-before.diff',
shell=True)
except subprocess.CalledProcessError:
pass
return 0
print('Unknown version control system.')
return -1
elif stage == 'after':
if os.path.isdir('.git'):
print('Git repo found.')
subprocess.check_call('git diff ./ >> /tmp/chronos-after.diff',
shell=True)
try:
subprocess.check_call(
'diff /tmp/chronos-before.diff /tmp/chronos-after.diff > /tmp/chronos-diff.patch',
shell=True)
except subprocess.CalledProcessError:
pass
print('Diff patch generated at /tmp/chronos-diff.patch')
print('Difference between diffs:')
with open('/tmp/chronos-diff.patch', 'r', encoding='utf-8') as f:
diff_content = f.read()
if diff_content.strip():
patch_found = True
print(diff_content)
else:
patch_found = False

if patch_found:
print(
'Patch result: failed. Patch found that affects source control versioning.'
)
return 1
else:
print(
'Patch result: success. No patch found that affects source control versioning.'
)
return 0
print('Patch result: failed. Unknown version control system.')
return -1

else:
print(
f'Patch result: failed. Unknown stage {stage} for diff patch analysis.')
return -1


def main():
target = sys.argv[1]
for logic_error_patch in LOGIC_ERROR_PATCHES:
if logic_error_patch.name == target:
logic_error_patch.func()
"""Main entrypoint."""
command = sys.argv[1]
if command == 'semantic-patch':
target_patch = sys.argv[2]
for logic_error_patch in LOGIC_ERROR_PATCHES:
if logic_error_patch.name == target_patch:
logic_error_patch.func()
elif command == 'diff-patch':
print(f'Diff patch not implemented yet {sys.argv[2]}.')
result = diff_patch_analysis(sys.argv[2])
sys.exit(result)


if __name__ == "__main__":
Expand Down
Loading
Loading