#!/usr/bin/env python
# SPDX-License-Identifier: GPL-2.0-or-later

import argparse
import os
import os.path
import subprocess
import sys
import typing

from collections import defaultdict
from pathlib import Path

from elftools.common.exceptions import ELFError
from elftools.elf.elffile import ELFFile

from portage import create_trees, _encodings
from portage.versions import cpv_getkey


def get_ldpaths(prefix: Path) -> typing.Set[str]:
    """Grab library paths from ldconfig"""

    ldconf = subprocess.run(["ldconfig", "-r", str(prefix), "-p"],
                            check=True,
                            capture_output=True)

    out = set()
    for line in ldconf.stdout.splitlines():
        if not line.startswith(b"\t"):
            continue

        _, path = line.split(b" => ", 1)
        p_path = Path(path.decode()).parent
        if p_path.name == "lib":
            out.add(p_path)

    return out


def samefile_if_exists(p1: Path, p2: typing.Union[Path, str]) -> bool:
    try:
        return p1.samefile(p2)
    except FileNotFoundError:
        return False


def migrate(prefix: Path, update: bool, mv_shared_libs: bool) -> int:
    """Perform or pretend the migration in prefix"""

    print(f"working in prefix: {prefix!s}")

    trees = create_trees(config_root=prefix, target_root=prefix)
    vartree = trees[max(trees)]["vartree"]
    vardb = vartree.dbapi

    libdirs = get_ldpaths(prefix)
    prefixed_libdirs = tuple(prefix / libdir.relative_to(libdir.root)
                             for libdir in libdirs)
    to_set_runpath = []
    to_move = defaultdict(list)
    to_rewrite_symlinks = {}

    print(f"libdirs determined to be: {' '.join(str(x) for x in libdirs)}")

    new_libdirs = [str(libdir.parent / "libt32") for libdir in libdirs]
    prefixed_new_libdirs = tuple(prefix / libdir.lstrip(os.path.sep)
                                 for libdir in new_libdirs)
    old_new_prefixed_libdirs = prefixed_libdirs + prefixed_new_libdirs

    # iterate over all installed files
    for p in vardb.cpv_all():
        # skip glibc, it's ABI-agnostic
        if cpv_getkey(p) == "sys-libs/glibc":
            continue
        for path_s, details in vardb._dblink(p).getcontents().items():
            # require a file or a symlink
            if details[0] not in ("obj", "sym"):
                continue

            path = Path(path_s)
            # skip directory entries
            if path.is_dir():
                continue

            # move only files directly in libdir, whose names start with "lib"
            if path.parent in prefixed_libdirs and path.name.startswith("lib"):
                if path.is_symlink():
                    target = path.readlink()
                    if str(target.parent) != ".":
                        abs_target_parent = path.parent / target.parent
                        if any(samefile_if_exists(abs_target_parent, libdir)
                               for libdir in prefixed_libdirs):
                            new_target = (target.parent.parent / "libt32" /
                                          target.name)
                            to_rewrite_symlinks[path] = new_target
                to_move[path.parent].append(path.name)

            # set RUNPATH only on regular files
            if details[0] != "obj":
                continue
            try:
                with open(path, "rb") as f:
                    elf_file = ELFFile(f)
                    # process only 32-bit ELF files
                    if elf_file.elfclass != 32:
                        continue

                    # skip files without .dynamic section
                    dyn = elf_file.get_section_by_name(".dynamic")
                    if dyn is None:
                        continue

                    # check for existing RPATH/RUNPATH
                    runpaths = [e for e in dyn.iter_tags()
                                if e.entry["d_tag"] in ("DT_RUNPATH",
                                                        "DT_RPATH")]
                    assert len(runpaths) < 2
                    runpath = None
                    if runpaths:
                        if runpaths[0].entry["d_tag"] == "DT_RPATH":
                            runpath = runpaths[0].rpath
                        else:
                            runpath = runpaths[0].runpath

                    # set RUNPATH only if there is at least one DT_NEEDED entry
                    # or if one is set already
                    # (this should reduce the risk of breaking programs)
                    has_needed = any(e.entry["d_tag"] == "DT_NEEDED"
                                     for e in dyn.iter_tags())

                    if runpath is not None or has_needed:
                        to_set_runpath.append((path, runpath))
            except (FileNotFoundError, ELFError):
                pass

    if not update:
        print("RUNPATH would be set on the following files:")
    else:
        print("setting RUNPATHs:")

    failed = False
    for path, current_runpath in to_set_runpath:
        runpath_dirs = []
        if current_runpath is not None:
            # remove time32 libdir from current_runpath
            # NB: this causes $ORIGIN to be always expanded; this is
            # expected since the file in question may be moved afterwards
            resolved_runpath = (current_runpath
                                .replace("${ORIGIN}", "$ORIGIN")
                                .replace("$ORIGIN",
                                         str(path.root /
                                             path.parent.relative_to(prefix)))
                                ).split(":")
            runpath_dirs = [x for x in resolved_runpath
                            if not any(samefile_if_exists(libdir,
                                                          prefix /
                                                          x.lstrip(os.path.sep))
                                       for libdir in old_new_prefixed_libdirs)]
        runpath_dirs += new_libdirs
        new_runpath = ":".join(runpath_dirs)
        if current_runpath == new_runpath:
            print(f"skipping {path!s} (already {new_runpath})")
            continue
        print(f"{path!s} ({current_runpath} -> {new_runpath})")
        if update:
            path_s = str(path)
            tmp_s = path_s + ".time32.tmp"
            ret = subprocess.run(["cp", "-a", path_s, tmp_s])
            if ret.returncode == 0:
                ret = subprocess.run(["patchelf", "--set-rpath", new_runpath,
                                      tmp_s])
            if ret.returncode == 0:
                ret = subprocess.run(["mv", tmp_s, path_s])
            if ret.returncode != 0:
                failed = True

    if failed:
        print("at least one patchelf call failed, aborting")
        return 1

    print()
    if not update:
        print("the following files would be copied or moved:")
    else:
        print("copying or moving files:")

    for libdir, files in to_move.items():
        copy = []
        move = []
        for filename in files:
            path = libdir / filename
            if mv_shared_libs or ".so" not in filename:
                move.append(filename)
                print(f"[mv] {path} -> ../libt32/")
            else:
                copy.append(filename)
                print(f"[cp] {path} -> ../libt32/")

        if update:
            (libdir.parent / "libt32").mkdir(exist_ok=True)
            if copy:
                ret = subprocess.run(["cp", "-a"] + copy + ["../libt32/"],
                                     cwd=libdir)
                if ret.returncode != 0:
                    failed = True
            if move:
                ret = subprocess.run(["mv"] + move + ["../libt32/"],
                                     cwd=libdir)
                if ret.returncode != 0:
                    failed = True
        for filename in files:
            new_target = to_rewrite_symlinks.get(libdir / filename)
            if new_target is not None:
                path = libdir / "../libt32" / filename
                readlink_path = path if update else libdir / filename
                print(f"rewrite symlink: {path} "
                      f"({readlink_path.readlink()} -> {new_target})")
                if update:
                    ret = subprocess.run(["ln", "-s", "-f",
                                          str(new_target), str(path)])
                    if ret.returncode != 0:
                        failed = True

    if failed:
        print("at least one operation failed")
        return 1

    if update:
        print("success! now you can start rebuilding. when done, remember to:")
        print(f"  rm -rf {' '.join(str(x.parent / 'libt32') for x in to_move)}")

    return 0


def main() -> None:
    os.umask(0o22)

    argp = argparse.ArgumentParser()
    argp.add_argument("-m", "--mv-shared-libs",
                      action="store_true",
                      help="Move rather than copy shared libraries")
    argp.add_argument("-u", "--update",
                      action="store_true",
                      help="Actually perform the migration")
    argp.add_argument("-P", "--prefix",
                      type=Path,
                      default="/",
                      help="Root directory of the system to migrate")
    args = argp.parse_args()

    return migrate(args.prefix, args.update, args.mv_shared_libs)


if __name__ == '__main__':
    sys.exit(main())
