From 0c9e2554332aeff70be293d70a904a29b62b6c51 Mon Sep 17 00:00:00 2001 From: Lorenz Stechauner Date: Wed, 7 May 2025 10:33:46 +0200 Subject: [PATCH] proj/test-return-values: Analyse call tree --- proj/server/src/intercept/__init__.py | 19 ++++- proj/server/src/intercept/standard.py | 109 +++++++++++++++++--------- proj/server/src/test-interrupts | 17 ++-- proj/server/src/test-memory | 2 + proj/server/src/test-return-values | 87 ++++++++++++++------ 5 files changed, 162 insertions(+), 72 deletions(-) diff --git a/proj/server/src/intercept/__init__.py b/proj/server/src/intercept/__init__.py index 3739596..3f929d5 100644 --- a/proj/server/src/intercept/__init__.py +++ b/proj/server/src/intercept/__init__.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +from __future__ import annotations from typing import Optional, TypedDict, NamedTuple, NotRequired, BinaryIO from socketserver import UnixStreamServer, StreamRequestHandler, ThreadingMixIn import os @@ -41,14 +42,23 @@ class FunctionCallId: ret_addr: int obj_path: str rel_ret_addr: int - sym_name: Optional[str] - src_file_name: Optional[str] - src_line_num: Optional[int] + sym_name: Optional[str] = None + src_file_name: Optional[str] = None + src_line_num: Optional[int] = None @property def obj_name(self) -> str: return self.obj_path.split('/')[-1] + @staticmethod + def for_exit() -> FunctionCallId: + call_id = FunctionCallId() + call_id.func_name = 'exit' + call_id.obj_path = 'sys' + call_id.ret_addr = 0 + call_id.rel_ret_addr = 0 + return call_id + @property def discriminator(self) -> str: discr = [f'{self.obj_name}+0x{self.rel_ret_addr:x}'] @@ -69,6 +79,9 @@ class FunctionCallId: def __str__(self) -> str: return self.func_name + ', ' + self.discriminator + def __repr__(self) -> str: + return f'<{self}>' + class ThreadedUnixStreamServer(ThreadingMixIn, UnixStreamServer): pass diff --git a/proj/server/src/intercept/standard.py b/proj/server/src/intercept/standard.py index 9ee588c..7e7ec87 100644 --- a/proj/server/src/intercept/standard.py +++ b/proj/server/src/intercept/standard.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +from __future__ import annotations from intercept import * @@ -45,6 +46,13 @@ FUNCTION_ERRORS: dict[str, list[str]] = { 'dup': ['EBADF', 'EMFILE'], 'dup2': ['EBADF', 'EBUSY', 'EINTR', 'EMFILE'], 'dup3': ['EBADF', 'EBUSY', 'EINTR', 'EMFILE'], + 'socket': ['EAFNOSUPPORT', 'EMFILE', 'ENFILE', 'ENOMEM', 'ENOBUFS', 'EACCES'], + 'bind': ['EADDRINUSE', 'EADDRNOTAVAIL', 'EAFNOSUPPORT', 'EACCES'], + 'listen': ['EACCES', 'ENOBUFS', 'EADDRINUSE'], + 'accept': ['ECONNABORTED', 'EINTR', 'EPERM', 'EPROTO', 'ENOMEM'], + 'connect': ['EACCES', 'EPERM', 'ECONNREFUSED', 'ENETUNREACH'], + 'getaddrinfo': [], + 'freeaddrinfo': [], 'send': ['EINTR'], 'recv': ['EINTR'], } @@ -220,6 +228,13 @@ class InterruptedCheckTester(InterruptedCheckParser, Handler): pass +class CallTreeNode(NamedTuple): + call: FunctionCallId + depth: int + data: dict + children: dict[str, CallTreeNode] + + class ReturnValueCheckParser(Parser): functions: dict[str, list[str]] = { fn: [e for e in errors if e not in SKIP_ERRORS] @@ -227,11 +242,23 @@ class ReturnValueCheckParser(Parser): if len(set(errors) - set(SKIP_ERRORS)) > 0 } context: dict - num: int = 0 + call_depth: int = 0 @property - def call_sequence(self) -> list[str]: - return self.context['call_sequence'] + def call_tree(self) -> CallTreeNode: + return self.context['call_tree'] + + @property + def last_call(self) -> CallTreeNode: + return self.context['last_call'] + + @property + def last_return(self) -> str: + return self.context['last_return'] + + @property + def next(self) -> tuple[Optional[CallTreeNode], Optional[str]]: + return self.context['next'] @property def is_init(self) -> bool: @@ -249,60 +276,66 @@ class ReturnValueCheckParser(Parser): def is_failed(self) -> bool: return self.context['state'] == 'failed' - def next(self, last: Optional[tuple[int, Optional[str]]]) -> tuple[int, Optional[str]]: + def get_next(self, last: Optional[tuple[CallTreeNode, Optional[str]]]) -> tuple[Optional[CallTreeNode], Optional[str]]: if last is None: - if len(self.call_sequence) == 0: - return 0, None - err = self.functions.get(self.call_sequence[0], [None])[0] - return (0, err) if err else self.next((0, err)) - last_fn, last_er = last - if last_er is None: - if last_fn + 1 >= len(self.call_sequence): - return last_fn + 1, None - err = self.functions.get(self.call_sequence[last_fn + 1], [None])[0] - return (last_fn + 1, err) if err else self.next((last_fn + 1, err)) - errors = self.functions.get(self.call_sequence[last_fn], []) - i = errors.index(last_er) - return (last_fn, errors[i + 1]) if i + 1 < len(errors) else self.next((last_fn, None)) + if self.call_tree is None: + return None, None + last_call, last_err = self.call_tree, None + else: + last_call, last_err = last + if last_err is None: + if len(last_call.children) == 0: + return None, None + err = self.functions.get(last_call.children['ok'].call.func_name, [None])[0] + return (last_call.children['ok'], err) if err else self.get_next((last_call.children['ok'], err)) + errors = self.functions.get(last_call.call.func_name, []) + i = errors.index(last_err) + return (last_call, errors[i + 1]) if i + 1 < len(errors) else self.get_next((last_call, None)) def before(self) -> None: if 'state' not in self.context: self.context['state'] = 'init' - self.context['state'] = 'init' - self.context['call_sequence'] = [] + self.context['call_tree'] = None self.context['next'] = None - self.context['last_error'] = None - self.context['results'] = {} if self.is_waiting or self.is_failed: self.context['state'] = 'testing' + self.context['last_call'] = None + self.context['last_return'] = None def after(self) -> None: if self.is_init: self.context['state'] = 'testing' - self.context['next'] = self.next(None) - elif self.is_testing and self.context['next'][0] > len(self.context['call_sequence']): + self.context['next'] = self.get_next(None) + elif self.is_testing and self.next[0] is None: self.context['state'] = 'finished' - elif self.is_waiting: - self.context['results'][(self.num - 1, self.context['last_error'])] = 'passed' + if self.last_call and self.last_call.call.func_name != 'exit': + self.last_call.children[self.last_return] = (CallTreeNode(FunctionCallId.for_exit(), self.last_call.depth + 1, {}, {})) def before_fallback(self, func_name: str, *args) -> str: + self.call_depth += 1 + call = self.get_call_id(func_name) + if self.last_call is None: + if self.call_tree is None: + self.context['call_tree'] = CallTreeNode(call, 0, {}, {}) + self.context['last_call'] = self.call_tree + else: + if self.last_return not in self.last_call.children: + self.last_call.children[self.last_return] = CallTreeNode(call, self.last_call.depth + 1, {}, {}) + self.context['last_call'] = self.last_call.children[self.last_return] + if self.is_init: - self.call_sequence.append(func_name) + self.context['last_return'] = 'ok' return 'ok' - elif self.is_waiting or self.is_failed: - print(func_name) - if func_name not in ('malloc', 'free', 'close', 'exit'): - self.context['state'] = 'failed' - self.context['results'][(self.num - 1, self.context['last_error'])] = 'failed' + + nxt = self.next + if nxt[0] is None or call != nxt[0].call or self.call_depth - 1 != nxt[0].depth or nxt[1] is None: + self.context['last_return'] = 'ok' return 'ok' - self.num += 1 - nxt = self.context['next'] - if self.num - 1 != nxt[0] or nxt[1] is None: - return 'ok' - self.context['next'] = self.next(nxt) + + self.context['next'] = self.get_next(nxt) self.context['state'] = 'waiting' if self.context['next'][1] else 'finished' - self.context['last_error'] = nxt[1] - return 'fail ' + nxt[1] + self.context['last_return'] = 'fail ' + nxt[1] + return self.context['last_return'] class ReturnValueCheckTester(ReturnValueCheckParser, Handler): diff --git a/proj/server/src/test-interrupts b/proj/server/src/test-interrupts index c21afe4..a0ae5c9 100755 --- a/proj/server/src/test-interrupts +++ b/proj/server/src/test-interrupts @@ -47,13 +47,16 @@ def main() -> None: t1 = threading.Thread(target=socket_thread, args=(socket_name, handler)) t1.daemon = True t1.start() - subprocess.run(extra, env={ - 'LD_PRELOAD': os.getcwd() + '/../../intercept/intercept.so', - 'INTERCEPT': 'unix:' + socket_name, - 'INTERCEPT_VERBOSE': '1', - 'INTERCEPT_FUNCTIONS': '*', - 'INTERCEPT_LIBRARIES': ','.join(['*', '-/lib*', '-/usr/lib*']), - }) + try: + subprocess.run(extra, env={ + 'LD_PRELOAD': os.getcwd() + '/../../intercept/intercept.so', + 'INTERCEPT': 'unix:' + socket_name, + 'INTERCEPT_VERBOSE': '1', + 'INTERCEPT_FUNCTIONS': '*', + 'INTERCEPT_LIBRARIES': ','.join(['*', '-/lib*', '-/usr/lib*']), + }) + except KeyboardInterrupt: + pass bold(':: REPORT ::') for call, status in ctx['results'].items(): diff --git a/proj/server/src/test-memory b/proj/server/src/test-memory index 17d6d61..724ca32 100755 --- a/proj/server/src/test-memory +++ b/proj/server/src/test-memory @@ -49,6 +49,8 @@ def main() -> None: 'INTERCEPT_FUNCTIONS': ','.join(['malloc', 'calloc', 'realloc', 'reallocarray', 'free', 'getaddrinfo', 'freeaddrinfo', 'getline', 'getdelim']), 'INTERCEPT_LIBRARIES': ','.join(['*', '-/lib*', '-/usr/lib*']), }) + except KeyboardInterrupt: + pass finally: with open(log_file, 'rb') as file: parser = intercept.standard.MemoryAllocationParser(file) diff --git a/proj/server/src/test-return-values b/proj/server/src/test-return-values index 8ec442f..2c8929f 100755 --- a/proj/server/src/test-return-values +++ b/proj/server/src/test-return-values @@ -9,6 +9,7 @@ import subprocess import intercept import intercept.standard +from intercept.standard import CallTreeNode def neutral(text: str) -> None: @@ -53,32 +54,70 @@ def main() -> None: while ctx.get('state', None) != 'finished': if stdin: stdin.seek(0, 0) - subprocess.run(extra, stdin=stdin, env={ - 'LD_PRELOAD': os.getcwd() + '/../../intercept/intercept.so', - 'INTERCEPT': 'unix:' + socket_name, - 'INTERCEPT_VERBOSE': '1', - 'INTERCEPT_FUNCTIONS': '*', - 'INTERCEPT_LIBRARIES': ','.join(['*', '-/lib*', '-/usr/lib*']), - }) + try: + subprocess.run(extra, stdin=stdin, env={ + 'LD_PRELOAD': os.getcwd() + '/../../intercept/intercept.so', + 'INTERCEPT': 'unix:' + socket_name, + 'INTERCEPT_VERBOSE': '1', + 'INTERCEPT_FUNCTIONS': '*', + 'INTERCEPT_LIBRARIES': ','.join(['*', '-/lib*', '-/usr/lib*']), + }) + except KeyboardInterrupt: + pass + + def iter_tree(node: CallTreeNode, ret=None, layer=0): + yield node.call, node.data, ret, layer, node.children + for ret, child in node.children.items(): + if ret != 'ok': + yield from iter_tree(child, ret, layer + 1) + if 'ok' in node.children: + yield from iter_tree(node.children['ok'], 'ok', layer) + + def iter_tree_ok(node: CallTreeNode): + yield node.call + if 'ok' in node.children: + yield from iter_tree_ok(node.children['ok']) + + bold(':: CALL TREE ::') + for call, _, r, l, c in iter_tree(ctx['call_tree']): + if r not in (None, 'ok'): + neutral(f'::{" " * l}{r}') + neutral(f'::{" " * l} {call}') + neutral('::') bold(':: REPORT ::') - neutral('::') - bold(':: SUMMARY ::') - for i, name in enumerate(ctx['call_sequence']): - errors = [r[1] for r in ctx['results'] if r[0] == i] - results = [ctx['results'][(i, e)] for e in errors] - res = '-' if len(results) == 0 else 'passed' if all(r == 'passed' for r in results) else 'failed' - esc, unesc = '', '' - errs = [e for e in errors if ctx['results'][(i, e)] == res] - errs = f' ({", ".join(errs)})' if len(errs) > 0 else '' - if sys.stdout.isatty(): - if res == '-': - esc, unesc = '\x1B[90m', '\x1B[0m' - elif res == 'failed': - res = '\x1B[31;1m' 'FAILED' '\x1B[0m' - elif res == 'passed': - res = '\x1B[32;1m' 'PASSED' '\x1B[0m' - neutral(f':: {esc}{i:3} {name:16} {res}{unesc}{errs}', ) + calls = {} + for call, _, r, l, c in iter_tree(ctx['call_tree']): + if l != 0: + continue + if call not in calls: + calls[call] = {} + entry = calls[call] + for ret, child in c.items(): + if ret not in entry: + entry[ret] = set() + entry[ret].add(tuple(c for c in iter_tree_ok(child))) + + allowed_cleanup_functions = ['malloc', 'free', 'freeaddrinfo', 'close', 'exit'] + for call, errors in calls.items(): + if len(errors) <= 1: + continue + tested = [] + failed = [] + default_path = errors['ok'] + for ret, paths in errors.items(): + if ret == 'ok': + continue + errno = ret.split(' ')[1] + tested.append(errno) + if any(p == default_path for p in paths) or any(c.func_name not in allowed_cleanup_functions for p in paths for c in p): + failed.append(errno) + if len(failed) > 0: + red(f':: TEST :: RETURN VALUE CHECK :: {call} :: FAILED') + else: + green(f':: TEST :: RETURN VALUE CHECK :: {call} :: PASSED') + + if __name__ == '__main__':