#!/usr/bin/env python3

import argparse
import os
import re
import sys
from collections import defaultdict

from elftools.elf.dynamic import DynamicSegment
from elftools.elf.elffile import ELFFile


def log(message):
    sys.stderr.write(message + "\n")


def ignore_toolchains(component):
    return not component.startswith("toolchain")


def read_file(path):
    with open(path, "r") as input_file:
        return [line.strip() for line in input_file]


def is_executable(path):
    return os.access(path, os.X_OK)


def is_elf(path):
    with open(path, "rb") as input_file:
        return input_file.read(4) == b"\x7FELF"


def unique_or_none(list):
    if len(list) == 1:
        return list[0]
    else:
        return None


def get_dynamic(elf):
    return unique_or_none([segment
                           for segment
                           in elf.iter_segments()
                           if type(segment) is DynamicSegment])


class Root:
    def __init__(self, root_path):
        self.file_map = dict()
        self.reverse_file_map = defaultdict(list)
        self.root_path = root_path
        self.package_files_path = os.path.join(self.root_path, "share", "orchestra")
        self.all_files = set(["lib"])

    def load_file(self, path):
        files = read_file(path)
        path = os.path.relpath(path, self.package_files_path)
        for file in files:
            if file.startswith("./"):
                file = file[2:]
            self.reverse_file_map[file].append(path)
            self.all_files.add(file)
        self.file_map[path] = files

    def load_package_files(self):
        # Walk recursively all the file the text files
        for directory, subdirectories, files in os.walk(self.package_files_path):
            for file in files:
                # Skip metadata files
                if file.endswith(".json"):
                    continue
                self.load_file(os.path.join(directory, file))

    def report_duplicates(self):
        header = False
        for file, packages in self.reverse_file_map.items():
            if len(packages) > 1:
                if not header:
                    header = True
                    log("Files in multiple packages:")
                log("  {}:".format(file))
                for package in packages:
                    log("    {}".format(package))

        return header

    def collect_installed_files(self):
        self.installed_files = set()
        for directory, subdirectories, files in os.walk(self.root_path):
            for subdirectory in subdirectories:
                path = os.path.join(directory, subdirectory)
                if os.path.islink(path):
                    self.installed_files.add(os.path.relpath(path, self.root_path))
            for file in files:
                path = os.path.join(directory, file)
                self.installed_files.add(os.path.relpath(path, self.root_path))

    def check_installed_files(self):
        missing_files = self.all_files - self.installed_files
        if missing_files:
            log("The following files are listed as installed but are not"
                + " present in root:")
            for missing_file in sorted(missing_files):
                log("  {}".format(missing_file))

        extra_files = self.installed_files - self.all_files
        if extra_files:
            log("The following files are present in root but do not belong to"
                + " any component:")
            for extra_file in sorted(extra_files):
                log("  {}".format(extra_file))

        return len(missing_files) > 0 or len(extra_files) > 0

    def is_for_host(self, path, elf):
        if elf.header.e_machine != "EM_X86_64":
            return False

        return True

    def prepare_file_list(self, files, prefix="", component_filter=lambda component: True):
        result = ""
        by_component = defaultdict(list)
        for file in files:
            components = ""
            if file in self.reverse_file_map:
                for component in self.reverse_file_map[file]:
                    by_component[component].append(file)
            else:
                by_component["(orphan)"].append(file)

        for component, files in sorted(by_component.items()):
            if not component_filter(component):
                continue
            result += "{}{}:\n".format(prefix, component)
            for file in files:
                result += "{}  {}\n".format(prefix, file)

        return result

    def print_file_list(self, files, prefix=""):
        log(self.prepare_file_list(files, prefix))

    def verify_elfs(self):
        missing_libraries = defaultdict(list)
        libraries_in_root = defaultdict(list)
        allowed_glibc_versions = set()
        used_glibc_versions = dict()
        invalid_runpaths = defaultdict(list)

        for installed_file in sorted(self.installed_files):
            path = os.path.join(self.root_path, installed_file)
            if os.path.isfile(path) and is_executable(path) and is_elf(path):
                with open(path, "rb") as elf_file:
                    elf = ELFFile(elf_file)
                    dynamic_segment = get_dynamic(elf)
                    if (self.is_for_host(installed_file, elf)
                            and dynamic_segment):

                        if "link-only" not in installed_file:
                            libraries_in_root[os.path.basename(installed_file)].append(installed_file)

                        # Get the string table
                        tag = unique_or_none([tag
                                              for tag
                                              in dynamic_segment.iter_tags()
                                              if tag.entry.d_tag == "DT_STRTAB"])
                        string_table_address = tag.entry.d_val
                        string_table_offset = unique_or_none(list(elf.address_offsets(string_table_address)))

                        tag = unique_or_none([tag
                                              for tag
                                              in dynamic_segment.iter_tags()
                                              if tag.entry.d_tag == "DT_STRSZ"])
                        string_table_size = tag.entry.d_val
                        elf_file.seek(string_table_offset)
                        string_table = elf_file.read(string_table_size)

                        glibc_versions = set([version.strip(b"\x00").decode("ascii")
                                              for version in
                                              re.findall(b"GLIBC_[0-9.]*\x00", string_table)])
                        if "link-only" in installed_file:
                            allowed_glibc_versions = allowed_glibc_versions.union(glibc_versions)
                        else:
                            used_glibc_versions[installed_file] = glibc_versions

                        runpaths = []
                        runpath_tag = unique_or_none([tag
                                                      for tag
                                                      in dynamic_segment.iter_tags()
                                                      if tag.entry.d_tag == "DT_RUNPATH"])
                        if runpath_tag:
                            runpath = string_table[runpath_tag.entry.d_val:].split(b"\x00")[0].decode("ascii")
                            runpath = runpath.replace("$ORIGIN", os.path.dirname(os.path.realpath(path)))
                            runpaths = runpath.split(":")
                            runpaths = map(os.path.realpath, runpaths)
                            runpaths = [os.path.relpath(runpath, self.root_path)
                                        for runpath
                                        in runpaths]
                            runpaths = list(set(runpaths))

                            for runpath in runpaths:
                                path = os.path.join(self.root_path, runpath)
                                if not (os.path.isdir(path) or os.path.islink(path)):
                                    invalid_runpaths[runpath].append(installed_file)

                        # Collect DT_NEEDED
                        needed_string_offsets = [tag.entry.d_val
                                                 for tag
                                                 in dynamic_segment.iter_tags()
                                                 if tag.entry.d_tag == "DT_NEEDED"]
                        for needed_string_offset in needed_string_offsets:
                            lib_name = string_table[needed_string_offset:].split(b"\x00")[0].decode("ascii")
                            found = False
                            for runpath in runpaths:
                                candidate = os.path.relpath(os.path.join(self.root_path,
                                                                         runpath,
                                                                         lib_name),
                                                            self.root_path)
                                if candidate in self.all_files:
                                    found = True
                                    break

                            if not found:
                                missing_libraries[lib_name].append(installed_file)

        if invalid_runpaths:
            log("The following runpaths are invalid:")
            for runpath, users in invalid_runpaths.items():
                log("  {}".format(runpath))
                self.print_file_list(users, "    ")

        system_libraries = []
        for missing_library, users in missing_libraries.items():
            if missing_library in libraries_in_root:
                file_list = self.prepare_file_list(users, "    ", ignore_toolchains)
                if file_list:
                    log("{} is available in root".format(missing_library))
                    log("  These are the instances:")
                    self.print_file_list(libraries_in_root[missing_library], "    ")
                    log("  These are the users:")
                    log(file_list)
            else:
                system_libraries.append((missing_library, users))

        if system_libraries and False:
            log("The following libraries are not provided in root:")
            for system_library, users in system_libraries:
                log("  {}:".format(system_library))
                self.print_file_list(users, "    ")

        by_version = defaultdict(list)
        for installed_file, versions in used_glibc_versions.items():
            unallowed_versions = versions - allowed_glibc_versions
            for unallowed_version in unallowed_versions:
                by_version[unallowed_version].append(installed_file)

        to_print = list()
        for version, users in sorted(by_version.items()):
            file_list = self.prepare_file_list(users, "    ", ignore_toolchains)
            if file_list:
                to_print.append((version, users, file_list))

        if to_print:
            log("The following unallowed glibc versions are being used:")
            for version, users, file_list in to_print:
                log("  {}".format(version))
                log(file_list)

        return any(len(x) > 0 for x in [invalid_runpaths, to_print])


def main():
    parser = argparse.ArgumentParser(description="Verify integrity of an orchestra root.")
    parser.add_argument("root_path", metavar="ROOT_PATH", default=".", help="Path to Orchestra root")
    args = parser.parse_args()

    root_path = args.root_path

    root = Root(root_path)

    root.load_package_files()
    duplicates_found = root.report_duplicates()
    root.collect_installed_files()
    orphans_found = root.check_installed_files()
    errors_in_elfs = root.verify_elfs()

    if duplicates_found or orphans_found or errors_in_elfs:
        log("[!] Inconsistencies found in the root directory!")
        return 1
    else:
        log("Root directory consistency checks passed!")
        return 0


if __name__ == "__main__":
    sys.exit(main())