diff --git a/proj/server/src/intercept/__init__.py b/proj/server/src/intercept/__init__.py index cc39631..46c86bb 100644 --- a/proj/server/src/intercept/__init__.py +++ b/proj/server/src/intercept/__init__.py @@ -26,6 +26,40 @@ StructMsgHdr = TypedDict('StructMsgHdr', {}) RET_ADDR_RE = re.compile(r': *((0x)?[0-9a-fA-Fx]+) *\((.+?)\+(.+?)(, *([^:]+?))?(, *(([^:]+?):([0-9]+)))?\)$') +class FunctionCallId: + func_name: str + ret_addr: int + obj_path: str + rel_ret_addr: int + sym_name: Optional[str] + src_file_name: Optional[str] + src_line_num: Optional[int] + + @property + def obj_name(self) -> str: + return self.obj_path.split('/')[-1] + + @property + def discriminator(self) -> str: + discr = [f'{self.obj_name}+0x{self.rel_ret_addr:x}'] + if self.sym_name: + discr.append(self.sym_name) + if self.src_file_name and self.src_line_num: + discr.append(f'{self.src_file_name}:{self.src_line_num}') + return ', '.join(discr) + + def __eq__(self, other) -> bool: + if not isinstance(other, FunctionCallId): + return False + return (self.func_name, self.obj_path, self.rel_ret_addr) == (other.func_name, other.obj_path, other.rel_ret_addr) + + def __hash__(self) -> int: + return hash((self.func_name, self.obj_path, self.rel_ret_addr)) + + def __str__(self) -> str: + return self.func_name + ', ' + self.discriminator + + class ThreadedUnixStreamServer(ThreadingMixIn, UnixStreamServer): pass @@ -51,6 +85,17 @@ class Parser: self.pid = None self.path = None + def get_call_id(self, func_name: str) -> FunctionCallId: + call_id = FunctionCallId() + call_id.func_name = func_name + call_id.ret_addr = self.ret_addr + call_id.obj_path = self.dli_file_name + call_id.rel_ret_addr = self.rel_ret_addr + call_id.sym_name = self.dli_sym_name + call_id.src_file_name = self.src_file_name + call_id.src_line_num = self.src_line_num + return call_id + def before(self) -> None: pass def after(self) -> None: pass def before_fallback(self, func_name: str, *args) -> str: pass diff --git a/proj/server/src/intercept/standard.py b/proj/server/src/intercept/standard.py index ab2fc65..9deaf1a 100644 --- a/proj/server/src/intercept/standard.py +++ b/proj/server/src/intercept/standard.py @@ -2,7 +2,6 @@ # -*- coding: utf-8 -*- from intercept import * -import sys FUNCTION_ERRORS: dict[str, list[str]] = { @@ -54,8 +53,8 @@ SKIP_ERRORS: list[str] = ['EINTR'] class MemoryAllocationParser(Parser): - allocated: dict[int, tuple[str, int, str, str, int, str, int]] - invalid_frees: list[tuple[str, int, str, str, int, str, int]] + allocated: dict[int, tuple[FunctionCallId, int]] + invalid_frees: list[tuple[FunctionCallId, int]] max_allocated: int num_alloc: int num_realloc: int @@ -77,56 +76,46 @@ class MemoryAllocationParser(Parser): def after_malloc(self, size, ret_value, errno=None) -> None: self.num_alloc += 1 if ret_value != 0: - self.allocated[ret_value] = ('malloc', size, - self.dli_file_name, self.dli_sym_name, self.rel_ret_addr, - self.src_file_name, self.src_line_num) + self.allocated[ret_value] = (self.get_call_id('malloc'), size) self.update_max_allocated() def after_calloc(self, nmemb, size, ret_value, errno=None) -> None: self.num_alloc += 1 if ret_value != 0: - self.allocated[ret_value] = ('calloc', nmemb * size, - self.dli_file_name, self.dli_sym_name, self.rel_ret_addr, - self.src_file_name, self.src_line_num) + self.allocated[ret_value] = (self.get_call_id('calloc'), nmemb * size) self.update_max_allocated() def after_realloc(self, ptr, size, ret_value, errno=None) -> None: if ptr == 0: self.num_alloc += 1 if ret_value != 0: - self.allocated[ret_value] = ('realloc', size, - self.dli_file_name, self.dli_sym_name, self.rel_ret_addr, - self.src_file_name, self.src_line_num) + self.allocated[ret_value] = (self.get_call_id('realloc'), size) else: self.num_realloc += 1 if ret_value != 0 and ptr in self.allocated: v = self.allocated[ptr] del self.allocated[ptr] - self.allocated[ret_value] = (v[0], size, v[2], v[3], v[4], v[5], v[6]) + self.allocated[ret_value] = (v[0], size) self.update_max_allocated() def after_reallocarray(self, ptr, nmemb, size, ret_value, errno=None) -> None: if ptr == 0: self.num_alloc += 1 if ret_value != 0: - self.allocated[ret_value] = ('reallocarray', nmemb * size, - self.dli_file_name, self.dli_sym_name, self.rel_ret_addr, - self.src_file_name, self.src_line_num) + self.allocated[ret_value] = (self.get_call_id('reallocarray'), nmemb * size) else: self.num_realloc += 1 if ret_value != 0: v = self.allocated[ptr] del self.allocated[ptr] - self.allocated[ret_value] = (v[0], nmemb * size, v[2], v[3], v[4], v[5], v[6]) + self.allocated[ret_value] = (v[0], nmemb * size) self.update_max_allocated() def after_getaddrinfo(self, node, service, hints, res_ptr, ret_value, errno=None, res=None) -> None: self.num_alloc += 1 if ret_value[0] == 0 and res is not None: size = sum(48 + r['ai_addrlen'] for r in res[1]) - self.allocated[res[0]] = ('getaddrinfo', size, - self.dli_file_name, self.dli_sym_name, self.rel_ret_addr, - self.src_file_name, self.src_line_num) + self.allocated[res[0]] = (self.get_call_id('getaddrinfo'), size) self.update_max_allocated() def after_free(self, ptr) -> None: @@ -136,9 +125,7 @@ class MemoryAllocationParser(Parser): del self.allocated[ptr] else: self.num_free -= 1 - self.invalid_frees.append(('free', ptr, - self.dli_file_name, self.dli_sym_name, self.rel_ret_addr, - self.src_file_name, self.src_line_num)) + self.invalid_frees.append((self.get_call_id('free'), ptr)) def after_freeaddrinfo(self, res: Pointer) -> None: self.num_free += 1 @@ -147,9 +134,7 @@ class MemoryAllocationParser(Parser): del self.allocated[res] else: self.num_free -= 1 - self.invalid_frees.append(('freeaddrinfo', res, - self.dli_file_name, self.dli_sym_name, self.rel_ret_addr, - self.src_file_name, self.src_line_num)) + self.invalid_frees.append((self.get_call_id('freeaddrinfo'), res)) class MemoryAllocationTester(MemoryAllocationParser, Handler): @@ -164,37 +149,32 @@ class InterruptedCheckParser(Parser): for fn, errors in FUNCTION_ERRORS.items() if 'EINTR' in errors or fn in ('sem_post',) } + context: dict counter: int = 0 - last_func_name: Optional[str] = None - last_ret_addr: Optional[int] = None - tested_functions: dict[tuple[str, int], str] + last_call: Optional[FunctionCallId] = None + tested_function_calls: dict[FunctionCallId, str] @property def while_testing(self) -> bool: return self.counter % self.cycles != 0 def before(self) -> None: - self.tested_functions = {} + self.tested_function_calls = {} def after(self) -> None: if self.while_testing: self.error() - for (name, ret_addr), status in self.tested_functions.items(): - if status == 'passed': - print(f'\x1B[32;1m{name} (0x{ret_addr:x}) -> {status}\x1B[0m', file=sys.stderr) - else: - print(f'\x1B[31;1m{name} (0x{ret_addr:x}) -> {status}\x1B[0m', file=sys.stderr) + self.context['results'] = self.tested_function_calls def error(self): - print(f'Error: Return value and errno EINTR not handled correctly in {self.last_func_name} (return address 0x{self.last_ret_addr:x})', file=sys.stderr) - self.tested_functions[(self.last_func_name, self.last_ret_addr)] = 'failed' + self.tested_function_calls[self.last_call] = 'failed' self.counter = 0 - self.last_func_name = None - self.last_ret_addr = None + self.last_call = None def before_fallback(self, func_name: str, *args) -> str: - if self.while_testing and (self.last_func_name != func_name or self.last_ret_addr != self.ret_addr): + call = self.get_call_id(func_name) + if self.while_testing and self.last_call != call: self.error() return 'ok' elif func_name not in self.functions: @@ -203,14 +183,12 @@ class InterruptedCheckParser(Parser): return self.functions[func_name][1] self.counter += 1 if self.while_testing: - self.last_ret_addr = self.ret_addr - self.last_func_name = func_name - self.tested_functions[(self.last_func_name, self.last_ret_addr)] = 'running' + self.last_call = call + self.tested_function_calls[self.last_call] = 'running' return self.functions[func_name][0] else: - self.tested_functions[(self.last_func_name, self.last_ret_addr)] = 'passed' - self.last_ret_addr = None - self.last_func_name = None + self.tested_function_calls[self.last_call] = 'passed' + self.last_call = None return self.functions[func_name][1] @@ -218,91 +196,97 @@ class InterruptedCheckTester(InterruptedCheckParser, Handler): pass -def get_return_value_check_tester() -> tuple[dict, type[Parser]]: - ctx = { - 'state': 'init', - 'call_sequence': [], - 'next': None, - 'last_error': None, - 'results': {}, +class ReturnValueCheckParser(Parser): + functions: dict[str, list[str]] = { + fn: [e for e in errors if e not in SKIP_ERRORS] + for fn, errors in FUNCTION_ERRORS.items() + if len(set(errors) - set(SKIP_ERRORS)) > 0 } + context: dict + num: int = 0 - class ReturnValueCheckTester(Parser): - functions: dict[str, list[str]] = { - fn: [e for e in errors if e not in SKIP_ERRORS] - for fn, errors in FUNCTION_ERRORS.items() - if len(set(errors) - set(SKIP_ERRORS)) > 0 - } - context: dict - num: int = 0 + @property + def call_sequence(self) -> list[str]: + return self.context['call_sequence'] - @property - def call_sequence(self) -> list[str]: - return self.context['call_sequence'] + @property + def is_init(self) -> bool: + return self.context['state'] == 'init' - @property - def is_init(self) -> bool: - return self.context['state'] == 'init' + @property + def is_testing(self) -> bool: + return self.context['state'] == 'testing' - @property - def is_testing(self) -> bool: - return self.context['state'] == 'testing' + @property + def is_waiting(self) -> bool: + return self.context['state'] == 'waiting' - @property - def is_waiting(self) -> bool: - return self.context['state'] == 'waiting' + @property + def is_failed(self) -> bool: + return self.context['state'] == 'failed' - @property - def is_failed(self) -> bool: - return self.context['state'] == 'failed' + def next(self, last: Optional[tuple[int, Optional[str]]]) -> tuple[int, Optional[str]]: + if last is None: + if len(self.call_sequence) == 0: + return 0, None + err = self.functions.get(self.call_sequence[0], [None])[0] + return (0, err) if err else self.next((0, err)) + last_fn, last_er = last + if last_er is None: + if last_fn + 1 >= len(self.call_sequence): + return last_fn + 1, None + err = self.functions.get(self.call_sequence[last_fn + 1], [None])[0] + return (last_fn + 1, err) if err else self.next((last_fn + 1, err)) + errors = self.functions.get(self.call_sequence[last_fn], []) + i = errors.index(last_er) + return (last_fn, errors[i + 1]) if i + 1 < len(errors) else self.next((last_fn, None)) - def next(self, last: Optional[tuple[int, Optional[str]]]) -> tuple[int, Optional[str]]: - if last is None: - if len(self.call_sequence) == 0: - return 0, None - err = self.functions.get(self.call_sequence[0], [None])[0] - return (0, err) if err else self.next((0, err)) - last_fn, last_er = last - if last_er is None: - if last_fn + 1 >= len(self.call_sequence): - return last_fn + 1, None - err = self.functions.get(self.call_sequence[last_fn + 1], [None])[0] - return (last_fn + 1, err) if err else self.next((last_fn + 1, err)) - errors = self.functions.get(self.call_sequence[last_fn], []) - i = errors.index(last_er) - return (last_fn, errors[i + 1]) if i + 1 < len(errors) else self.next((last_fn, None)) + def before(self) -> None: + if 'state' not in self.context: + self.context['state'] = 'init' + self.context['state'] = 'init' + self.context['call_sequence'] = [] + self.context['next'] = None + self.context['last_error'] = None + self.context['results'] = {} + if self.is_waiting or self.is_failed: + self.context['state'] = 'testing' - def before(self) -> None: - self.context = ctx - if self.is_waiting or self.is_failed: - self.context['state'] = 'testing' + def after(self) -> None: + if self.is_init: + self.context['state'] = 'testing' + self.context['next'] = self.next(None) + elif self.is_testing and self.context['next'][0] > len(self.context['call_sequence']): + self.context['state'] = 'finished' + elif self.is_waiting: + self.context['results'][(self.num - 1, self.context['last_error'])] = 'passed' - def after(self) -> None: - if self.is_init: - self.context['state'] = 'testing' - self.context['next'] = self.next(None) - elif self.is_testing and self.context['next'][0] > len(self.context['call_sequence']): - self.context['state'] = 'finished' - elif self.is_waiting: - self.context['results'][(self.num - 1, self.context['last_error'])] = 'passed' + def before_fallback(self, func_name: str, *args) -> str: + if self.is_init: + self.call_sequence.append(func_name) + return 'ok' + elif self.is_waiting or self.is_failed: + print(func_name) + if func_name not in ('malloc', 'free', 'close', 'exit'): + self.context['state'] = 'failed' + self.context['results'][(self.num - 1, self.context['last_error'])] = 'failed' + return 'ok' + self.num += 1 + nxt = self.context['next'] + if self.num - 1 != nxt[0] or nxt[1] is None: + return 'ok' + self.context['next'] = self.next(nxt) + self.context['state'] = 'waiting' if self.context['next'][1] else 'finished' + self.context['last_error'] = nxt[1] + return 'fail ' + nxt[1] - def before_fallback(self, func_name: str, *args) -> str: - if self.is_init: - self.call_sequence.append(func_name) - return 'ok' - elif self.is_waiting or self.is_failed: - print(func_name) - if func_name not in ('malloc', 'free', 'close', 'exit'): - self.context['state'] = 'failed' - self.context['results'][(self.num - 1, self.context['last_error'])] = 'failed' - return 'ok' - self.num += 1 - nxt = self.context['next'] - if self.num - 1 != nxt[0] or nxt[1] is None: - return 'ok' - self.context['next'] = self.next(nxt) - self.context['state'] = 'waiting' if self.context['next'][1] else 'finished' - self.context['last_error'] = nxt[1] - return 'fail ' + nxt[1] - return ctx, ReturnValueCheckTester +class ReturnValueCheckTester(ReturnValueCheckParser, Handler): + pass + + +def init_with_ctx(parser: type[Parser]) -> tuple[dict, type[Parser]]: + ctx = {} + class ContextParser(parser): + context: dict = ctx + return ctx, ContextParser diff --git a/proj/server/src/test-interrupts b/proj/server/src/test-interrupts index a4342e0..c21afe4 100755 --- a/proj/server/src/test-interrupts +++ b/proj/server/src/test-interrupts @@ -2,6 +2,7 @@ # -*- coding: utf-8 -*- import os +import sys import argparse import threading import subprocess @@ -10,8 +11,27 @@ import intercept import intercept.standard -def socket_thread(socket: str) -> None: - intercept.intercept(socket, intercept.standard.InterruptedCheckTester) +def neutral(text: str) -> None: + print(text, file=sys.stderr) + +def color(text: str, col: str) -> None: + print(('\x1B[' + col + 'm' if sys.stderr.isatty() else '') + + text + + ('\x1B[0m' if sys.stderr.isatty() else ''), + file=sys.stderr) + +def bold(text: str) -> None: + color(text, '1') + +def red(text: str) -> None: + color(text, '31;1') + +def green(text: str) -> None: + color(text, '32;1') + + +def socket_thread(socket: str, handler: type[intercept.Handler]) -> None: + intercept.intercept(socket, handler) def main() -> None: @@ -20,10 +40,11 @@ def main() -> None: if len(extra) > 0 and extra[0] == '--': extra.pop(0) if len(extra) == 0: - parser.error('command expected after arguments or \'--\'') + parser.error("command expected after arguments or '--'") socket_name = f'/tmp/intercept.interrupts.{os.getpid()}.sock' - t1 = threading.Thread(target=socket_thread, args=(socket_name,)) + ctx, handler = intercept.standard.init_with_ctx(intercept.standard.InterruptedCheckTester) + t1 = threading.Thread(target=socket_thread, args=(socket_name, handler)) t1.daemon = True t1.start() subprocess.run(extra, env={ @@ -34,6 +55,14 @@ def main() -> None: 'INTERCEPT_LIBRARIES': ','.join(['*', '-/lib*', '-/usr/lib*']), }) + bold(':: REPORT ::') + for call, status in ctx['results'].items(): + if status == 'passed': + green(f':: TEST :: INTERRUPT HANDLING :: {call} :: PASSED') + else: + red(f':: TEST :: INTERRUPT HANDLING :: {call} :: FAILED') + + if __name__ == '__main__': main() diff --git a/proj/server/src/test-memory b/proj/server/src/test-memory index 7033a19..b744957 100755 --- a/proj/server/src/test-memory +++ b/proj/server/src/test-memory @@ -58,14 +58,8 @@ def main() -> None: if len(parser.allocated) > 0: red(':: TEST :: MEMORY LEAKS :: FAILED ::') red(":: Not free'd:") - for ptr, (func, size, fname, sname, ret, src, line) in parser.allocated.items(): - fname = fname.split('/')[-1] - pos = [func, f'{fname}+0x{ret:x}'] - if sname: - pos.append(sname) - if src and line: - pos.append(f'{src}:{line}') - red(f':: 0x{ptr:x}: {size:>6} bytes ({", ".join(pos)})') + for ptr, (call, size) in parser.allocated.items(): + red(f':: 0x{ptr:x}: {size:>6} bytes ({call})') else: green(':: TEST :: MEMORY LEAKS :: PASSED ::') green(":: All allocated memory blocks were free'd!") @@ -73,14 +67,8 @@ def main() -> None: if len(parser.invalid_frees) > 0: red(':: TEST :: INVALID FREES :: FAILED ::') red(':: Invalid/double frees:') - for (ptr, func, fname, sname, ret, src, line) in parser.invalid_frees: - fname = fname.split('/')[-1] - pos = [f'{fname}+0x{ret:x}'] - if sname: - pos.append(sname) - if src and line: - pos.append(f'{src}:{line}') - red(f':: {func}: 0x{ptr:x} ({", ".join(pos)})') + for (call, ptr) in parser.invalid_frees: + red(f':: {call.func_name}: 0x{ptr:x} ({call.discriminator})') else: green(':: TEST :: INVALID FREES :: PASSED ::') green(':: No invalid/double frees occured!') diff --git a/proj/server/src/test-return-values b/proj/server/src/test-return-values index b05e038..8ec442f 100755 --- a/proj/server/src/test-return-values +++ b/proj/server/src/test-return-values @@ -11,6 +11,25 @@ import intercept import intercept.standard +def neutral(text: str) -> None: + print(text, file=sys.stderr) + +def color(text: str, col: str) -> None: + print(('\x1B[' + col + 'm' if sys.stderr.isatty() else '') + + text + + ('\x1B[0m' if sys.stderr.isatty() else ''), + file=sys.stderr) + +def bold(text: str) -> None: + color(text, '1') + +def red(text: str) -> None: + color(text, '31;1') + +def green(text: str) -> None: + color(text, '32;1') + + def socket_thread(socket: str, handler: type[intercept.Handler]) -> None: intercept.intercept(socket, handler) @@ -22,16 +41,16 @@ def main() -> None: if len(extra) > 0 and extra[0] == '--': extra.pop(0) if len(extra) == 0: - parser.error('command expected after arguments or \'--\'') + parser.error("command expected after arguments or '--'") stdin = open(args.stdin) if args.stdin else None socket_name = f'/tmp/intercept.return-values.{os.getpid()}.sock' - ctx, handler = intercept.standard.get_return_value_check_tester() + ctx, handler = intercept.standard.init_with_ctx(intercept.standard.ReturnValueCheckTester) t1 = threading.Thread(target=socket_thread, args=(socket_name, handler)) t1.daemon = True t1.start() - while ctx['state'] != 'finished': + while ctx.get('state', None) != 'finished': if stdin: stdin.seek(0, 0) subprocess.run(extra, stdin=stdin, env={ @@ -41,6 +60,10 @@ def main() -> None: 'INTERCEPT_FUNCTIONS': '*', 'INTERCEPT_LIBRARIES': ','.join(['*', '-/lib*', '-/usr/lib*']), }) + + bold(':: REPORT ::') + neutral('::') + bold(':: SUMMARY ::') for i, name in enumerate(ctx['call_sequence']): errors = [r[1] for r in ctx['results'] if r[0] == i] results = [ctx['results'][(i, e)] for e in errors] @@ -52,10 +75,10 @@ def main() -> None: if res == '-': esc, unesc = '\x1B[90m', '\x1B[0m' elif res == 'failed': - res = '\x1B[31;1mfailed\x1B[0m' + res = '\x1B[31;1m' 'FAILED' '\x1B[0m' elif res == 'passed': - res = '\x1B[32;1mpassed\x1B[0m' - print(f'{esc}{i:3} {name:16} {res}{unesc}{errs}') + res = '\x1B[32;1m' 'PASSED' '\x1B[0m' + neutral(f':: {esc}{i:3} {name:16} {res}{unesc}{errs}', ) if __name__ == '__main__':