// SPDX-License-Identifier: MIT

use anyhow::{anyhow, bail, Context, Result};
use axum::{extract::ConnectInfo, http::Request, middleware::AddExtension, Router};
use hyper::{
    body::Incoming,
    rt::{Read, Write},
};
use hyper_util::rt::{TokioExecutor, TokioIo};
use nix::unistd;
use std::convert::Infallible;
use std::net::SocketAddr;
use tokio::net::TcpStream;
use tower::Service;
use tracing::{trace, warn};

use crate::args::args;

mod http_upgrade;
use http_upgrade::{is_https, redirect_to_https};

// Note: we prefer rustls over tls (openssl) if both are set.
// It doesn't matter.
#[cfg(all(feature = "tls-openssl", not(feature = "tls-rustls")))]
mod tls_openssl;
#[cfg(all(feature = "tls-openssl", not(feature = "tls-rustls")))]
use tls_openssl as tls;

#[cfg(feature = "tls-rustls")]
mod tls_rustls;
#[cfg(feature = "tls-rustls")]
use tls_rustls as tls;

// fallback if neither tls is set
#[cfg(not(any(feature = "tls-openssl", feature = "tls-rustls")))]
mod tls {
    use anyhow::{Error, Result};
    use tokio::net::TcpStream;

    pub type Acceptor = ();
    pub async fn tls_init() -> Result<()> {
        Err(Error::msg("feature disabled"))
    }
    pub async fn tls_accept(_acceptor: (), _cnx: TcpStream) -> Result<TcpStream> {
        panic!("Should never be called");
    }
}

fn unwrap_infallible<T>(result: Result<T, Infallible>) -> T {
    match result {
        Ok(value) => value,
        Err(err) => match err {},
    }
}

async fn serve_connection<I>(
    tower_service: AddExtension<Router, ConnectInfo<SocketAddr>>,
    stream: I,
) where
    I: Read + Write + Unpin + Send + 'static,
{
    // Hyper also has its own `Service` trait and doesn't use tower. We can use
    // `hyper::service::service_fn` to create a hyper `Service` that calls our app through
    // `tower::Service::call`.
    let hyper_service = hyper::service::service_fn(move |request: Request<Incoming>| {
        // We have to clone `tower_service` because hyper's `Service` uses `&self` whereas
        // tower's `Service` requires `&mut self`.
        //
        // We don't need to call `poll_ready` since `Router` is always ready.
        tower_service.clone().call(request)
    });

    // `server::conn::auto::Builder` supports both http1 and http2.
    //
    // `TokioExecutor` tells hyper to use `tokio::spawn` to spawn tasks.
    let mut builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
    builder.http2().max_header_list_size(4096);
    if let Err(e) = builder
        .serve_connection_with_upgrades(stream, hyper_service)
        .await
    {
        // error is Box<dyn StdError + Send + Sync> and apparently can't be
        // casted down to anyhow error?
        trace!("Got error serving client: {:?}", e)
    }
}

async fn serve_one(
    tower_service: AddExtension<Router, ConnectInfo<SocketAddr>>,
    cnx: TcpStream,
    tls_acceptor: Option<tls::Acceptor>,
) -> Result<()> {
    // TLS setup
    // Need to split the branches even if they are identical as we cannot have a 'stream'
    // which is both TcpStream and whatever it is tls streams are
    match tls_acceptor {
        Some(acceptor) => {
            if is_https(&cnx).await? {
                let stream = TokioIo::new(tls::tls_accept(acceptor, cnx).await?);
                serve_connection(tower_service, stream).await;
            } else {
                redirect_to_https(cnx).await;
            }
        }
        None => {
            // Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio.
            // `TokioIo` converts between them.
            let stream = TokioIo::new(cnx);
            serve_connection(tower_service, stream).await;
        }
    };

    Ok(())
}

fn create_listener(addr: std::net::SocketAddr) -> Result<tokio::net::TcpListener> {
    // We cannot set freebind option directly to TcpListener (that directly take an address
    // parameter and do not give us a chance to call setsockopt on the underlying socket),
    // so create a socket manually and set it up appropriately.
    // The non-blocking part of this init was copied from tokio/src/net/tcp/socket.rs
    let addr: socket2::SockAddr = addr.into();
    let domain = addr.domain();

    let ty = socket2::Type::STREAM.nonblocking();
    let sock = socket2::Socket::new(domain, ty, Some(socket2::Protocol::TCP))
        .context("Could not create socket")?;

    sock.set_nonblocking(true)
        .context("could not set non-blocking")?;
    sock.set_reuse_address(true)
        .context("could not set reuseaddr")?;

    if args().free_bind {
        match domain {
            socket2::Domain::IPV4 => sock.set_freebind_v4(true),
            socket2::Domain::IPV6 => sock.set_freebind_v6(true),
            _ => bail!("invalid domain"),
        }
        .context("could not set freebind")?;
    }

    // bind address/listen before converting to TcpListener. If free-bind was not set,
    // this bind will fail if address is not available on the system.
    sock.bind(&addr)
        .context("Could not bind to address -- not setup or missing --free-bind?")?;
    // note: rust's TcpListener default to -1 (or SOMAXCONN), but abos-web default
    // makes more sense to keep low to limit max memory usage
    sock.listen(16).context("Could not listen to address")?;

    // convert to tokio TcpListener through std
    let sock: std::net::TcpListener = sock.into();
    sock.try_into()
        .context("Could not convert to tokio listener")
}

fn drop_privileges() -> Result<()> {
    if !unistd::Uid::current().is_root() {
        return Ok(());
    }

    let user = unistd::User::from_name(&args().user)?
        .with_context(|| anyhow!("User '{}' not found", args().user))?;
    unistd::setgroups(&[user.gid])?;
    unistd::setgid(user.gid)?;
    unistd::setuid(user.uid)?;
    Ok(())
}

pub async fn run_server(app: Router) -> Result<()> {
    let addr = args().listen_addr;
    let listener =
        create_listener(addr).with_context(|| format!("Check --listen address ({addr})"))?;
    let mut make_service = app.into_make_service_with_connect_info::<SocketAddr>();
    let tls_acceptor = match tls::tls_init().await {
        Ok(acceptor) => Some(acceptor),
        Err(e) => {
            warn!("Could not setup TLS: {:?}", e);
            if let Ok(io_error) = e.downcast::<std::io::Error>() {
                if io_error.kind() == std::io::ErrorKind::PermissionDenied {
                    panic!(
                        "Permission denied while reading TLS certificate, aborting as this probably
means you should run abos-web as root (privileges are dropped after reading
TLS key)"
                    )
                }
            }
            None
        }
    };

    // drop privileges after loading TLS key
    drop_privileges()?;

    loop {
        let (cnx, addr) = listener
            .accept()
            .await
            .context("Could not accept client!")?;
        let tower_service = unwrap_infallible(make_service.call(addr).await);

        // In non-tls mode acceptor is Option<()> which is copy and does not
        // require an explicit clone, but we need it for tls.
        #[allow(clippy::clone_on_copy)]
        let tls_acceptor = tls_acceptor.clone();

        tokio::spawn(async move {
            if let Err(e) = serve_one(tower_service, cnx, tls_acceptor).await {
                trace!("Got error serving {}: {:?}", addr, e);
            }
        });
    }
}
