/**
 * Copyright (C) 2023-2024 Atmark Techno, Inc. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */
#include "command_line_utils.h"
#include "agent_config.h"
#include "agent_log.h"
#include "version.h"
#include <aws/crt/Api.h>
#include <aws/crt/Types.h>
#include <aws/crt/auth/Credentials.h>
#include <aws/crt/io/Pkcs11.h>
#include <iostream>

namespace Utils
{
    CommandLineUtils::CommandLineUtils()
    {
        // Automatically register the help command
        RegisterCommand(m_cmd_help, "", "Prints this message");
    }

    void CommandLineUtils::RegisterCommand(CommandLineOption option)
    {
        if (m_registeredCommands.count(option.m_commandName))
        {
            AGENT_LOG_DEBUG("Cannot register command: %s: Command already registered!",
                            option.m_commandName.c_str());
            return;
        }
        m_registeredCommands.insert({option.m_commandName, option});
    }

    void CommandLineUtils::RegisterCommand(
        Aws::Crt::String commandName,
        Aws::Crt::String exampleInput,
        Aws::Crt::String helpOutput)
    {
        RegisterCommand(CommandLineOption(commandName, exampleInput, helpOutput));
    }

    void CommandLineUtils::RemoveCommand(Aws::Crt::String commandName)
    {
        if (m_registeredCommands.count(commandName))
        {
            m_registeredCommands.erase(commandName);
        }
    }

    void CommandLineUtils::UpdateCommandHelp(Aws::Crt::String commandName, Aws::Crt::String newCommandHelp)
    {
        if (m_registeredCommands.count(commandName))
        {
            m_registeredCommands.at(commandName).m_helpOutput = std::move(newCommandHelp);
        }
    }

    void CommandLineUtils::SendArguments(const char **argv, const char **argc)
    {
        if (m_beginPosition != nullptr || m_endPosition != nullptr)
        {
            AGENT_LOG_DEBUG("Arguments already sent!");
            return;
        }
        m_beginPosition = argv;
        m_endPosition = argc;

        if (HasCommand("version"))
        {
            PrintVersion();
            exit(0);
        }

        // Automatically check and print the help message if the help command is present
        if (HasCommand(m_cmd_help))
        {
            PrintHelp();
            exit(-1);
        }
    }

    bool CommandLineUtils::HasCommand(Aws::Crt::String command)
    {
        return std::find(m_beginPosition, m_endPosition, "--" + command) != m_endPosition;
    }

    Aws::Crt::String CommandLineUtils::GetCommand(Aws::Crt::String command)
    {
        const char **itr = std::find(m_beginPosition, m_endPosition, "--" + command);
        if (itr != m_endPosition && ++itr != m_endPosition)
        {
            return Aws::Crt::String(*itr);
        }
        return "";
    }

    Aws::Crt::String CommandLineUtils::GetCommandOrDefault(Aws::Crt::String command, Aws::Crt::String commandDefault)
    {
        if (HasCommand(command))
        {
            return Aws::Crt::String(GetCommand(command));
        }
        return commandDefault;
    }

    Aws::Crt::String CommandLineUtils::GetCommandRequired(
        Aws::Crt::String command,
        Aws::Crt::String optionalAdditionalMessage)
    {
        if (HasCommand(command))
        {
            return GetCommand(command);
        }
        PrintHelp();
        AGENT_LOG_ERROR("Missing required argument: --%s", command.c_str());
        if (optionalAdditionalMessage != "")
        {
            AGENT_LOG_ERROR("%s", optionalAdditionalMessage.c_str());
        }
        exit(-1);
    }

    void CommandLineUtils::PrintHelp()
    {
        fprintf(stdout, "Usage:\n");
        fprintf(stdout, "%s", m_programName.c_str());

        for (auto const &pair : m_registeredCommands)
        {
            fprintf(stdout, " --%s %s", pair.first.c_str(), pair.second.m_exampleInput.c_str());
        }

        fprintf(stdout, "\n\n");

        for (auto const &pair : m_registeredCommands)
        {
            fprintf(stdout, "* %s:\t\t%s\n", pair.first.c_str(), pair.second.m_helpOutput.c_str());
        }

        fprintf(stdout, "\n");
    }

    void CommandLineUtils::PrintVersion()
    {
        fprintf(stdout, "Version: %s\n", VERSION);
    }

    void CommandLineUtils::InitCommands()
    {
        RegisterCommand("key", "<path>", "Path to your key in PEM format.");
        RegisterCommand("cert", "<path>", "Path to your client certificate in PEM format.");
        RegisterCommand("thing_name", "<str>", "The name of your IOT thing.");
        RegisterCommand("pkcs11_lib", "<path>", "Path to PKCS#11 library.");
        RegisterCommand("key_label", "<str>", "Label of private key on the PKCS#11 token (optional).");
        RegisterCommand("pin", "<str>", "User PIN for logging into PKCS#11 token.");
        RegisterCommand(
            m_cmd_ca_file, "<path>", "Path to AmazonRootCA1.pem (optional, system trust store used by default).");
        RegisterCommand(
            m_cmd_verbosity,
            "<log level>",
            "The logging level to use. Choices are 'Trace', 'Debug', 'Info', 'Warn', 'Error', 'Fatal', and 'None'. "
            "(optional, default='none')");
        RegisterCommand(
            m_cmd_log_file,
            "<str>",
            "File to write logs to. If not provided, logs will be written to stdout. "
            "(optional, default='none')");
        RegisterCommand("version", "", "Print version of this software.");
    }

    void CommandLineUtils::StartLoggingBasedOnCommand(Aws::Crt::ApiHandle *apiHandle)
    {
        // Process logging command
        if (HasCommand("verbosity"))
        {
            Aws::Crt::LogLevel logLevel = Aws::Crt::LogLevel::None;
            Aws::Crt::String verbosity = GetCommand(m_cmd_verbosity);
            if (verbosity == "Fatal")
            {
                logLevel = Aws::Crt::LogLevel::Fatal;
            }
            else if (verbosity == "Error")
            {
                logLevel = Aws::Crt::LogLevel::Error;
            }
            else if (verbosity == "Warn")
            {
                logLevel = Aws::Crt::LogLevel::Warn;
            }
            else if (verbosity == "Info")
            {
                logLevel = Aws::Crt::LogLevel::Info;
            }
            else if (verbosity == "Debug")
            {
                logLevel = Aws::Crt::LogLevel::Debug;
            }
            else if (verbosity == "Trace")
            {
                logLevel = Aws::Crt::LogLevel::Trace;
            }
            else
            {
                logLevel = Aws::Crt::LogLevel::None;
            }

            if (HasCommand("log_file"))
            {
                apiHandle->InitializeLogging(logLevel, GetCommand(m_cmd_log_file).c_str());
            }
            else
            {
                apiHandle->InitializeLogging(logLevel, stderr);
            }
        }
    }

    void CommandLineUtils::SetEndpoint(Aws::Crt::String endpoint)
    {
        m_endpoint = endpoint;
    }

    std::shared_ptr<Aws::Crt::Mqtt::MqttConnection> CommandLineUtils::BuildPKCS11MQTTConnection(
        Aws::Iot::MqttClient *client)
    {
        std::shared_ptr<Aws::Crt::Io::Pkcs11Lib> pkcs11Lib =
            Aws::Crt::Io::Pkcs11Lib::Create(GetCommandRequired(m_cmd_pkcs11_lib));
        if (!pkcs11Lib)
        {
            AGENT_LOG_ERROR("Pkcs11Lib failed: %s", Aws::Crt::ErrorDebugString(Aws::Crt::LastError()));
            exit(-1);
        }

        Aws::Crt::Io::TlsContextPkcs11Options pkcs11Options(pkcs11Lib);
        pkcs11Options.SetCertificateFilePath(GetCommandRequired(m_cmd_cert_file));
        pkcs11Options.SetUserPin(GetCommandRequired(m_cmd_pkcs11_pin));

        if (HasCommand(m_cmd_pkcs11_key))
        {
            pkcs11Options.SetPrivateKeyObjectLabel(GetCommand(m_cmd_pkcs11_key));
        }

        Aws::Iot::MqttClientConnectionConfigBuilder clientConfigBuilder(pkcs11Options);
        if (!clientConfigBuilder)
        {
            AGENT_LOG_ERROR(
                "MqttClientConnectionConfigBuilder failed: %s",
                Aws::Crt::ErrorDebugString(Aws::Crt::LastError()));
            exit(-1);
        }

        if (HasCommand(m_cmd_ca_file))
        {
            clientConfigBuilder.WithCertificateAuthority(GetCommand(m_cmd_ca_file).c_str());
        }

        clientConfigBuilder.WithEndpoint(m_endpoint);
        return GetClientConnectionForMQTTConnection(client, &clientConfigBuilder);
    }

    std::shared_ptr<Aws::Crt::Mqtt::MqttConnection> CommandLineUtils::BuildDirectMQTTConnection(
        Aws::Iot::MqttClient *client)
    {
        Aws::Crt::String certificatePath = GetCommandRequired(m_cmd_cert_file);
        Aws::Crt::String keyPath = GetCommandRequired(m_cmd_key_file);
        Aws::Crt::String endpoint = m_endpoint;

        auto clientConfigBuilder =
            Aws::Iot::MqttClientConnectionConfigBuilder(certificatePath.c_str(), keyPath.c_str());
        clientConfigBuilder.WithEndpoint(endpoint);

        if (HasCommand(m_cmd_ca_file))
        {
            clientConfigBuilder.WithCertificateAuthority(GetCommand(m_cmd_ca_file).c_str());
        }

        return GetClientConnectionForMQTTConnection(client, &clientConfigBuilder);
    }

    std::shared_ptr<Aws::Crt::Mqtt::MqttConnection> CommandLineUtils::BuildMQTTConnection()
    {
        if (!m_internal_client)
        {
            m_internal_client = Aws::Iot::MqttClient();
            if (!m_internal_client)
            {
                AGENT_LOG_ERROR(
                    "MQTT Client Creation failed with error %s",
                    Aws::Crt::ErrorDebugString(m_internal_client.LastError()));
                exit(-1);
            }
        }

        if (HasCommand(m_cmd_pkcs11_lib))
        {
            return BuildPKCS11MQTTConnection(&m_internal_client);
        }
        else
        {
            return BuildDirectMQTTConnection(&m_internal_client);
        }
    }

    std::shared_ptr<Aws::Crt::Mqtt::MqttConnection> CommandLineUtils::GetClientConnectionForMQTTConnection(
        Aws::Iot::MqttClient *client,
        Aws::Iot::MqttClientConnectionConfigBuilder *clientConfigBuilder)
    {
        auto clientConfig = clientConfigBuilder->Build();
        if (!clientConfig)
        {
            AGENT_LOG_ERROR(
                "Client Configuration initialization failed with error %s",
                Aws::Crt::ErrorDebugString(clientConfig.LastError()));
            exit(-1);
        }

        auto connection = client->NewConnection(clientConfig);
        if (!*connection)
        {
            AGENT_LOG_ERROR(
                "MQTT Connection Creation failed with error %s",
                Aws::Crt::ErrorDebugString(connection->LastError()));
            exit(-1);
        }
        return connection;
    }

} // namespace Utils
