1
0

proj/test-return-values: Analyse call tree

This commit is contained in:
2025-05-07 10:33:46 +02:00
parent 2fff206f6a
commit 0c9e255433
5 changed files with 162 additions and 72 deletions

View File

@@ -1,6 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import annotations
from typing import Optional, TypedDict, NamedTuple, NotRequired, BinaryIO from typing import Optional, TypedDict, NamedTuple, NotRequired, BinaryIO
from socketserver import UnixStreamServer, StreamRequestHandler, ThreadingMixIn from socketserver import UnixStreamServer, StreamRequestHandler, ThreadingMixIn
import os import os
@@ -41,14 +42,23 @@ class FunctionCallId:
ret_addr: int ret_addr: int
obj_path: str obj_path: str
rel_ret_addr: int rel_ret_addr: int
sym_name: Optional[str] sym_name: Optional[str] = None
src_file_name: Optional[str] src_file_name: Optional[str] = None
src_line_num: Optional[int] src_line_num: Optional[int] = None
@property @property
def obj_name(self) -> str: def obj_name(self) -> str:
return self.obj_path.split('/')[-1] return self.obj_path.split('/')[-1]
@staticmethod
def for_exit() -> FunctionCallId:
call_id = FunctionCallId()
call_id.func_name = 'exit'
call_id.obj_path = 'sys'
call_id.ret_addr = 0
call_id.rel_ret_addr = 0
return call_id
@property @property
def discriminator(self) -> str: def discriminator(self) -> str:
discr = [f'{self.obj_name}+0x{self.rel_ret_addr:x}'] discr = [f'{self.obj_name}+0x{self.rel_ret_addr:x}']
@@ -69,6 +79,9 @@ class FunctionCallId:
def __str__(self) -> str: def __str__(self) -> str:
return self.func_name + ', ' + self.discriminator return self.func_name + ', ' + self.discriminator
def __repr__(self) -> str:
return f'<{self}>'
class ThreadedUnixStreamServer(ThreadingMixIn, UnixStreamServer): class ThreadedUnixStreamServer(ThreadingMixIn, UnixStreamServer):
pass pass

View File

@@ -1,6 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import annotations
from intercept import * from intercept import *
@@ -45,6 +46,13 @@ FUNCTION_ERRORS: dict[str, list[str]] = {
'dup': ['EBADF', 'EMFILE'], 'dup': ['EBADF', 'EMFILE'],
'dup2': ['EBADF', 'EBUSY', 'EINTR', 'EMFILE'], 'dup2': ['EBADF', 'EBUSY', 'EINTR', 'EMFILE'],
'dup3': ['EBADF', 'EBUSY', 'EINTR', 'EMFILE'], 'dup3': ['EBADF', 'EBUSY', 'EINTR', 'EMFILE'],
'socket': ['EAFNOSUPPORT', 'EMFILE', 'ENFILE', 'ENOMEM', 'ENOBUFS', 'EACCES'],
'bind': ['EADDRINUSE', 'EADDRNOTAVAIL', 'EAFNOSUPPORT', 'EACCES'],
'listen': ['EACCES', 'ENOBUFS', 'EADDRINUSE'],
'accept': ['ECONNABORTED', 'EINTR', 'EPERM', 'EPROTO', 'ENOMEM'],
'connect': ['EACCES', 'EPERM', 'ECONNREFUSED', 'ENETUNREACH'],
'getaddrinfo': [],
'freeaddrinfo': [],
'send': ['EINTR'], 'send': ['EINTR'],
'recv': ['EINTR'], 'recv': ['EINTR'],
} }
@@ -220,6 +228,13 @@ class InterruptedCheckTester(InterruptedCheckParser, Handler):
pass pass
class CallTreeNode(NamedTuple):
call: FunctionCallId
depth: int
data: dict
children: dict[str, CallTreeNode]
class ReturnValueCheckParser(Parser): class ReturnValueCheckParser(Parser):
functions: dict[str, list[str]] = { functions: dict[str, list[str]] = {
fn: [e for e in errors if e not in SKIP_ERRORS] fn: [e for e in errors if e not in SKIP_ERRORS]
@@ -227,11 +242,23 @@ class ReturnValueCheckParser(Parser):
if len(set(errors) - set(SKIP_ERRORS)) > 0 if len(set(errors) - set(SKIP_ERRORS)) > 0
} }
context: dict context: dict
num: int = 0 call_depth: int = 0
@property @property
def call_sequence(self) -> list[str]: def call_tree(self) -> CallTreeNode:
return self.context['call_sequence'] return self.context['call_tree']
@property
def last_call(self) -> CallTreeNode:
return self.context['last_call']
@property
def last_return(self) -> str:
return self.context['last_return']
@property
def next(self) -> tuple[Optional[CallTreeNode], Optional[str]]:
return self.context['next']
@property @property
def is_init(self) -> bool: def is_init(self) -> bool:
@@ -249,60 +276,66 @@ class ReturnValueCheckParser(Parser):
def is_failed(self) -> bool: def is_failed(self) -> bool:
return self.context['state'] == 'failed' return self.context['state'] == 'failed'
def next(self, last: Optional[tuple[int, Optional[str]]]) -> tuple[int, Optional[str]]: def get_next(self, last: Optional[tuple[CallTreeNode, Optional[str]]]) -> tuple[Optional[CallTreeNode], Optional[str]]:
if last is None: if last is None:
if len(self.call_sequence) == 0: if self.call_tree is None:
return 0, None return None, None
err = self.functions.get(self.call_sequence[0], [None])[0] last_call, last_err = self.call_tree, None
return (0, err) if err else self.next((0, err)) else:
last_fn, last_er = last last_call, last_err = last
if last_er is None: if last_err is None:
if last_fn + 1 >= len(self.call_sequence): if len(last_call.children) == 0:
return last_fn + 1, None return None, None
err = self.functions.get(self.call_sequence[last_fn + 1], [None])[0] err = self.functions.get(last_call.children['ok'].call.func_name, [None])[0]
return (last_fn + 1, err) if err else self.next((last_fn + 1, err)) return (last_call.children['ok'], err) if err else self.get_next((last_call.children['ok'], err))
errors = self.functions.get(self.call_sequence[last_fn], []) errors = self.functions.get(last_call.call.func_name, [])
i = errors.index(last_er) i = errors.index(last_err)
return (last_fn, errors[i + 1]) if i + 1 < len(errors) else self.next((last_fn, None)) return (last_call, errors[i + 1]) if i + 1 < len(errors) else self.get_next((last_call, None))
def before(self) -> None: def before(self) -> None:
if 'state' not in self.context: if 'state' not in self.context:
self.context['state'] = 'init' self.context['state'] = 'init'
self.context['state'] = 'init' self.context['call_tree'] = None
self.context['call_sequence'] = []
self.context['next'] = None self.context['next'] = None
self.context['last_error'] = None
self.context['results'] = {}
if self.is_waiting or self.is_failed: if self.is_waiting or self.is_failed:
self.context['state'] = 'testing' self.context['state'] = 'testing'
self.context['last_call'] = None
self.context['last_return'] = None
def after(self) -> None: def after(self) -> None:
if self.is_init: if self.is_init:
self.context['state'] = 'testing' self.context['state'] = 'testing'
self.context['next'] = self.next(None) self.context['next'] = self.get_next(None)
elif self.is_testing and self.context['next'][0] > len(self.context['call_sequence']): elif self.is_testing and self.next[0] is None:
self.context['state'] = 'finished' self.context['state'] = 'finished'
elif self.is_waiting: if self.last_call and self.last_call.call.func_name != 'exit':
self.context['results'][(self.num - 1, self.context['last_error'])] = 'passed' self.last_call.children[self.last_return] = (CallTreeNode(FunctionCallId.for_exit(), self.last_call.depth + 1, {}, {}))
def before_fallback(self, func_name: str, *args) -> str: def before_fallback(self, func_name: str, *args) -> str:
self.call_depth += 1
call = self.get_call_id(func_name)
if self.last_call is None:
if self.call_tree is None:
self.context['call_tree'] = CallTreeNode(call, 0, {}, {})
self.context['last_call'] = self.call_tree
else:
if self.last_return not in self.last_call.children:
self.last_call.children[self.last_return] = CallTreeNode(call, self.last_call.depth + 1, {}, {})
self.context['last_call'] = self.last_call.children[self.last_return]
if self.is_init: if self.is_init:
self.call_sequence.append(func_name) self.context['last_return'] = 'ok'
return 'ok' return 'ok'
elif self.is_waiting or self.is_failed:
print(func_name) nxt = self.next
if func_name not in ('malloc', 'free', 'close', 'exit'): if nxt[0] is None or call != nxt[0].call or self.call_depth - 1 != nxt[0].depth or nxt[1] is None:
self.context['state'] = 'failed' self.context['last_return'] = 'ok'
self.context['results'][(self.num - 1, self.context['last_error'])] = 'failed'
return 'ok' return 'ok'
self.num += 1
nxt = self.context['next'] self.context['next'] = self.get_next(nxt)
if self.num - 1 != nxt[0] or nxt[1] is None:
return 'ok'
self.context['next'] = self.next(nxt)
self.context['state'] = 'waiting' if self.context['next'][1] else 'finished' self.context['state'] = 'waiting' if self.context['next'][1] else 'finished'
self.context['last_error'] = nxt[1] self.context['last_return'] = 'fail ' + nxt[1]
return 'fail ' + nxt[1] return self.context['last_return']
class ReturnValueCheckTester(ReturnValueCheckParser, Handler): class ReturnValueCheckTester(ReturnValueCheckParser, Handler):

View File

@@ -47,13 +47,16 @@ def main() -> None:
t1 = threading.Thread(target=socket_thread, args=(socket_name, handler)) t1 = threading.Thread(target=socket_thread, args=(socket_name, handler))
t1.daemon = True t1.daemon = True
t1.start() t1.start()
subprocess.run(extra, env={ try:
'LD_PRELOAD': os.getcwd() + '/../../intercept/intercept.so', subprocess.run(extra, env={
'INTERCEPT': 'unix:' + socket_name, 'LD_PRELOAD': os.getcwd() + '/../../intercept/intercept.so',
'INTERCEPT_VERBOSE': '1', 'INTERCEPT': 'unix:' + socket_name,
'INTERCEPT_FUNCTIONS': '*', 'INTERCEPT_VERBOSE': '1',
'INTERCEPT_LIBRARIES': ','.join(['*', '-/lib*', '-/usr/lib*']), 'INTERCEPT_FUNCTIONS': '*',
}) 'INTERCEPT_LIBRARIES': ','.join(['*', '-/lib*', '-/usr/lib*']),
})
except KeyboardInterrupt:
pass
bold(':: REPORT ::') bold(':: REPORT ::')
for call, status in ctx['results'].items(): for call, status in ctx['results'].items():

View File

@@ -49,6 +49,8 @@ def main() -> None:
'INTERCEPT_FUNCTIONS': ','.join(['malloc', 'calloc', 'realloc', 'reallocarray', 'free', 'getaddrinfo', 'freeaddrinfo', 'getline', 'getdelim']), 'INTERCEPT_FUNCTIONS': ','.join(['malloc', 'calloc', 'realloc', 'reallocarray', 'free', 'getaddrinfo', 'freeaddrinfo', 'getline', 'getdelim']),
'INTERCEPT_LIBRARIES': ','.join(['*', '-/lib*', '-/usr/lib*']), 'INTERCEPT_LIBRARIES': ','.join(['*', '-/lib*', '-/usr/lib*']),
}) })
except KeyboardInterrupt:
pass
finally: finally:
with open(log_file, 'rb') as file: with open(log_file, 'rb') as file:
parser = intercept.standard.MemoryAllocationParser(file) parser = intercept.standard.MemoryAllocationParser(file)

View File

@@ -9,6 +9,7 @@ import subprocess
import intercept import intercept
import intercept.standard import intercept.standard
from intercept.standard import CallTreeNode
def neutral(text: str) -> None: def neutral(text: str) -> None:
@@ -53,32 +54,70 @@ def main() -> None:
while ctx.get('state', None) != 'finished': while ctx.get('state', None) != 'finished':
if stdin: if stdin:
stdin.seek(0, 0) stdin.seek(0, 0)
subprocess.run(extra, stdin=stdin, env={ try:
'LD_PRELOAD': os.getcwd() + '/../../intercept/intercept.so', subprocess.run(extra, stdin=stdin, env={
'INTERCEPT': 'unix:' + socket_name, 'LD_PRELOAD': os.getcwd() + '/../../intercept/intercept.so',
'INTERCEPT_VERBOSE': '1', 'INTERCEPT': 'unix:' + socket_name,
'INTERCEPT_FUNCTIONS': '*', 'INTERCEPT_VERBOSE': '1',
'INTERCEPT_LIBRARIES': ','.join(['*', '-/lib*', '-/usr/lib*']), 'INTERCEPT_FUNCTIONS': '*',
}) 'INTERCEPT_LIBRARIES': ','.join(['*', '-/lib*', '-/usr/lib*']),
})
except KeyboardInterrupt:
pass
def iter_tree(node: CallTreeNode, ret=None, layer=0):
yield node.call, node.data, ret, layer, node.children
for ret, child in node.children.items():
if ret != 'ok':
yield from iter_tree(child, ret, layer + 1)
if 'ok' in node.children:
yield from iter_tree(node.children['ok'], 'ok', layer)
def iter_tree_ok(node: CallTreeNode):
yield node.call
if 'ok' in node.children:
yield from iter_tree_ok(node.children['ok'])
bold(':: CALL TREE ::')
for call, _, r, l, c in iter_tree(ctx['call_tree']):
if r not in (None, 'ok'):
neutral(f'::{" " * l}{r}')
neutral(f'::{" " * l} {call}')
neutral('::')
bold(':: REPORT ::') bold(':: REPORT ::')
neutral('::') calls = {}
bold(':: SUMMARY ::') for call, _, r, l, c in iter_tree(ctx['call_tree']):
for i, name in enumerate(ctx['call_sequence']): if l != 0:
errors = [r[1] for r in ctx['results'] if r[0] == i] continue
results = [ctx['results'][(i, e)] for e in errors] if call not in calls:
res = '-' if len(results) == 0 else 'passed' if all(r == 'passed' for r in results) else 'failed' calls[call] = {}
esc, unesc = '', '' entry = calls[call]
errs = [e for e in errors if ctx['results'][(i, e)] == res] for ret, child in c.items():
errs = f' ({", ".join(errs)})' if len(errs) > 0 else '' if ret not in entry:
if sys.stdout.isatty(): entry[ret] = set()
if res == '-': entry[ret].add(tuple(c for c in iter_tree_ok(child)))
esc, unesc = '\x1B[90m', '\x1B[0m'
elif res == 'failed': allowed_cleanup_functions = ['malloc', 'free', 'freeaddrinfo', 'close', 'exit']
res = '\x1B[31;1m' 'FAILED' '\x1B[0m' for call, errors in calls.items():
elif res == 'passed': if len(errors) <= 1:
res = '\x1B[32;1m' 'PASSED' '\x1B[0m' continue
neutral(f':: {esc}{i:3} {name:16} {res}{unesc}{errs}', ) tested = []
failed = []
default_path = errors['ok']
for ret, paths in errors.items():
if ret == 'ok':
continue
errno = ret.split(' ')[1]
tested.append(errno)
if any(p == default_path for p in paths) or any(c.func_name not in allowed_cleanup_functions for p in paths for c in p):
failed.append(errno)
if len(failed) > 0:
red(f':: TEST :: RETURN VALUE CHECK :: {call} :: FAILED')
else:
green(f':: TEST :: RETURN VALUE CHECK :: {call} :: PASSED')
if __name__ == '__main__': if __name__ == '__main__':