#!/usr/bin/python3
# amqp-consume(1) replacement with gpuenv-utils support
#
# Author: Christian Kastner <ckk@kvr.at>
# License: MIT


"""amqp-consume(1) replacement with gpuenv-utils support.

This is not a 1:1 replacement of the original utility. It only implements the
functionality used by debci-worker, as a drop-in.

Furthermore, this utility accesses the debci configuration, via
$debci_config_dir, in order to determine which GPUs this particular worker
uses.
"""


import argparse
import os
import signal
import socket
import ssl
import subprocess
import sys
import logging
import time
import types
from urllib.parse import urlparse

import pika


def debci_config_get_string(option: str) -> str:
    """Gets a debci config option, interpreting its value as a string."""
    cmd = ["debci", "config", "-v", option]
    out = subprocess.check_output(cmd, text=True)
    return out.split("\n", maxsplit=1)[0]


def debci_config_get_list(option: str) -> list[str]:
    """Gets a debci config option, interpreting its value as a list."""
    return debci_config_get_string(option).split()


class CustomAdapter(logging.LoggerAdapter):
    """Custom logging adapter for adding a name prefix"""

    def process(self, msg, kwargs):
        return f"[{self.extra['name']}] {msg}", kwargs


class AMQPConsumer:
    """Consume messages from an AMQP queue when resources become ready.

    Connects to a debci queue. When a message is received, communicates with a
    gpuenv-server to check if the GPUs are ready for use, and executes the job.
    If they are not ready, the message is returned, and the GPUs are waited
    for.

    Parameters
    ----------
    name : str
        The name of the worker, to use as a prefix during logging
    args : list[str]
        The command argument list to execute when receiving a message. The
        AMQP message will be supplied as the command's stdin.
    params : pika.ConnectionParameters
        Connection parameter object for connecting to the server.
    queue : str
        Name of the queue to which to connect.
    socket_path : str
        Path to the socket on which the gpuenv-server is listening.
    slot_ids : list[str]
        PCI slot IDs of the GPUs needed to process commands. If empty, then the
        gpuenv-server will be queried, and all its known GPUs will be used.
    """

    def __init__(
        self,
        name: str,
        args: list[str],
        params: pika.ConnectionParameters,
        queue: str,
        socket_path: str = "/run/gpuenv.socket",
        slot_ids: list[str] | None = None,
    ):
        self._name = name
        self._logger = CustomAdapter(logging.getLogger(__name__), {"name": name})
        self._socket_path = socket_path
        self._args = args
        self._params = params
        self._connection = None
        self._channel = None
        self._queue = queue
        self._devices_locked = False
        if slot_ids:
            self._slot_ids = slot_ids
        else:
            self._slot_ids = self._find_gpus()
        self._logger.info("Using the following devices: %s", ",".join(self._slot_ids))

    def _find_gpus(self) -> list[str]:
        """Return a list of  all GPUs listed in a debci config.

        GPUs are listed as --gpu in the podman- or qemu-based backend args. If no
        GPUs are present, an empty list is returned.
        """
        backend = debci_config_get_string("backend")
        # Backend names may contain a +, the arg variable names may not
        backend_args = f"autopkgtest_args_{backend.replace('+', '')}"
        match backend:
            case name if name.startswith("qemu"):
                autopkgtest_args = debci_config_get_list(backend_args)
            case name if name.startswith("podman"):
                autopkgtest_args = debci_config_get_list(backend_args)
            case _:
                raise NotImplementedError

        parser = argparse.ArgumentParser()
        parser.add_argument("--gpu", action="append")
        args, _ = parser.parse_known_args(autopkgtest_args)
        if args.gpu:
            return args.gpu

        # config did not have --gpu options, so get them all
        match backend:
            case name if name.startswith("qemu"):
                text = self._gpuenv_send("list_all --vfio")
            case name if name.startswith("podman"):
                text = self._gpuenv_send("list_all --no-vfio")
            case _:
                raise NotImplementedError
        return text.split(",")

    def _gpuenv_send(self, message: str) -> str:
        """Trivial message exchange with a gpuenv-server."""
        sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
        sock.connect(self._socket_path)
        sock.settimeout(3.0)
        try:
            sock.send(message.encode())
            reply = sock.recv(1024).decode()
        except socket.timeout:
            self._logger.critical("Socket timed out. Is gpuenv-server running?")
            sys.exit(1)
        sock.close()

        code, text = int(reply[0:3]), reply[4:]
        if code != 200:
            raise RuntimeError(f"Unexpected reply: {code} {text}")
        return text

    def _lock_devices(self) -> bool:
        """Try to get a lock on all necessary GPUs."""
        if self._devices_locked:
            return True
        reply = self._gpuenv_send(f"lock_multi {','.join(self._slot_ids)}")
        self._devices_locked = reply == "True"
        if self._devices_locked:
            self._logger.debug("Acquired lock on all needed devices.")
        else:
            self._logger.debug("Could not acquire lock on all needed devices.")
        return self._devices_locked

    def _unlock_devices(self) -> None:
        """Release locked GPUs."""
        self._gpuenv_send(f"release_multi {','.join(self._slot_ids)}")
        self._devices_locked = False
        self._logger.debug("Released lock on devices.")

    def _devices_are_free(self) -> None:
        """Check if all necessary devices can be locked."""
        reply = self._gpuenv_send(f"is_free_multi {','.join(self._slot_ids)}")
        return reply == "True"

    def consume(self, count: int | None = None) -> None:
        """Emulate amqp-consume(1)."""
        self._connection = pika.BlockingConnection(self._params)
        self._channel = self._connection.channel()
        self._channel.basic_qos(prefetch_count=1)

        processed = 0
        # pylint: disable-next=unused-variable
        for method_frame, properties, body in self._channel.consume(self._queue):
            if not self._lock_devices():
                self._channel.basic_nack(method_frame.delivery_tag)
                self._logger.debug("Waiting for all needed devices to become free...")
                while not self._devices_are_free():
                    time.sleep(60)
                    # Sends heartbeat
                    self._connection.process_data_events()
                continue

            try:
                # pylint: disable-next=consider-using-with
                proc = subprocess.Popen(self._args, stdin=subprocess.PIPE)
                # This is not very clean, but we don't want to use Popen.communicate()
                # because it wait()s, which we want to do ourselves
                proc.stdin.write(body)
                proc.stdin.close()  # send EOF
            except subprocess.CalledProcessError as err:
                self._logger.error("Job failed with exit status %s", err.returncode)
                self._channel.basic_nack(method_frame.delivery_tag)
                continue

            while True:
                try:
                    proc.wait(timeout=30)
                except subprocess.TimeoutExpired:
                    # Sends heartbeat
                    self._connection.process_data_events()
                    continue
                break
            self._channel.basic_ack(method_frame.delivery_tag)

            self._unlock_devices()
            processed += 1
            if count is not None and processed >= count:
                break
        self._channel.cancel()

    def stop(self):
        """Release devices. Useful hook for cleanup phase."""
        if self._devices_locked:
            self._unlock_devices()
        # Note that channel/connection can be closed server-side
        if self._channel:
            if not self._channel.is_closed:
                self._channel.close()
            self._channel = None
        if self._connection:
            if not self._connection.is_closed:
                self._connection.close()
            self._connection = None

    def stop_by_signal(self, signum: int, frame: types.FrameType | None = None) -> None:
        """Callback for signal."""
        self._logger.info("Received signal to stop")
        self.stop()
        sys.exit(0)


def main():
    """Main"""
    if os.getenv("GPUENV_AMQP_DEBUG"):
        level = logging.INFO
    else:
        level = logging.DEBUG
    logging.basicConfig(level=level, format="%(message)s")
    # pika produces a ton of debug messages
    logging.getLogger("pika").propagate = False

    parser = argparse.ArgumentParser("amqp-consume-enhanced")
    parser.add_argument("-u", "--url")
    parser.add_argument("-q", "--queue")
    parser.add_argument("-c", "--count", default=None)
    parser.add_argument("--ssl", action="store_true")
    parser.add_argument("--cacert")
    parser.add_argument("--cert")
    parser.add_argument("--key")
    parser.add_argument("--gpuenv-socket", default="/run/gpuenv.socket")
    parser.add_argument("command", nargs=argparse.REMAINDER)
    # We are passed this option so we need to be able to handle it, but its
    # value is hard-coded to 1, so we don't use it
    parser.add_argument("-p", "--prefetch-count", default=None)
    args = parser.parse_args()

    # Build up ConnectionParameters
    urlp = urlparse(args.url)
    params_kwargs = {
        "host": urlp.hostname,
        "port": socket.getservbyname(urlp.scheme),
    }
    if urlp.username:
        params_kwargs["credentials"] = pika.PlainCredentials(
            urlp.username, urlp.password
        )
    if args.ssl:
        context = ssl.create_default_context(cafile=args.cacert)
        context.load_cert_chain(args.cert, args.key)
        params_kwargs["ssl_options"] = pika.SSLOptions(context, urlp.hostname)
    params = pika.ConnectionParameters(**params_kwargs)

    # A subtle artefact of how we use argparse.REMAINDER
    if args.command[0] == "--":
        args.command = args.command[1:]

    config_dir = debci_config_get_string("config_dir")
    config_name = config_dir.removeprefix("/etc/debci").removesuffix("/")
    config_name = os.path.basename(config_dir)

    consumer = AMQPConsumer(
        name=config_name,
        args=args.command,
        params=params,
        queue=args.queue,
        socket_path=args.gpuenv_socket,
    )

    signal.signal(signal.SIGINT, consumer.stop_by_signal)
    signal.signal(signal.SIGTERM, consumer.stop_by_signal)

    while True:
        try:
            consumer.consume(args.count)
        except KeyboardInterrupt:
            consumer.stop()
            sys.exit(0)
        except (
            pika.exceptions.AMQPConnectionError,
            pika.exceptions.ConnectionClosedByBroker,
        ):
            # When the server closes on us, we simply pause and retry
            time.sleep(60)


if __name__ == "__main__":
    main()
