// SPDX-License-Identifier: MIT

use anyhow::{anyhow, Context, Result};
use askama::Template;
use axum::{
    body::Bytes,
    extract::{
        ws::{self, WebSocket, WebSocketUpgrade},
        Extension, Form,
    },
    http::StatusCode,
    middleware,
    response::{IntoResponse, Redirect},
    routing::{get, post},
    Router,
};
use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart};
use futures::{sink::SinkExt, stream::StreamExt};
use once_cell::sync::OnceCell;
use parking_lot::Mutex;
use serde::Deserialize;
use std::collections::HashMap;
use tokio::fs;
use tower_sessions::Session;

#[cfg(feature = "restapi")]
mod restapi;
#[cfg(not(feature = "restapi"))]
mod restapi {
    pub fn routes() -> axum::Router {
        axum::Router::new()
    }
}

mod vpn_common;
pub use vpn_common::{vpn_act, vpn_connection_type, VpnConfig, VpnType};

use crate::common::{
    self, check_auth, get_title, stream_command, CommandOpts, Config, HtmlTemplate, LoggedIn, Title,
};
use crate::error::{ErrorStringResult, PageResult};

static VPN_WS: OnceCell<Mutex<Option<WebSocket>>> = OnceCell::new();

#[derive(TryFromMultipart)]
struct VpnParam {
    setting_name: String,
    conf: FieldData<Bytes>,
    auth_type: String,
    username: String,
    password: String,
    cert: FieldData<Bytes>,
    key: FieldData<Bytes>,
    key_pass: String,
}

#[derive(Template)]
#[template(path = "../src/vpn/templates/vpn.html")]
struct VpnTemplate {
    setting_name: String,
    connection_type: VpnType,
    vpn_addr: HashMap<&'static str, String>,
    wg_publickey: String,
    is_wg_installed: bool,
    is_ovpn_installed: bool,
}

pub fn routes() -> Router {
    VPN_WS
        .set(Mutex::new(None))
        .expect("VPN_WS once cell already set!");

    Router::new()
        .route("/vpn", get(vpn))
        .route("/vpn_setup", post(vpn_setup))
        .route("/vpn_delete", post(vpn_delete))
        .route("/vpn_connect", post(vpn_connect))
        .route("/vpn_disconnect", post(vpn_disconnect))
        .route("/vpn_ws", get(vpn_ws))
        .route("/vpn_wg_gen_key", get(vpn_wg_gen_key))
        .route("/vpn_wg_setup", post(vpn_wg_setup))
        .route_layer(middleware::from_fn(check_auth))
        .merge(restapi::routes())
        .route_layer(middleware::from_fn(|request, next| {
            get_title(request, next, "./vpn")
        }))
}

async fn vpn_ws(session: Session, ws: WebSocketUpgrade) -> ErrorStringResult {
    if !cfg!(debug_assertions) && !session.logged_in().await {
        Err((StatusCode::UNAUTHORIZED, "not logged in"))?;
    }
    Ok(ws.on_upgrade(vpn_socket_handler).into_response())
}

async fn vpn_socket_handler(socket: WebSocket) {
    *VPN_WS
        .get()
        .expect("VPN_WS once cell not initialized")
        .lock() = Some(socket);
}

async fn vpn(Extension(title): Extension<Title>) -> PageResult {
    let connection_type = vpn_connection_type().await;
    let args = &["vpn_info.sh"];
    let output = common::exec_command(args).await?;
    let setting_name = String::from_utf8_lossy(&output.stdout).trim().to_string();

    let mut interface = "tun0";
    if connection_type == VpnType::Wireguard {
        interface = "wg";
    }
    let vpn_addr = common::current_address(interface, interface).await?;
    let wg_publickey = vpn_wg_get_publickey().await;
    let template = VpnTemplate {
        setting_name,
        connection_type,
        vpn_addr,
        wg_publickey,
        is_wg_installed: Config::get().software.is_wg_installed,
        is_ovpn_installed: Config::get().software.is_ovpn_installed,
    };
    Ok(HtmlTemplate::new(title.0, template).into_response())
}

async fn vpn_setup(TypedMultipart(vpn_param): TypedMultipart<VpnParam>) -> ErrorStringResult {
    // Delete current settings before setup.
    let args = &["vpn_delete.sh"];
    common::exec_command(args).await?;

    let conf_contents = vpn_param.conf.contents;
    let conf_name = vpn_param
        .conf
        .metadata
        .file_name
        .context("Config file name is none")?;
    let vpn_config = VpnConfig::new()?;
    vpn_config.save_file(&conf_name, &conf_contents).await?;

    let mut args = vec![
        "vpn_setup.sh".to_string(),
        vpn_config.dir()?.to_string(),
        vpn_param.setting_name,
        vpn_param.auth_type.clone(),
        conf_name,
    ];

    match &*vpn_param.auth_type {
        "userpass" => {
            args.push("--user".to_string());
            args.push(vpn_param.username);
            args.push("--password".to_string());
            args.push(vpn_param.password);
        }
        "cert" => {
            let cert_contents = vpn_param.cert.contents;
            let cert_name = vpn_param
                .cert
                .metadata
                .file_name
                .context("Cert file name is none.")?;
            if !cert_contents.is_empty() && !cert_name.is_empty() {
                vpn_config
                    .save_file(&cert_name, &cert_contents)
                    .await
                    .context("Could not save certificate")?;
                args.push("--cert".to_string());
                args.push(cert_name);
            }

            let key_contents = vpn_param.key.contents;
            let key_name = vpn_param
                .key
                .metadata
                .file_name
                .context("Key file name is none.")?;
            if !key_contents.is_empty() && !key_name.is_empty() {
                vpn_config
                    .save_file(&key_name, &key_contents)
                    .await
                    .context("Could not save key file")?;
                args.push("--key".to_string());
                args.push(key_name);
            }
            if !vpn_param.key_pass.is_empty() {
                args.push("--askpass".to_string());
                args.push(vpn_param.key_pass);
            }
        }
        _ => Err(anyhow!("Bad auth_type {}.", vpn_param.auth_type))?,
    }

    if Config::get().software.is_ovpn_installed {
        common::exec_command(&args).await?;
        Ok(Redirect::to("/vpn").into_response())
    } else {
        vpn_start(&args).await?;
        Ok(().into_response())
    }
}

async fn vpn_start(args: &[String]) -> Result<()> {
    let Some(mut socket) = VPN_WS
        .get()
        .expect("VPN_WS once cell not initialized")
        .lock()
        .take()
    else {
        return Err(anyhow!("websocket does not exist."));
    };
    let (mut output, _) = (&mut socket).split();
    if let Err(e) = stream_command(args, &CommandOpts::default(), &mut output).await {
        let _ = output.send(ws::Message::Close(None)).await;
        return Err(anyhow!("Could not start vpn: {}", e));
    };
    let _ = output.send(ws::Message::Close(None)).await;
    Ok(())
}

async fn vpn_delete() -> PageResult {
    let args = &["vpn_delete.sh"];
    common::exec_command(args).await?;

    Ok(Redirect::to("/vpn").into_response())
}

async fn vpn_connect() -> PageResult {
    vpn_act(true).await?;

    Ok(Redirect::to("/vpn").into_response())
}

async fn vpn_disconnect() -> PageResult {
    vpn_act(false).await?;

    Ok(Redirect::to("/vpn").into_response())
}

#[derive(Deserialize)]
struct VpnWgParam {
    setting_name: String,
    endpoint: String,
    publickey: String,
    self_address: String,
    dns_address: String,
    keepalive: String,
}

async fn vpn_wg_setup(Form(wg_param): Form<VpnWgParam>) -> PageResult {
    let args = vec![
        "wireguard_setup.sh",
        &wg_param.setting_name,
        &wg_param.endpoint,
        &wg_param.publickey,
        &wg_param.self_address,
        &wg_param.dns_address,
        &wg_param.keepalive,
    ];

    common::exec_command(&args).await?;

    Ok(Redirect::to("/vpn").into_response())
}

async fn vpn_wg_gen_key() -> ErrorStringResult {
    common::exec_command(&["wireguard_gen_key.sh"]).await?;
    Ok(vpn_wg_get_publickey().await.into_response())
}

async fn vpn_wg_get_publickey() -> String {
    match fs::read_to_string("/etc/atmark/abos_web/wireguard/publickey.txt").await {
        Ok(s) => s.trim().to_string(),
        Err(_) => String::new(),
    }
}
