#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import sys
import argparse
import threading
import subprocess

import intercept
import intercept.standard
from intercept.standard import CallTreeNode


def neutral(text: str) -> None:
    print(text, file=sys.stderr)

def color(text: str, col: str) -> None:
    print(('\x1B[' + col + 'm' if sys.stderr.isatty() else '')
          + text
          + ('\x1B[0m' if sys.stderr.isatty() else ''),
          file=sys.stderr)

def bold(text: str) -> None:
    color(text, '1')

def test_result(test: list[str], result: bool or None) -> None:
    text = ':: TEST :: ' + ' :: '.join(test)
    esc = '\x1B[' + ({True: '32', False: '31', None: ''}[result]) + 'm' if sys.stderr.isatty() else ''
    result = {True: 'PASSED', False: 'FAILED', None: 'SKIPPED'}[result]
    bold(f'{text} {"." * (80 - len(text))} {esc}{result}')


def socket_thread(socket: str, handler: type[intercept.Handler]) -> None:
    intercept.intercept(socket, handler)


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', '--stdin')
    args, extra = parser.parse_known_args()
    if len(extra) > 0 and extra[0] == '--':
        extra.pop(0)
    if len(extra) == 0:
        parser.error("command expected after arguments or '--'")

    stdin = open(args.stdin) if args.stdin else None

    socket_name = f'/tmp/intercept.return-values.{os.getpid()}.sock'
    ctx, handler = intercept.standard.init_with_ctx(intercept.standard.ReturnValueCheckTester)
    t1 = threading.Thread(target=socket_thread, args=(socket_name, handler))
    t1.daemon = True
    t1.start()

    n = 0
    while ctx.get('state', None) != 'finished':
        n += 1
        if stdin:
            stdin.seek(0, 0)
        try:
            subprocess.run(extra, stdin=stdin, env={
                'LD_PRELOAD': os.getcwd() + '/../../intercept/intercept.so',
                'INTERCEPT': 'unix:' + socket_name,
                'INTERCEPT_VERBOSE': '1',
                'INTERCEPT_FUNCTIONS': '*',
                'INTERCEPT_LIBRARIES': ','.join(['*', '-/lib*', '-/usr/lib*']),
            })
        except KeyboardInterrupt:
            if n > 64:
                break

    def iter_tree(node: CallTreeNode, ret=None, layer=0):
        yield node.call, node.data, ret, layer, node.children
        for ret, child in node.children.items():
            if ret != 'ok':
                yield from iter_tree(child, ret, layer + 1)
        if 'ok' in node.children:
            yield from iter_tree(node.children['ok'], 'ok', layer)

    def iter_tree_ok(node: CallTreeNode):
        yield node.call
        if 'ok' in node.children:
            yield from iter_tree_ok(node.children['ok'])

    bold(':: CALL TREE ::')
    for call, _, r, l, c in iter_tree(ctx['call_tree']):
        if r not in (None, 'ok'):
            neutral(f'::{"  " * l}{r}')
        neutral(f'::{"  " * l} {call}')
    neutral('::')

    bold(':: REPORT ::')
    calls = {}
    for call, _, r, l, c in iter_tree(ctx['call_tree']):
        if l != 0:
            continue
        if call not in calls:
            calls[call] = {}
        entry = calls[call]
        for ret, child in c.items():
            if ret not in entry:
                entry[ret] = set()
            entry[ret].add(tuple(c for c in iter_tree_ok(child)))

    allowed_cleanup_functions = ['malloc', 'free', 'freeaddrinfo', 'close', 'exit']
    for call, errors in calls.items():
        if len(errors) <= 1:
            continue
        tested = []
        failed = []
        default_path = errors['ok']
        for ret, paths in errors.items():
            if ret == 'ok':
                continue
            errno = ret.split(' ')[1]
            tested.append(errno)
            if any(p == default_path for p in paths) or any(c.func_name not in allowed_cleanup_functions for p in paths for c in p):
                failed.append(errno)
        if len(failed) > 0:
            test_result(['Return value check', str(call)], False)
        else:
            test_result(['Return value check', str(call)], True)




if __name__ == '__main__':
    main()
