/**
 * Copyright (C) 2023-2024 Atmark Techno, Inc. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */
#include <fcntl.h>
#include <sys/param.h>
#include <sys/syscall.h>

#include "execute_command.h"
#include "agent_time.h"
#include "agent_log.h"

static int pidfd_open(pid_t pid)
{
    return syscall(SYS_pidfd_open, pid, 0);
}

static int set_nonblock(int fd) {
    int flags = fcntl(fd, F_GETFL);
    if (flags < 0)
    {
        AGENT_LOG_WARN("Could not get fd flags: %d", errno);
        return -1;
    }
    flags |= O_NONBLOCK;
    if (fcntl(fd, F_SETFL, flags) < 0)
    {
        AGENT_LOG_WARN("Could not set O_NONBLOCK: %d", errno);
        return -1;
    }
    return flags;
}

static int read_into(String &str, int fd)
{
    int n;
    char buf[4096];
    while (true) {
        n = read(fd, buf, sizeof(buf));
        if (n == 0 || (n < 0 && errno == EAGAIN))
            return 0;
        if (n < 0)
        {
            AGENT_LOG_WARN("read failed: %d", errno);
            return -1;
        }
        str = str.append(buf, n);
    }
}

static bool createPidFile(int pid, String filePath)
{
    std::ofstream pidFile(filePath.c_str());
    if (!pidFile.is_open())
    {
        AGENT_LOG_WARN("Error: Unable to create pid file: %s", filePath.c_str());
        return false;
    }

    pidFile << pid;
    pidFile.close();
    return true;
}

static pid_t spawn_child(const char *command, int *fdout, int *fderr)
{
    // Keep vfork in its own function to avoid clobbering local variables.
    //
    // Moving this code directly in executeCommand yield the following warnings with -O >=1:
    // error: variable 'start_time' might be clobbered by 'longjmp' or 'vfork' [-Werror=clobbered]
    // error: argument 'timeout_ms' might be clobbered by 'longjmp' or 'vfork' [-Werror=clobbered]
    // error: variable 'code' might be clobbered by 'longjmp' or 'vfork' [-Werror=clobbered]
    //
    // This separates the stack for this subfunction and thus prevents the problem.
    int pid = vfork();

    if (pid == 0)
    {
        close(fdout[0]);
        close(fderr[0]);
        dup2(fdout[1], fileno(stdout));
        dup2(fderr[1], fileno(stderr));
        close(fdout[1]);
        close(fderr[1]);

        extern char **environ;
        Vector<char*> exec_env;

        for (char **env = environ; *env != nullptr; ++env)
        {
            if (strncmp(*env, "OPENSSL_CONF=", 13 /* length without \0 */) != 0)
                exec_env.push_back(*env);
        }
        exec_env.push_back(nullptr);

        execle("/bin/sh", "sh", "-c", command, nullptr, exec_env.data());
        _exit(1);
    }
    return pid;
}

// run command with timeout
// - standard output/errors are appended to arguments
// - sends command SIGTERM after timeout and kill 5 seconds later
// (timeout_ms <= 0 disables the timeout)
// - returns exit code, or signal + 128 like shell
int executeCommand(
    String command,
    String &stdOutput,
    String &stdError,
    int timeout_ms,
    bool *is_timeout,
    String jobId)
{
    int fdout[2], fderr[2];
    int pid, pidfd, status;
    int code = -1;
    auto start_time = Monotonic::now();
    int terminate_sig = SIGTERM;

    if (is_timeout != nullptr)
        *is_timeout = false;
    String pidFilePath = PID_FILE_PATH(jobId);

    if (pipe(fdout) < 0)
        return -1;
    if (pipe(fderr) < 0)
    {
        close(fdout[0]);
        close(fdout[1]);
        return -1;
    }

    pid = spawn_child(command.c_str(), fdout, fderr);
    if (pid < 0) // error
    {
        close(fdout[0]);
        close(fdout[1]);
        close(fderr[0]);
        close(fderr[1]);
        return -1;
    }

    // parent
    close(fdout[1]);
    close(fderr[1]);
    pidfd = pidfd_open(pid);
    if (pidfd < 0)
    {
        AGENT_LOG_WARN("pidfd open failed: %d", errno);
        close(fdout[0]);
        close(fderr[0]);
        kill(pid, SIGKILL);
        waitpid(pid, NULL, 0);
        return -1;
    }

    int fdmax = MAX(pidfd, MAX(fdout[0], fderr[0]));
    if (set_nonblock(fdout[0]) < 0 || set_nonblock(fderr[0]) < 0) {
        close(fdout[0]);
        close(fderr[0]);
        kill(pid, SIGKILL);
        waitpid(pid, NULL, 0);
        return -1;
    }

    if (jobId != "")
    {
        if (!createPidFile(pid, pidFilePath))
        {
            close(fdout[0]);
            close(fderr[0]);
            kill(pid, SIGKILL);
            waitpid(pid, NULL, 0);
            return -1;
        }
    }

    while (true)
    {
        fd_set readfds;
        FD_ZERO(&readfds);
        FD_SET(fdout[0], &readfds);
        FD_SET(fderr[0], &readfds);
        FD_SET(pidfd, &readfds);

        struct timeval tv;
        struct timeval *timeout = time_left(start_time, timeout_ms, &tv);
        int ret = select(fdmax + 1, &readfds, NULL, NULL, timeout);
        if (ret < 0 && errno == EAGAIN)
            continue;
        if (ret < 0)
        {
            AGENT_LOG_WARN("select failed: %d", errno);
            goto loop_err;
        }
        if (timeout && timeout->tv_sec == 0 && timeout->tv_usec == 0) {
            // timeout -- send sigterm and wait 5 more seconds before sigkill
            kill(pid, terminate_sig);
            terminate_sig = SIGKILL;
            timeout_ms = 5000;
            if (is_timeout != nullptr)
                *is_timeout = true;
            start_time = Monotonic::now();
        }
        if (FD_ISSET(fdout[0], &readfds))
        {
            if (read_into(stdOutput, fdout[0]) < 0)
                goto loop_err;
        }
        if (FD_ISSET(fderr[0], &readfds))
        {
            if (read_into(stdError, fderr[0]) < 0)
                goto loop_err;
        }
        if (FD_ISSET(pidfd, &readfds))
        {
            // check child really terminated
            ret = waitpid(pid, &status, WNOHANG);
            if (ret < 0) {
                AGENT_LOG_WARN("wait failed: %d", errno);
                goto loop_err;
            }
            if (ret > 0)
                break;
        }
    }

    // check for any trailing output
    if (read_into(stdOutput, fdout[0]) < 0)
        goto loop_err;
    if (read_into(stdError, fderr[0]) < 0)
        goto loop_err;

    close(fdout[0]);
    close(fderr[0]);
    close(pidfd);

    if (WIFEXITED(status))
    {
        code = WEXITSTATUS(status);
    }
    else if (WIFSIGNALED(status))
    {
        // act like shells: return 128 + signo on signal
        code = 128 + WTERMSIG(status);
    }
    else
    {
        AGENT_LOG_WARN("child process neither exited nor signaled?");
    }

    std::remove(pidFilePath.c_str());

    return code;

loop_err:
    close(fdout[0]);
    close(fderr[0]);
    close(pidfd);
    waitpid(pid, NULL, 0);
    std::remove(pidFilePath.c_str());
    return -1;
}
