// SPDX-License-Identifier: MIT

use anyhow::{Error, Result};
use axum::{
    handler::HandlerWithoutStateExt,
    http::{Request, StatusCode, Uri},
    response::{IntoResponse, Redirect},
    BoxError,
};
use axum_extra::extract::Host;
use hyper::body::Incoming;
use hyper_util::rt::{TokioExecutor, TokioIo};
use tokio::net::TcpStream;
use tower::Service;
use tracing::{info, trace};

use crate::server::unwrap_infallible;

pub async fn is_https(cnx: &TcpStream) -> Result<bool> {
    let mut byte = [0u8; 1];
    let n = cnx.peek(&mut byte).await?;
    if n == 0 {
        return Err(Error::msg("Connection closed before sending anything"));
    }
    // The first byte in the TLS protocol is always 0x16
    Ok(byte[0] == 0x16)
}

fn make_https(host: String, uri: Uri) -> Result<Uri, BoxError> {
    let mut parts = uri.into_parts();

    parts.scheme = Some(axum::http::uri::Scheme::HTTPS);

    match &parts.path_and_query {
        None => parts.path_and_query = Some("/".parse()?),
        Some(path) if path.path().starts_with("/api/") => {
            Err("Refusing http /api url; consider changing token that was sent in plain text\n")?
        }
        _ => (),
    }

    parts.authority = Some(host.parse()?);

    Ok(Uri::from_parts(parts)?)
}

async fn redirect_request(Host(host): Host, uri: Uri) -> impl IntoResponse {
    match make_https(host, uri) {
        Ok(uri) => Ok(Redirect::permanent(&uri.to_string())),
        Err(e) => {
            info!("failed to convert URI to HTTPS: {:?}", e);
            Err((StatusCode::BAD_REQUEST, e.to_string()))
        }
    }
}

pub async fn redirect_to_https(cnx: TcpStream) {
    let mut make_service = redirect_request.into_make_service();
    let tower_service = unwrap_infallible(make_service.call(()).await);
    let hyper_service = hyper::service::service_fn(move |request: Request<Incoming>| {
        tower_service.clone().call(request)
    });
    let stream = TokioIo::new(cnx);
    let mut builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
    builder.http2().max_header_list_size(4096);
    if let Err(e) = builder.serve_connection(stream, hyper_service).await {
        trace!("Got error redirecting to https: {:?}", e)
    }
}
