/**
 * Copyright (C) 2023-2024 Atmark Techno, Inc. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

#include <arpa/inet.h>
#include <netinet/icmp6.h>
#include <stdio.h>
#include <sys/random.h>
#include <errno.h>
#include <sys/socket.h>
#include <netinet/ip_icmp.h>
#include <netdb.h>

#include "ping.h"
#include "execute_command.h"
#include "agent_utils.h"
#include "agent_time.h"
#include "agent_log.h"

// netinet/icmp6 defines ICMP6_ECHO_REQUEST and friend but there's no
// equivalent for ICMP4 ?!
//#define ICMP_ECHO_RESPONSE 0 -> ICMP_ECHOREPLY
//#d/efine ICMP_DESTINATION_UNREACHABLE 3 -> ICMP_DEST_UNREACH
//#define ICMP_ECHO_REQUEST 8 -> ICMP_ECHO
//#define ICMP_TIME_EXCEEDED 11 -> ICMP_TIME_EXCEEDED

union sockaddr46 {
    struct sockaddr         sa;
    struct sockaddr_in      s4;
    struct sockaddr_in6     s6;
    struct sockaddr_storage ss;
};

// custom errors for fallback logic
#define ERR_OTHER -1
#define ERR_NOROOT -2
#define ERR_UNREACH -3

static uint16_t icmp_cksum(char *packet, size_t len) {
#ifndef NDEBUG
    if (len % 2 != 0) {
        AGENT_LOG_WARN("Only handling even packet size, aborting");
        return 0;
    }
#endif

    uint16_t *buf = reinterpret_cast<uint16_t *>(packet);
    uint32_t sum = buf[0];
    for (size_t i = 2; i < len/2; i++) {
        sum += buf[i];
    }
    sum = (sum >> 16) + (sum & 0xffff);
    sum += (sum >> 16);
    return ~sum;
}

static long ping4(union sockaddr46 &addr) {
    int sock = socket(addr.ss.ss_family, SOCK_RAW, IPPROTO_ICMP);
    if (sock < 0)
        return ERR_NOROOT;

    char request[64] = { 0 };
    struct icmp *packet = (struct icmp*)request;
    packet->icmp_type = ICMP_ECHO;
    getrandom(&packet->icmp_id, sizeof(packet->icmp_id), 0);
    packet->icmp_cksum = icmp_cksum(request, sizeof(request));

    auto start_time = Monotonic::now();

    if (sendto(sock, &request, sizeof(request), 0, &addr.sa, sizeof(addr)) < 0) {
        int save_errno = errno;
        close(sock);
        if (save_errno == ENETUNREACH)
            return ERR_UNREACH;
        return ERR_OTHER;
    }

    char reply[128];
    int ret;
    while (1) {
        fd_set readfds;
        FD_ZERO(&readfds);
        FD_SET(sock, &readfds);
        struct timeval tv;
        struct timeval *timeout = time_left(start_time, 10000, &tv);
        ret = select(sock + 1, &readfds, NULL, NULL, timeout);
        if (ret < 0 && errno == EAGAIN)
            continue;
        if (ret < 0) {
            AGENT_LOG_WARN("ping check: select failed: %d", errno);
            close(sock);
            return ERR_OTHER;
        }
        if (ret == 0) {
            // timeout
            close(sock);
            return ERR_OTHER;
        }
        int size = recv(sock, &reply, sizeof(reply), 0);
        if (size < 0) {
            AGENT_LOG_WARN("ping check: Error receiving reply: %d", errno);
            close(sock);
            return ERR_OTHER;
        }
        struct iphdr *iphdr = (struct iphdr*)reply;
        if (size < ICMP_MINLEN || size < (iphdr->ihl << 2) + ICMP_MINLEN) {
            // too small
            continue;
        }
        struct icmp *reply_packet = (struct icmp*)(reply + (iphdr->ihl << 2));
        if (reply_packet->icmp_type == ICMP_DEST_UNREACH) {
            // XXX check this was for us
            close(sock);
            return ERR_UNREACH;
        }
        if (reply_packet->icmp_type != ICMP_ECHOREPLY)
            continue;
        if (reply_packet->icmp_id == packet->icmp_id)
            break;
    }

    auto now = Monotonic::now();
    long long rtt = monotonic_us_diff(now, start_time);

    close(sock);
    return rtt;
}

static long ping6(union sockaddr46 &addr) {
    int sock = socket(addr.ss.ss_family, SOCK_RAW, IPPROTO_ICMPV6);
    if (sock < 0)
        return ERR_NOROOT;

    char request[64] = { 0 };
    struct icmp6_hdr *packet = (struct icmp6_hdr*)request;
    packet->icmp6_type = ICMP6_ECHO_REQUEST;
    getrandom(&packet->icmp6_id, sizeof(packet->icmp6_id), 0);
    // ipv6 checksum is computed by the kernel (includes ip header)

    auto start_time = Monotonic::now();

    if (sendto(sock, &request, sizeof(request), 0, &addr.sa, sizeof(addr)) < 0) {
        int save_errno = errno;
        close(sock);
        if (save_errno == ENETUNREACH)
            return ERR_UNREACH;
        return ERR_OTHER;
    }

    char reply[128];
    int ret;
    while (1) {
        fd_set readfds;
        FD_ZERO(&readfds);
        FD_SET(sock, &readfds);
        struct timeval tv;
        struct timeval *timeout = time_left(start_time, 10000, &tv);
        ret = select(sock + 1, &readfds, NULL, NULL, timeout);
        if (ret < 0 && errno == EAGAIN)
            continue;
        if (ret < 0) {
            AGENT_LOG_WARN("ping check: select failed: %d", errno);
            close(sock);
            return ERR_OTHER;
        }
        if (ret == 0) {
            // timeout
            close(sock);
            return ERR_OTHER;
        }
        int size = recv(sock, &reply, sizeof(reply), 0);
        if (size < 0) {
            AGENT_LOG_WARN("ping check: Error receiving reply: %d", errno);
            close(sock);
            return ERR_OTHER;
        }
        if ((size_t)size < sizeof(struct icmp6_hdr)) {
            // too small, ignore
            continue;
        }
        struct icmp6_hdr *reply_packet = (struct icmp6_hdr*)reply;
        if (reply_packet->icmp6_type == ICMP6_DST_UNREACH) {
            // XXX check this was for us...
            close(sock);
            return ERR_UNREACH;
        }
        if (reply_packet->icmp6_type != ICMP6_ECHO_REPLY)
            continue;
        if (reply_packet->icmp6_id == packet->icmp6_id)
            break;
    }
    close(sock);

    auto now = Monotonic::now();
    long long rtt = monotonic_us_diff(now, start_time);

    return rtt;
}

long ping_command(String dst)
{
    // agent should never run as non-root, but this doesn't hurt to keep...
    String out, err;
    executeCommand(
        "ping -c 1 " + dst + " | awk -F '/' '"
            // busybox ping
            "/^round-trip / { print $4 * 1000}"
            // iputils ping
            "/^rtt / { print $5 * 1000 }"
            "'",
        out,
        err
    );
    if (out.empty())
    {
        AGENT_LOG_WARN("ping check: Failed to exec ping: %s", err.c_str());
        return -1;
    }

    return stoi_safe(out);
}

bool resolve_hostname(int af, const char *hostname, union sockaddr46 &addr)
{
    struct addrinfo hints, *result;
    std::memset(&hints, 0, sizeof(struct addrinfo));
    hints.ai_family = af;
    hints.ai_socktype = SOCK_STREAM;
    hints.ai_flags = 0;
    hints.ai_protocol = 0;

    int status = getaddrinfo(hostname, nullptr, &hints, &result);
    if (status != 0)
    {
        // suppress messages about the absence of the target host IPv4/6 address
        if (status != EAI_NODATA)
            AGENT_LOG_WARN("ping check: getaddrinfo: %s", gai_strerror(status));
        return false;
    }

    if (result->ai_family == AF_INET)
    {
        addr.s4 = *((struct sockaddr_in*)(result->ai_addr));
    }
    else if (result->ai_family == AF_INET6)
    {
        addr.s6 = *((struct sockaddr_in6*)(result->ai_addr));
    }
    else
    {
        AGENT_LOG_WARN("ping check: failed to resolve hostname: ai_family = %d", result->ai_family);
        freeaddrinfo(result);
        return false;
    }

    freeaddrinfo(result);
    return true;
}

long ping(const char *hostname) {
    union sockaddr46 addr = { };
    long rc = ERR_OTHER;

    // try ipv6 first
    if (resolve_hostname(AF_INET6, hostname, addr))
    {
        rc = ping6(addr);
        if (rc > 0)
            return rc;
        else if (rc == ERR_NOROOT)
            return ping_command(String(hostname)); // fall back to ping command
    }

    // fall back to ipv4
    memset(&addr, 0, sizeof(addr));
    if (resolve_hostname(AF_INET, hostname, addr))
    {
        rc = ping4(addr);
        if (rc > 0)
            return rc;
        else if (rc == ERR_NOROOT)
            return ping_command(String(hostname)); // fall back to ping command
    }

    return rc;
}
