diff --git a/proj/server/src/intercept/__init__.py b/proj/server/src/intercept/__init__.py index 3f929d5..ca89eb8 100644 --- a/proj/server/src/intercept/__init__.py +++ b/proj/server/src/intercept/__init__.py @@ -93,7 +93,7 @@ class Parser: pid: Optional[int] tid: Optional[int] path: Optional[str] - stack: list[tuple[int, int, str, Optional[str], Optional[str], Optional[int], str, tuple]] + stack: dict[tuple[int, int], list[tuple[int, int, str, Optional[str], Optional[str], Optional[int], str, tuple]]] ret_addr: int rel_ret_addr: int dli_file_name: str @@ -104,7 +104,7 @@ class Parser: def __init__(self, rfile: BinaryIO, wfile: BinaryIO = None): self.rfile = rfile self.wfile = wfile - self.stack = [] + self.stack = {} self.pid = None self.path = None @@ -255,14 +255,22 @@ class Parser: idx += 1 return PointerTo(val, l), idx else: + m = re.match(r'0x[0-9a-fA-F]+|[0-9]+|\(nil\)', argument[idx:]) + if m is not None: + value = m.group(0) + idx += len(value) + if idx < len(argument) and argument[idx] == ',': + idx += 1 + value = int(value, 0) if value != '(nil)' else 0 + return PointerTo(val, value), idx m = re.match(r'[A-Z0-9_]+', argument[idx:]) - if not m: - raise ValueError() - value = m.group(0) - idx += len(value) - if idx < len(argument) and argument[idx] == ',': - idx += 1 - return Constant(val, value), idx + if m is not None: + value = m.group(0) + idx += len(value) + if idx < len(argument) and argument[idx] == ',': + idx += 1 + return Constant(val, value), idx + raise ValueError() @staticmethod def parse_args(arguments: str, named: bool = False, ret: bool = False) -> tuple[tuple or dict, int]: @@ -292,6 +300,8 @@ class Parser: def handle_msg(self, msg: bytes): timestamp, pid, tid, data = msg.rstrip(b'\n').split(b' ', 3) self.pid, self.tid = int(pid), int(tid) + if len(self.stack) == 0: + self.stack[(self.pid, self.tid)] = [] if not data.startswith(b'return ') and not data == b'return': call = data.decode('utf-8') #print(f'[{self.pid}][{self.tid}] {call}') @@ -306,11 +316,13 @@ class Parser: self.src_file_name = src_fname self.src_line_num = int(src_line, 0) if src_line else None args, _ = Parser.parse_args(call[call.find('(') + 1:call.rfind(': ') - 1]) - self.stack.append( + self.stack[(self.pid, self.tid)].append( (self.ret_addr, self.rel_ret_addr, self.dli_file_name, self.dli_sym_name, self.src_file_name, self.src_line_num, func_name, args)) + if func_name == 'fork': + self.stack[(0, 0)] = self.stack[(self.pid, self.tid)][:] try: func = getattr(self, f'before_{func_name}') if not callable(func): @@ -338,10 +350,13 @@ class Parser: other_vals = ret[1].strip() if len(ret) > 1 else '' if len(other_vals) > 0: kwargs, _ = Parser.parse_args(other_vals, named=True, ret=True) + if (self.pid, self.tid) not in self.stack: + self.stack[(self.pid, self.tid)] = self.stack[(0, 0)] + del self.stack[(0, 0)] (self.ret_addr, self.rel_ret_addr, self.dli_file_name, self.dli_sym_name, self.src_file_name, self.src_line_num, - func_name, args) = self.stack.pop() + func_name, args) = self.stack[(self.pid, self.tid)].pop() try: func = getattr(self, f'after_{func_name}') if not callable(func): @@ -354,12 +369,12 @@ class Parser: if ret_value is None: func(*args, **kwargs) else: - func(*args, ret_value, **kwargs) + func(*args, ret_value=ret_value, **kwargs) except NotImplementedError: if ret_value is None: self.after_fallback(func_name, *args, **kwargs) else: - self.after_fallback(func_name, *args, ret_value, **kwargs) + self.after_fallback(func_name, *args, ret_value=ret_value, **kwargs) #print(f'[{self.pid}][{self.tid}] -> {ret}') def before_malloc(self, size: int) -> str: @@ -658,7 +673,7 @@ class Handler(StreamRequestHandler, Parser): self.pid = int(meta['PID']) if 'PID' in meta else None self.path = meta['PATH'] if 'PATH' in meta else None print(f'Process with PID {self.pid} connected ({self.path})') - self.stack = [] + self.stack = {} self.parse() diff --git a/proj/server/src/intercept/standard.py b/proj/server/src/intercept/standard.py index 7e7ec87..dfef42b 100644 --- a/proj/server/src/intercept/standard.py +++ b/proj/server/src/intercept/standard.py @@ -128,14 +128,14 @@ class MemoryAllocationParser(Parser): def after_getline(self, line_ptr, n_ptr, stream, ret_value, errno=None, n=None, line=None) -> None: if ret_value >= 0 and n is not None and line is not None: - if line_ptr.target == 0: + if n_ptr.target == 0 or line_ptr.target == 0: self.num_alloc += 1 self.allocated[line.ptr] = (self.get_call_id('getline'), n) elif line_ptr.target != line.ptr: self.num_realloc += 1 - v = self.allocated[line.ptr] - del self.allocated[line.ptr] - self.allocated[ret_value] = (v[0], n) + v = self.allocated[line_ptr.target] + del self.allocated[line_ptr.target] + self.allocated[line.ptr] = (v[0], n) self.update_max_allocated() def after_getdelim(self, line_ptr, n_ptr, delim, stream, ret_value, errno=None, n=None, line=None) -> None: @@ -293,6 +293,7 @@ class ReturnValueCheckParser(Parser): return (last_call, errors[i + 1]) if i + 1 < len(errors) else self.get_next((last_call, None)) def before(self) -> None: + self.context['pid'] = self.pid if 'state' not in self.context: self.context['state'] = 'init' self.context['call_tree'] = None @@ -303,6 +304,8 @@ class ReturnValueCheckParser(Parser): self.context['last_return'] = None def after(self) -> None: + if self.context['pid'] != self.pid: + return if self.is_init: self.context['state'] = 'testing' self.context['next'] = self.get_next(None) @@ -312,6 +315,9 @@ class ReturnValueCheckParser(Parser): self.last_call.children[self.last_return] = (CallTreeNode(FunctionCallId.for_exit(), self.last_call.depth + 1, {}, {})) def before_fallback(self, func_name: str, *args) -> str: + if self.context['pid'] != self.pid: + return 'ok' + self.call_depth += 1 call = self.get_call_id(func_name) if self.last_call is None: diff --git a/proj/server/src/test-return-values b/proj/server/src/test-return-values index 2c8929f..094c1ea 100755 --- a/proj/server/src/test-return-values +++ b/proj/server/src/test-return-values @@ -51,7 +51,10 @@ def main() -> None: t1 = threading.Thread(target=socket_thread, args=(socket_name, handler)) t1.daemon = True t1.start() + + n = 0 while ctx.get('state', None) != 'finished': + n += 1 if stdin: stdin.seek(0, 0) try: @@ -63,7 +66,8 @@ def main() -> None: 'INTERCEPT_LIBRARIES': ','.join(['*', '-/lib*', '-/usr/lib*']), }) except KeyboardInterrupt: - pass + if n > 64: + break def iter_tree(node: CallTreeNode, ret=None, layer=0): yield node.call, node.data, ret, layer, node.children