diff --git a/proj/server/src/intercept.py b/proj/server/src/intercept.py index 6124df6..8b51082 100755 --- a/proj/server/src/intercept.py +++ b/proj/server/src/intercept.py @@ -19,13 +19,13 @@ class ThreadedUnixStreamServer(ThreadingMixIn, UnixStreamServer): class Handler(StreamRequestHandler): pid: int - stack: list[tuple[str, tuple]] + stack: list[tuple[int, str, tuple]] + ret_addr: int - def before(self) -> None: - pass - - def after(self) -> None: - pass + def before(self) -> None: pass + def after(self) -> None: pass + def before_fallback(self, func_name: str, *args) -> str: pass + def after_fallback(self, func_name: str, *args, **kwargs) -> None: pass def handle(self): first = self.rfile.readline() @@ -93,7 +93,7 @@ class Handler(StreamRequestHandler): idx = len(m.group(0)) if a.startswith('0x'): val = int(a[2:], 16) - elif a.startswith('0'): + elif a.startswith('0') and len(a) > 1: val = int(a[1:], 8) else: val = int(a, 10) @@ -157,68 +157,131 @@ class Handler(StreamRequestHandler): call = data.decode('utf-8') print(f'[{self.pid}] {call}') func_name = call[:call.find('(')] - args, _ = Handler.parse_args(call[call.find('(') + 1:-1]) - self.stack.append((func_name, args)) + self.ret_addr = int(call[call.rfind(':') + 1:], 0) + args, _ = Handler.parse_args(call[call.find('(') + 1:call.rfind(':') - 1]) + self.stack.append((self.ret_addr, func_name, args)) try: func = getattr(self, f'before_{func_name}') - if callable(func): - command = func(*args) or 'ok' - else: - command = 'ok' + if not callable(func): + func = None except AttributeError: - command = 'ok' + func = None + try: + if func is None: + raise NotImplementedError() + command = func(*args) or self.before_fallback(func_name, *args) or 'ok' + except NotImplementedError: + command = self.before_fallback(func_name, *args) or 'ok' print(f'[{self.pid}] -> {command}') self.wfile.write(command.encode('utf-8') + b'\n') else: ret = data.decode('utf-8') - ret_value, _ = Handler.parse_arg(ret[7:]) - func_name, args = self.stack.pop() + ret_value, _ = Handler.parse_arg(ret[7:].split(';')[0]) + self.ret_addr, func_name, args = self.stack.pop() try: func = getattr(self, f'after_{func_name}') - if callable(func): - if ret_value is None: - func(*args) - else: - func(*args, ret_value) + if not callable(func): + func = None except AttributeError: - pass + func = None + try: + if func is None: + raise NotImplementedError() + if ret_value is None: + func(*args) + else: + func(*args, ret_value) + except NotImplementedError: + if ret_value is None: + self.after_fallback(func_name, *args) + else: + self.after_fallback(func_name, *args, ret_value) print(f'[{self.pid}] -> {ret}') - def before_malloc(self, size: int) -> str: pass + def before_malloc(self, size: int) -> str: + raise NotImplementedError() def after_malloc(self, size: int, - ret_value: int, errno: str = None) -> None: pass - def before_calloc(self, nmemb: int, size: int) -> str: pass + ret_value: int, errno: str = None) -> None: + raise NotImplementedError() + def before_calloc(self, nmemb: int, size: int) -> str: + raise NotImplementedError() def after_calloc(self, nmemb: int, size: int, - ret_value: int, errno: str = None) -> None: pass - def before_realloc(self, ptr: int, size: int) -> str: pass + ret_value: int, errno: str = None) -> None: + raise NotImplementedError() + def before_realloc(self, ptr: int, size: int) -> str: + raise NotImplementedError() def after_realloc(self, ptr: int, size: int, - ret_value: int, errno: str = None) -> None: pass - def before_reallocarray(self, ptr: int, nmemb: int, size: int) -> str: pass + ret_value: int, errno: str = None) -> None: + raise NotImplementedError() + def before_reallocarray(self, ptr: int, nmemb: int, size: int) -> str: + raise NotImplementedError() def after_reallocarray(self, ptr: int, nmemb: int, size: int, - ret_value: int, errno: str = None) -> None: pass - def before_free(self, ptr: int) -> str: pass - def after_free(self, ptr: int) -> None: pass - def before_getopt(self, argc: int, argv: Pointer[list[Pointer[bytes]]], optstring: Pointer[bytes]) -> str: pass + ret_value: int, errno: str = None) -> None: + raise NotImplementedError() + def before_free(self, ptr: int) -> str: + raise NotImplementedError() + def after_free(self, ptr: int) -> None: + raise NotImplementedError() + def before_getopt(self, argc: int, argv: Pointer[list[Pointer[bytes]]], optstring: Pointer[bytes]) -> str: + raise NotImplementedError() def after_getopt(self, argc: int, argv: Pointer[list[Pointer[bytes]]], optstring: Pointer[bytes], - ret_value: int) -> None: pass - def before_close(self, fildes: int) -> str: pass + ret_value: int) -> None: + raise NotImplementedError() + def before_close(self, fildes: int) -> str: + raise NotImplementedError() def after_close(self, fildes: int, - ret_value: int, errno: str = None) -> None: pass - def before_sem_init(self, sem: int, pshared: int, value: int) -> str: pass + ret_value: int, errno: str = None) -> None: + raise NotImplementedError() + def before_sem_init(self, sem: int, pshared: int, value: int) -> str: + raise NotImplementedError() def after_sem_init(self, sem: int, pshared: int, value: int, - ret_value: int, errno: str = None) -> None: pass - def before_sem_open(self, name: str, oflag: Flags, mode: Optional[int], value: Optional[int]) -> str: pass + ret_value: int, errno: str = None) -> None: + raise NotImplementedError() + def before_sem_open(self, name: str, oflag: Flags, mode: Optional[int], value: Optional[int]) -> str: + raise NotImplementedError() def after_sem_open(self, name: str, oflag: Flags, mode: Optional[int], value: Optional[int], - ret_value: int, errno: str = None) -> None: pass - def before_sem_post(self, sem: int) -> str: pass + ret_value: int, errno: str = None) -> None: + raise NotImplementedError() + def before_sem_post(self, sem: int) -> str: + raise NotImplementedError() def after_sem_post(self, sem: int, - ret_value: int, errno: str = None) -> None: pass - def before_sem_wait(self, sem: int) -> str: pass + ret_value: int, errno: str = None) -> None: + raise NotImplementedError() + def before_sem_wait(self, sem: int) -> str: + raise NotImplementedError() def after_sem_wait(self, sem: int, - ret_value: int, errno: str = None) -> None: pass - def before_sem_timedwait(self, sem: int, abs_timeout: Pointer[StructTimeSpec]): pass + ret_value: int, errno: str = None) -> None: + raise NotImplementedError() + def before_sem_trywait(self, sem: int) -> str: + raise NotImplementedError() + def after_sem_trywait(self, sem: int, + ret_value: int, errno: str = None) -> None: + raise NotImplementedError() + def before_sem_timedwait(self, sem: int, abs_timeout: Pointer[StructTimeSpec]) -> str: + raise NotImplementedError() def after_sem_timedwait(self, sem: int, abs_timeout: Pointer[StructTimeSpec], - ret_value: int, errno: str = None): pass + ret_value: int, errno: str = None) -> None: + raise NotImplementedError() + def before_sem_getvalue(self, sem: int, value_ptr: int) -> str: + raise NotImplementedError() + def after_sem_getvalue(self, sem:int, value_ptr: int, + ret_value: int, errno: str = None, value: int = None) -> None: + raise NotImplementedError() + def before_sem_close(self, sem: int) -> str: + raise NotImplementedError() + def after_sem_close(self, sem: int, + ret_value: int, errno: str = None) -> None: + raise NotImplementedError() + def before_sem_unlink(self, name: Pointer[bytes]) -> str: + raise NotImplementedError() + def after_sem_unlink(self, name: Pointer[bytes], + ret_value: int, errno: str = None) -> None: + raise NotImplementedError() + def before_sem_destroy(self, sem: int) -> str: + raise NotImplementedError() + def after_sem_destroy(self, sem: int, + ret_value: int, errno: str = None) -> None: + raise NotImplementedError() class MemoryAllocationTester(Handler): @@ -284,6 +347,82 @@ class MemoryAllocationTester(Handler): del self.allocated[ptr] +class ReturnValueCheckTester(Handler): + pass + + +class InterruptedCheckTester(Handler): + cycles: int = 50 + counter: int = 0 + last_func_name: Optional[str] = None + last_ret_addr: Optional[int] = None + tested_functions: dict[tuple[str, int], str] + + @property + def while_testing(self) -> bool: + return self.counter % self.cycles != 0 + + def before(self) -> None: + self.tested_functions = {} + + def after(self) -> None: + if self.while_testing: + self.error() + for (name, ret_addr), status in self.tested_functions.items(): + print(f'{name} (0x{ret_addr:x}) -> {status}') + + def error(self): + print(f'Error: Return value and errno EINTR not handled correctly in {self.last_func_name} (return address 0x{self.last_ret_addr:x})') + self.tested_functions[(self.last_func_name, self.last_ret_addr)] = 'failed' + self.counter = 0 + self.last_func_name = None + self.last_ret_addr = None + + def after_fallback(self, func_name: str, *args, **kwargs) -> None: + if self.while_testing and self.last_func_name != func_name: + self.error() + + def before_sem_wait(self, sem: int) -> str: + if self.last_ret_addr and self.last_ret_addr != self.ret_addr: + self.error() + self.counter += 1 + if self.while_testing: + self.last_ret_addr = self.ret_addr + self.last_func_name = 'sem_wait' + self.tested_functions[(self.last_func_name, self.last_ret_addr)] = 'running' + return 'fail EINTR' + else: + self.tested_functions[(self.last_func_name, self.last_ret_addr)] = 'passed' + self.last_ret_addr = None + self.last_func_name = None + return 'return 0' + + def before_sem_trywait(self, sem: int) -> str: + self.counter += 1 + if self.while_testing: + self.last_ret_addr = self.ret_addr + self.last_func_name = 'sem_trywait' + return 'fail EINTR' + else: + self.last_ret_addr = None + self.last_func_name = None + return 'return 0' + + def before_sem_timedwait(self, sem: int, abs_timeout: Pointer[StructTimeSpec]) -> str: + self.counter += 1 + if self.while_testing: + self.last_ret_addr = self.ret_addr + self.last_func_name = 'sem_timedwait' + return 'fail EINTR' + else: + self.last_ret_addr = None + self.last_func_name = None + return 'return 0' + + def before_sem_post(self, sem: int) -> str: + return 'return 0' + + def intercept(socket: str, handler: type[Handler]) -> None: try: with ThreadedUnixStreamServer(socket, handler) as server: diff --git a/proj/server/src/server.py b/proj/server/src/server.py index 260c825..221d111 100755 --- a/proj/server/src/server.py +++ b/proj/server/src/server.py @@ -39,7 +39,7 @@ def main() -> None: parser = argparse.ArgumentParser() parser.add_argument('socket', metavar='FILE') args = parser.parse_args() - intercept.intercept(args.socket, intercept.MemoryAllocationTester) + intercept.intercept(args.socket, intercept.InterruptedCheckTester) if __name__ == '__main__': diff --git a/proj/test1/src/intercept.c b/proj/test1/src/intercept.c index dfae993..cf55db6 100644 --- a/proj/test1/src/intercept.c +++ b/proj/test1/src/intercept.c @@ -499,7 +499,7 @@ static void init(void) { void *__sym(malloc)(size_t size) { init(); - msg("malloc(%li)", size); + msg("malloc(%li): %p", size, __builtin_return_address(0)); if (mode >= 4) { char buf[BUFFER_SIZE]; rcv(buf, sizeof(buf)); @@ -515,7 +515,7 @@ void *__sym(malloc)(size_t size) { void *__sym(calloc)(size_t nmemb, size_t size) { init(); - msg("calloc(%li, %li)", nmemb, size); + msg("calloc(%li, %li): %p", nmemb, size, __builtin_return_address(0)); if (mode >= 4) { char buf[BUFFER_SIZE]; rcv(buf, sizeof(buf)); @@ -531,7 +531,7 @@ void *__sym(calloc)(size_t nmemb, size_t size) { void *__sym(realloc)(void *ptr, size_t size) { init(); - msg("realloc(%p, %li)", ptr, size); + msg("realloc(%p, %li): %p", ptr, size, __builtin_return_address(0)); if (mode >= 4) { char buf[BUFFER_SIZE]; rcv(buf, sizeof(buf)); @@ -547,7 +547,7 @@ void *__sym(realloc)(void *ptr, size_t size) { void *__sym(reallocarray)(void *ptr, size_t nmemb, size_t size) { init(); - msg("reallocarray(%p, %li)", ptr, size); + msg("reallocarray(%p, %li): %p", ptr, size, __builtin_return_address(0)); if (mode >= 4) { char buf[BUFFER_SIZE]; rcv(buf, sizeof(buf)); @@ -563,7 +563,7 @@ void *__sym(reallocarray)(void *ptr, size_t nmemb, size_t size) { void __sym(free)(void *ptr) { init(); - msg("free(%p)", ptr); + msg("free(%p): %p", ptr, __builtin_return_address(0)); if (mode >= 4) { char buf[BUFFER_SIZE]; rcv(buf, sizeof(buf)); @@ -577,7 +577,7 @@ void __sym(free)(void *ptr) { int __sym(getopt)(const int argc, char *const argv[], const char *shortopts) { init(); - msg("getopt(%i, %as, %es)", argc, argv, argc, shortopts); + msg("getopt(%i, %as, %es): %p", argc, argv, argc, shortopts, __builtin_return_address(0)); if (mode >= 4) { char buf[BUFFER_SIZE]; rcv(buf, sizeof(buf)); @@ -595,7 +595,7 @@ int __sym(getopt)(const int argc, char *const argv[], const char *shortopts) { int __sym(close)(int fildes) { init(); - msg("close(%i)", fildes); + msg("close(%i): %p", fildes, __builtin_return_address(0)); if (mode >= 4) { char buf[BUFFER_SIZE]; rcv(buf, sizeof(buf)); @@ -611,7 +611,7 @@ int __sym(close)(int fildes) { int __sym(sem_init)(sem_t *sem, int pshared, unsigned int value) { init(); - msg("sem_init(%p, %i, %u)", sem, pshared, value); + msg("sem_init(%p, %i, %u): %p", sem, pshared, value, __builtin_return_address(0)); if (mode >= 4) { char buf[BUFFER_SIZE]; rcv(buf, sizeof(buf)); @@ -640,9 +640,9 @@ sem_t *__sym(sem_open)(const char *name, int oflag, ...) { mode_arg = va_arg(args, mode_t); value = va_arg(args, unsigned int); va_end(args); - msg("sem_open(%es, 0%o:|%s, 0%03o, %u)", name, oflag, ostr, mode_arg, value); + msg("sem_open(%es, 0%o:|%s, 0%03o, %u): %p", name, oflag, ostr, mode_arg, value, __builtin_return_address(0)); } else { - msg("sem_open(%es, 0%o:|%s)", name, oflag, ostr); + msg("sem_open(%es, 0%o:|%s): %p", name, oflag, ostr, __builtin_return_address(0)); } char overwrite[BUFFER_SIZE]; @@ -714,7 +714,7 @@ sem_t *__sym(sem_open)(const char *name, int oflag, ...) { int __sym(sem_post)(sem_t *sem) { init(); - msg("sem_post(%p)", sem); + msg("sem_post(%p): %p", sem, __builtin_return_address(0)); if (mode >= 4) { char buf[BUFFER_SIZE]; rcv(buf, sizeof(buf)); @@ -730,7 +730,7 @@ int __sym(sem_post)(sem_t *sem) { int __sym(sem_wait)(sem_t *sem) { init(); - msg("sem_wait(%p)", sem); + msg("sem_wait(%p): %p", sem, __builtin_return_address(0)); if (mode >= 4) { char buf[BUFFER_SIZE]; rcv(buf, sizeof(buf)); @@ -746,7 +746,7 @@ int __sym(sem_wait)(sem_t *sem) { int __sym(sem_trywait)(sem_t *sem) { init(); - msg("sem_trywait(%p)", sem); + msg("sem_trywait(%p): %p", sem, __builtin_return_address(0)); if (mode >= 4) { char buf[BUFFER_SIZE]; rcv(buf, sizeof(buf)); @@ -762,7 +762,7 @@ int __sym(sem_trywait)(sem_t *sem) { int __sym(sem_timedwait)(sem_t *restrict sem, const struct timespec *restrict abs_timeout) { init(); - msg("sem_timedwait(%p, %p:{tv_sec: %li, tv_nsec: %li})", sem, abs_timeout, abs_timeout->tv_sec, abs_timeout->tv_nsec); + msg("sem_timedwait(%p, %p:{tv_sec: %li, tv_nsec: %li}): %p", sem, abs_timeout, abs_timeout->tv_sec, abs_timeout->tv_nsec, __builtin_return_address(0)); struct timespec overwrite; if (mode >= 4) { char buf[BUFFER_SIZE]; @@ -807,7 +807,7 @@ int __sym(sem_timedwait)(sem_t *restrict sem, const struct timespec *restrict ab int __sym(sem_getvalue)(sem_t *restrict sem, int *restrict value) { init(); - msg("sem_getvalue(%p, %p)", sem, value); + msg("sem_getvalue(%p, %p): %p", sem, value, __builtin_return_address(0)); if (mode >= 4) { char buf[BUFFER_SIZE]; rcv(buf, sizeof(buf)); @@ -823,7 +823,7 @@ int __sym(sem_getvalue)(sem_t *restrict sem, int *restrict value) { int __sym(sem_close)(sem_t *sem) { init(); - msg("sem_close(%p)", sem); + msg("sem_close(%p): %p", sem, __builtin_return_address(0)); if (mode >= 4) { char buf[BUFFER_SIZE]; rcv(buf, sizeof(buf)); @@ -839,7 +839,7 @@ int __sym(sem_close)(sem_t *sem) { int __sym(sem_unlink)(const char *name) { init(); - msg("sem_unlink(%es)", name); + msg("sem_unlink(%es): %p", name, __builtin_return_address(0)); char overwrite[BUFFER_SIZE]; if (mode >= 4) { char buf[BUFFER_SIZE]; @@ -870,7 +870,7 @@ int __sym(sem_unlink)(const char *name) { int __sym(sem_destroy)(sem_t *sem) { init(); - msg("sem_destroy(%p)", sem); + msg("sem_destroy(%p): %p", sem, __builtin_return_address(0)); if (mode >= 4) { char buf[BUFFER_SIZE]; rcv(buf, sizeof(buf)); diff --git a/proj/test1/src/main.c b/proj/test1/src/main.c index 5c6395e..c4e15ba 100644 --- a/proj/test1/src/main.c +++ b/proj/test1/src/main.c @@ -5,6 +5,7 @@ #include #include #include +#include static void usage(const char *prog_name) { fprintf(stderr, "usage: %s [-a]\n", prog_name); @@ -22,6 +23,29 @@ void do_something(void) { free(mem); } +void do_sem(void) { + + sem_t sem; + sem_init(&sem, 0, 1); + + sem_post(&sem); + + while (sem_wait(&sem) == -1) { + if (errno != EINTR) { + // TODO fail + } + } + + do_something(); + + sem_wait(&sem); + + do_something(); + + sem_destroy(&sem); + +} + int main(const int argc, char *const argv[]) { for (int ch; (ch = getopt(argc, argv, "a")) != -1;) { switch (ch) { @@ -37,4 +61,5 @@ int main(const int argc, char *const argv[]) { do_something(); do_something(); do_something(); + do_sem(); }