proj/test-return-values: Try to handle fork() correctly
This commit is contained in:
@@ -93,7 +93,7 @@ class Parser:
|
|||||||
pid: Optional[int]
|
pid: Optional[int]
|
||||||
tid: Optional[int]
|
tid: Optional[int]
|
||||||
path: Optional[str]
|
path: Optional[str]
|
||||||
stack: list[tuple[int, int, str, Optional[str], Optional[str], Optional[int], str, tuple]]
|
stack: dict[tuple[int, int], list[tuple[int, int, str, Optional[str], Optional[str], Optional[int], str, tuple]]]
|
||||||
ret_addr: int
|
ret_addr: int
|
||||||
rel_ret_addr: int
|
rel_ret_addr: int
|
||||||
dli_file_name: str
|
dli_file_name: str
|
||||||
@@ -104,7 +104,7 @@ class Parser:
|
|||||||
def __init__(self, rfile: BinaryIO, wfile: BinaryIO = None):
|
def __init__(self, rfile: BinaryIO, wfile: BinaryIO = None):
|
||||||
self.rfile = rfile
|
self.rfile = rfile
|
||||||
self.wfile = wfile
|
self.wfile = wfile
|
||||||
self.stack = []
|
self.stack = {}
|
||||||
self.pid = None
|
self.pid = None
|
||||||
self.path = None
|
self.path = None
|
||||||
|
|
||||||
@@ -255,14 +255,22 @@ class Parser:
|
|||||||
idx += 1
|
idx += 1
|
||||||
return PointerTo(val, l), idx
|
return PointerTo(val, l), idx
|
||||||
else:
|
else:
|
||||||
|
m = re.match(r'0x[0-9a-fA-F]+|[0-9]+|\(nil\)', argument[idx:])
|
||||||
|
if m is not None:
|
||||||
|
value = m.group(0)
|
||||||
|
idx += len(value)
|
||||||
|
if idx < len(argument) and argument[idx] == ',':
|
||||||
|
idx += 1
|
||||||
|
value = int(value, 0) if value != '(nil)' else 0
|
||||||
|
return PointerTo(val, value), idx
|
||||||
m = re.match(r'[A-Z0-9_]+', argument[idx:])
|
m = re.match(r'[A-Z0-9_]+', argument[idx:])
|
||||||
if not m:
|
if m is not None:
|
||||||
raise ValueError()
|
value = m.group(0)
|
||||||
value = m.group(0)
|
idx += len(value)
|
||||||
idx += len(value)
|
if idx < len(argument) and argument[idx] == ',':
|
||||||
if idx < len(argument) and argument[idx] == ',':
|
idx += 1
|
||||||
idx += 1
|
return Constant(val, value), idx
|
||||||
return Constant(val, value), idx
|
raise ValueError()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_args(arguments: str, named: bool = False, ret: bool = False) -> tuple[tuple or dict, int]:
|
def parse_args(arguments: str, named: bool = False, ret: bool = False) -> tuple[tuple or dict, int]:
|
||||||
@@ -292,6 +300,8 @@ class Parser:
|
|||||||
def handle_msg(self, msg: bytes):
|
def handle_msg(self, msg: bytes):
|
||||||
timestamp, pid, tid, data = msg.rstrip(b'\n').split(b' ', 3)
|
timestamp, pid, tid, data = msg.rstrip(b'\n').split(b' ', 3)
|
||||||
self.pid, self.tid = int(pid), int(tid)
|
self.pid, self.tid = int(pid), int(tid)
|
||||||
|
if len(self.stack) == 0:
|
||||||
|
self.stack[(self.pid, self.tid)] = []
|
||||||
if not data.startswith(b'return ') and not data == b'return':
|
if not data.startswith(b'return ') and not data == b'return':
|
||||||
call = data.decode('utf-8')
|
call = data.decode('utf-8')
|
||||||
#print(f'[{self.pid}][{self.tid}] {call}')
|
#print(f'[{self.pid}][{self.tid}] {call}')
|
||||||
@@ -306,11 +316,13 @@ class Parser:
|
|||||||
self.src_file_name = src_fname
|
self.src_file_name = src_fname
|
||||||
self.src_line_num = int(src_line, 0) if src_line else None
|
self.src_line_num = int(src_line, 0) if src_line else None
|
||||||
args, _ = Parser.parse_args(call[call.find('(') + 1:call.rfind(': ') - 1])
|
args, _ = Parser.parse_args(call[call.find('(') + 1:call.rfind(': ') - 1])
|
||||||
self.stack.append(
|
self.stack[(self.pid, self.tid)].append(
|
||||||
(self.ret_addr, self.rel_ret_addr,
|
(self.ret_addr, self.rel_ret_addr,
|
||||||
self.dli_file_name, self.dli_sym_name,
|
self.dli_file_name, self.dli_sym_name,
|
||||||
self.src_file_name, self.src_line_num,
|
self.src_file_name, self.src_line_num,
|
||||||
func_name, args))
|
func_name, args))
|
||||||
|
if func_name == 'fork':
|
||||||
|
self.stack[(0, 0)] = self.stack[(self.pid, self.tid)][:]
|
||||||
try:
|
try:
|
||||||
func = getattr(self, f'before_{func_name}')
|
func = getattr(self, f'before_{func_name}')
|
||||||
if not callable(func):
|
if not callable(func):
|
||||||
@@ -338,10 +350,13 @@ class Parser:
|
|||||||
other_vals = ret[1].strip() if len(ret) > 1 else ''
|
other_vals = ret[1].strip() if len(ret) > 1 else ''
|
||||||
if len(other_vals) > 0:
|
if len(other_vals) > 0:
|
||||||
kwargs, _ = Parser.parse_args(other_vals, named=True, ret=True)
|
kwargs, _ = Parser.parse_args(other_vals, named=True, ret=True)
|
||||||
|
if (self.pid, self.tid) not in self.stack:
|
||||||
|
self.stack[(self.pid, self.tid)] = self.stack[(0, 0)]
|
||||||
|
del self.stack[(0, 0)]
|
||||||
(self.ret_addr, self.rel_ret_addr,
|
(self.ret_addr, self.rel_ret_addr,
|
||||||
self.dli_file_name, self.dli_sym_name,
|
self.dli_file_name, self.dli_sym_name,
|
||||||
self.src_file_name, self.src_line_num,
|
self.src_file_name, self.src_line_num,
|
||||||
func_name, args) = self.stack.pop()
|
func_name, args) = self.stack[(self.pid, self.tid)].pop()
|
||||||
try:
|
try:
|
||||||
func = getattr(self, f'after_{func_name}')
|
func = getattr(self, f'after_{func_name}')
|
||||||
if not callable(func):
|
if not callable(func):
|
||||||
@@ -354,12 +369,12 @@ class Parser:
|
|||||||
if ret_value is None:
|
if ret_value is None:
|
||||||
func(*args, **kwargs)
|
func(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
func(*args, ret_value, **kwargs)
|
func(*args, ret_value=ret_value, **kwargs)
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
if ret_value is None:
|
if ret_value is None:
|
||||||
self.after_fallback(func_name, *args, **kwargs)
|
self.after_fallback(func_name, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
self.after_fallback(func_name, *args, ret_value, **kwargs)
|
self.after_fallback(func_name, *args, ret_value=ret_value, **kwargs)
|
||||||
#print(f'[{self.pid}][{self.tid}] -> {ret}')
|
#print(f'[{self.pid}][{self.tid}] -> {ret}')
|
||||||
|
|
||||||
def before_malloc(self, size: int) -> str:
|
def before_malloc(self, size: int) -> str:
|
||||||
@@ -658,7 +673,7 @@ class Handler(StreamRequestHandler, Parser):
|
|||||||
self.pid = int(meta['PID']) if 'PID' in meta else None
|
self.pid = int(meta['PID']) if 'PID' in meta else None
|
||||||
self.path = meta['PATH'] if 'PATH' in meta else None
|
self.path = meta['PATH'] if 'PATH' in meta else None
|
||||||
print(f'Process with PID {self.pid} connected ({self.path})')
|
print(f'Process with PID {self.pid} connected ({self.path})')
|
||||||
self.stack = []
|
self.stack = {}
|
||||||
self.parse()
|
self.parse()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -128,14 +128,14 @@ class MemoryAllocationParser(Parser):
|
|||||||
|
|
||||||
def after_getline(self, line_ptr, n_ptr, stream, ret_value, errno=None, n=None, line=None) -> None:
|
def after_getline(self, line_ptr, n_ptr, stream, ret_value, errno=None, n=None, line=None) -> None:
|
||||||
if ret_value >= 0 and n is not None and line is not None:
|
if ret_value >= 0 and n is not None and line is not None:
|
||||||
if line_ptr.target == 0:
|
if n_ptr.target == 0 or line_ptr.target == 0:
|
||||||
self.num_alloc += 1
|
self.num_alloc += 1
|
||||||
self.allocated[line.ptr] = (self.get_call_id('getline'), n)
|
self.allocated[line.ptr] = (self.get_call_id('getline'), n)
|
||||||
elif line_ptr.target != line.ptr:
|
elif line_ptr.target != line.ptr:
|
||||||
self.num_realloc += 1
|
self.num_realloc += 1
|
||||||
v = self.allocated[line.ptr]
|
v = self.allocated[line_ptr.target]
|
||||||
del self.allocated[line.ptr]
|
del self.allocated[line_ptr.target]
|
||||||
self.allocated[ret_value] = (v[0], n)
|
self.allocated[line.ptr] = (v[0], n)
|
||||||
self.update_max_allocated()
|
self.update_max_allocated()
|
||||||
|
|
||||||
def after_getdelim(self, line_ptr, n_ptr, delim, stream, ret_value, errno=None, n=None, line=None) -> None:
|
def after_getdelim(self, line_ptr, n_ptr, delim, stream, ret_value, errno=None, n=None, line=None) -> None:
|
||||||
@@ -293,6 +293,7 @@ class ReturnValueCheckParser(Parser):
|
|||||||
return (last_call, errors[i + 1]) if i + 1 < len(errors) else self.get_next((last_call, None))
|
return (last_call, errors[i + 1]) if i + 1 < len(errors) else self.get_next((last_call, None))
|
||||||
|
|
||||||
def before(self) -> None:
|
def before(self) -> None:
|
||||||
|
self.context['pid'] = self.pid
|
||||||
if 'state' not in self.context:
|
if 'state' not in self.context:
|
||||||
self.context['state'] = 'init'
|
self.context['state'] = 'init'
|
||||||
self.context['call_tree'] = None
|
self.context['call_tree'] = None
|
||||||
@@ -303,6 +304,8 @@ class ReturnValueCheckParser(Parser):
|
|||||||
self.context['last_return'] = None
|
self.context['last_return'] = None
|
||||||
|
|
||||||
def after(self) -> None:
|
def after(self) -> None:
|
||||||
|
if self.context['pid'] != self.pid:
|
||||||
|
return
|
||||||
if self.is_init:
|
if self.is_init:
|
||||||
self.context['state'] = 'testing'
|
self.context['state'] = 'testing'
|
||||||
self.context['next'] = self.get_next(None)
|
self.context['next'] = self.get_next(None)
|
||||||
@@ -312,6 +315,9 @@ class ReturnValueCheckParser(Parser):
|
|||||||
self.last_call.children[self.last_return] = (CallTreeNode(FunctionCallId.for_exit(), self.last_call.depth + 1, {}, {}))
|
self.last_call.children[self.last_return] = (CallTreeNode(FunctionCallId.for_exit(), self.last_call.depth + 1, {}, {}))
|
||||||
|
|
||||||
def before_fallback(self, func_name: str, *args) -> str:
|
def before_fallback(self, func_name: str, *args) -> str:
|
||||||
|
if self.context['pid'] != self.pid:
|
||||||
|
return 'ok'
|
||||||
|
|
||||||
self.call_depth += 1
|
self.call_depth += 1
|
||||||
call = self.get_call_id(func_name)
|
call = self.get_call_id(func_name)
|
||||||
if self.last_call is None:
|
if self.last_call is None:
|
||||||
|
|||||||
@@ -51,7 +51,10 @@ def main() -> None:
|
|||||||
t1 = threading.Thread(target=socket_thread, args=(socket_name, handler))
|
t1 = threading.Thread(target=socket_thread, args=(socket_name, handler))
|
||||||
t1.daemon = True
|
t1.daemon = True
|
||||||
t1.start()
|
t1.start()
|
||||||
|
|
||||||
|
n = 0
|
||||||
while ctx.get('state', None) != 'finished':
|
while ctx.get('state', None) != 'finished':
|
||||||
|
n += 1
|
||||||
if stdin:
|
if stdin:
|
||||||
stdin.seek(0, 0)
|
stdin.seek(0, 0)
|
||||||
try:
|
try:
|
||||||
@@ -63,7 +66,8 @@ def main() -> None:
|
|||||||
'INTERCEPT_LIBRARIES': ','.join(['*', '-/lib*', '-/usr/lib*']),
|
'INTERCEPT_LIBRARIES': ','.join(['*', '-/lib*', '-/usr/lib*']),
|
||||||
})
|
})
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
if n > 64:
|
||||||
|
break
|
||||||
|
|
||||||
def iter_tree(node: CallTreeNode, ret=None, layer=0):
|
def iter_tree(node: CallTreeNode, ret=None, layer=0):
|
||||||
yield node.call, node.data, ret, layer, node.children
|
yield node.call, node.data, ret, layer, node.children
|
||||||
|
|||||||
Reference in New Issue
Block a user