#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import subprocess
import os
import sys

import intercept
import intercept.standard


def neutral(text: str) -> None:
    print(text, file=sys.stderr)

def color(text: str, col: str) -> None:
    print(('\x1B[' + col + 'm' if sys.stderr.isatty() else '')
          + text
          + ('\x1B[0m' if sys.stderr.isatty() else ''),
          file=sys.stderr)

def bold(text: str) -> None:
    color(text, '1')

def test_result(test: list[str], result: bool or None) -> None:
    text = ':: TEST :: ' + ' :: '.join(test)
    esc = '\x1B[' + ({True: '32', False: '31', None: ''}[result]) + 'm' if sys.stderr.isatty() else ''
    result = {True: 'PASSED', False: 'FAILED', None: 'SKIPPED'}[result]
    bold(f'{text} {"." * (80 - len(text))} {esc}{result}')


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', '--stdin')
    args, extra = parser.parse_known_args()
    if len(extra) > 0 and extra[0] == '--':
        extra.pop(0)
    if len(extra) == 0:
        parser.error("command expected after arguments or '--'")

    stdin = open(args.stdin) if args.stdin else sys.stdin
    log_file = f'/tmp/intercept.memory.{os.getpid()}.log'

    try:
        subprocess.run(extra, stdin=stdin, env={
            'LD_PRELOAD': os.getcwd() + '/../../intercept/intercept.so',
            'INTERCEPT': 'file:' + log_file,
            'INTERCEPT_VERBOSE': '1',
            'INTERCEPT_FUNCTIONS': ','.join(['malloc', 'calloc', 'realloc', 'reallocarray', 'free', 'getaddrinfo', 'freeaddrinfo', 'getline', 'getdelim']),
            'INTERCEPT_LIBRARIES': ','.join(['*', '-/lib*', '-/usr/lib*']),
        })
    except KeyboardInterrupt:
        pass
    finally:
        with open(log_file, 'rb') as file:
            parser = intercept.standard.MemoryAllocationParser(file)
            parser.parse()
            bold(':: REPORT ::')
            neutral('::')
            if len(parser.allocated) > 0:
                test_result(['Memory leaks'], False)
                neutral("::   Not free'd:")
                for ptr, (call, size) in parser.allocated.items():
                    neutral(f'::     0x{ptr:x}: {size:>6} bytes ({call})')
            else:
                test_result(['Memory leaks'], True)
                neutral("::   All allocated memory blocks were free'd!")
            neutral('::')
            if len(parser.invalid_frees) > 0:
                test_result(['Invalid frees'], False)
                neutral('::   Invalid/double frees:')
                for (call, ptr) in parser.invalid_frees:
                    neutral(f'::     {call.func_name}: 0x{ptr:x} ({call.discriminator})')
            else:
                test_result(['Invalid frees'], True)
                neutral('::   No invalid/double frees occured!')
            neutral('::')
            neutral(f':: #allocs: {parser.num_alloc}, #reallocs: {parser.num_realloc}, #frees: {parser.num_free}')
            neutral(f':: Max dynamically allocated: {parser.max_allocated} bytes')
        os.remove(log_file)


if __name__ == '__main__':
    main()
