Skip to content
compare-functions.py 8.39 KiB
Newer Older
#!/usr/bin/env python3

import argparse
import json

from binascii import unhexlify
from elftools.elf.elffile import ELFFile
from rangeset import RangeSet

empty = (RangeSet(0,0)).difference(RangeSet(-1,1))

strrange = lambda cov: " ".join([hex(x[0]) + "-" + hex(x[1]) for x in cov if x[0] != x[1]])

def get_change(current, previous):
    if current == previous:
        return 0.0
    try:
        return (abs(current - previous) / previous) * 100.0
    except ZeroDivisionError:
        return 100.0

nops_encoding = {
    "EM_ARM": map(unhexlify, ["e320f000"]),
    "EM_X86_64": map(unhexlify, ["90"]),
    "EM_MIPS": map(unhexlify, [])
}

def section_by_name(elf, name):
    matches = [x for x in elf.iter_sections() if x.name == name]
    assert len(matches) == 1
    return matches[0]

def collect_symbols_coverage(symtab, ignore):
    symbols = []
    for symbol in symtab.iter_symbols():
        if symbol.entry.st_value != 0:
            range = RangeSet(symbol.entry.st_value, symbol.entry.st_value + symbol.entry.st_size)
            range = range.difference(ignore)
            symbols.append((str(symbol.name), range))
    return symbols

def strings_to_ranges(results, ignore):

    # For each function create the python coverage rage starting from strings
    for function in results:
        coverage = empty
        for x in function["coverage"]:
            coverage = coverage | RangeSet(int(x["start"], 16), int(x["end"], 16))
        function["coverage"] = coverage.difference(ignore)

def collect_nops(elf):
    segments = [x for x in elf.iter_segments() if x.header.p_type == "PT_LOAD"]
    nops = empty
    for segment in segments:
        base_addr = segment.header.p_vaddr
        content = segment.data()
        for nop in nops_encoding[elf.header.e_machine]:
            last_match = 0
            match = content.find(nop, last_match)
            while match != -1:
                last_match = match + 1
                nops |= RangeSet(base_addr + match, base_addr + match + len(nop))
                match = content.find(nop, last_match)
    return nops

def collect_constant_pools(elf, symtab):
    MAPPING_DATA = 0
    MAPPING_CODE = 1
    constant_pools = empty
    blacklist = set()
    for symbol in symtab.iter_symbols():
        if symbol.entry.st_size != 0 and symbol.entry.st_value != 0 and symbol.entry.st_info["type"] == "STT_FUNC":
            blacklist.add(symbol.entry.st_value)

    ranges = []
    for symbol in symtab.iter_symbols():
        if symbol.entry.st_size != 0 and symbol.entry.st_value != 0 and symbol.entry.st_info["type"] == "STT_FUNC":
            ranges.append((MAPPING_CODE, symbol.entry.st_value, symbol.entry.st_shndx))

        if symbol.entry.st_size == 0 and symbol.entry.st_info.bind == "STB_LOCAL" and symbol.entry.st_info.type == "STT_NOTYPE" and symbol.name.startswith("$"):
            if symbol.name[1] == "a":
                ranges.append((MAPPING_CODE, symbol.entry.st_value, symbol.entry.st_shndx))
            elif symbol.name[1] == "d":
                if not (symbol.entry.st_value in blacklist):
                    ranges.append((MAPPING_DATA, symbol.entry.st_value, symbol.entry.st_shndx))
            else:
                raise "Unexpected symbol"

    ranges = sorted(ranges, key=lambda x: x[1])

    for i in range(elf.num_sections()):
        the_section = elf.get_section(i)
        section_ranges = map(lambda x: (x[0], x[1]), filter(lambda x: x[2] == i, ranges))
        section_ranges = list(section_ranges)
        to_skip = 0
        while to_skip < len(section_ranges) and section_ranges[to_skip][0] == MAPPING_CODE:
            to_skip += 1

        section_ranges = section_ranges[to_skip:]
        last_mapping_type = MAPPING_CODE
        last_start = 0
        for mapping_type, start in section_ranges:
            if last_mapping_type == MAPPING_DATA and mapping_type == MAPPING_CODE:
                constant_pools |= RangeSet(last_start, start)

            if mapping_type != last_mapping_type:
                last_mapping_type, last_start = mapping_type, start

        if last_mapping_type == MAPPING_DATA:
            constant_pools |= RangeSet(last_start, the_section.header.sh_addr + the_section.header.sh_size)

    return constant_pools

def main():
    parser = argparse.ArgumentParser(description='My nice tool.')
    parser.add_argument('--only-start', action='store_true', help="Match only functions starting at the same exact address.")
    parser.add_argument('ida', metavar='IDAFILE', help='IDA created file')
    parser.add_argument('revng', metavar='REVNGFILE', help='rev.ng created file.')
    parser.add_argument('elf', metavar='ELF', help='rev.ng created file.')
    parser.add_argument('outputjson', metavar='OUTPUTFILEJSON', help='File where to write output info in JSON format.')
    parser.add_argument('outputmatching', metavar='OUTPUTFILEMATCH', help='File where to write the functions that match.')
    args = parser.parse_args()

    with open(args.revng, "r") as revng_file, open(args.ida, "r") as ida_file, open(args.elf, "rb") as elf_file, open(args.outputjson, "w") as output_json, open(args.outputmatching, "w") as output_matching:

        # JSON output file
        results = []

        # Load the elf file
        elf = ELFFile(elf_file)
        symtab = section_by_name(elf, ".symtab")
        nops = collect_nops(elf)
        constant_pools = collect_constant_pools(elf, symtab)
        to_ignore = nops | constant_pools

        symbols = collect_symbols_coverage(symtab, to_ignore)

        # Load the files produced by IDA and rev.ng
        revng = json.load(revng_file)
        ida = json.load(ida_file)

        # Convert the coverage description to python ranges
        strings_to_ranges(revng, to_ignore)
        strings_to_ranges(ida, to_ignore)

        # Collect percentage of matching
        total = 0
        matching = 0

        # Write in output for each function the coverage measurements
        for function_revng in revng:
            for function_ida in ida:
                for symbol in symbols:
                    symbol_name, symbol_coverage = symbol

                    # Use function entry address
                    if int(function_revng['entry_point_address'], 16) == int(function_ida['entry_point_address'], 16) and function_revng['entry_point'].replace('bb.', '') == symbol_name:
                        coverage_revng = function_revng["coverage"]
                        coverage_ida = function_ida["coverage"]
                        coverage_symbol = symbol_coverage
                        total += 1

                        # Set that there is a match as soon as revng and ida matches, do not worry about symbol coverage
                        percentage_change = get_change(coverage_revng.measure(), coverage_ida.measure())
                        match = percentage_change <= 10.0

                        if not match:
                            print()
                            print("We found a non-matching function")
                            print(function_ida['entry_point'] + "/" + function_revng['entry_point'])
                            print("Percentage discepancy was: " + str(percentage_change))

                        item = {'function_name': function_revng['entry_point'],
                                'symbol_coverage': str(coverage_symbol),
                                'symbol_coverage_measure': coverage_symbol.measure(),
                                'revng_coverage': str(coverage_revng),
                                'revng_coverage_measure': coverage_revng.measure(),
                                'ida_coverage': str(coverage_ida),
                                'ida_coverage_measure': coverage_ida.measure(),
                                'match': str(match),
                                'revng-ida' : strrange(function_revng["coverage"] - function_ida["coverage"]),
                                'ida-revng' : strrange(function_ida["coverage"] - function_revng["coverage"])
                                }
                        results.append(item)

                        # Add to the .matching file only functions that match
                        if match:
                            output_matching.write(function_revng['entry_point'] + '\n')
                            matching += 1

        # Write the json on the output file
        json.dump(results, output_json, indent=2)

        if total == 0:
            print("No matching function")
        else:
            print(matching/total)

if __name__ == "__main__":
    main()