// SPDX-License-Identifier: MIT

use anyhow::{anyhow, Result};
use axum::{
    extract::{ConnectInfo, Request},
    http::StatusCode,
    middleware::Next,
    response::{IntoResponse, Response},
};
use parking_lot::RwLock;
use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
use std::time::Instant;
use tokio::process::Command;
use tracing::info;

use crate::args::args;

#[derive(Clone)]
struct Allowed {
    allowed: bool,
    ts: Instant,
}

static ALLOWED_IPS: RwLock<Option<HashMap<IpAddr, Allowed>>> = RwLock::new(None);

/// helper to simplify ipv4 written as ::ffff:ipv4 into "normal" ipv4
trait SimplifyIpv4 {
    fn simplify_ipv4(self) -> Self;
}

impl SimplifyIpv4 for IpAddr {
    fn simplify_ipv4(self) -> IpAddr {
        match self {
            IpAddr::V4(v4) => IpAddr::V4(v4),
            IpAddr::V6(v6) => {
                if let Some(v4) = v6.to_ipv4_mapped() {
                    IpAddr::V4(v4)
                } else {
                    IpAddr::V6(v6)
                }
            }
        }
    }
}

async fn ip_allowed_uncached(ip: &IpAddr) -> Result<()> {
    let out = Command::new("ip")
        .args(["route", "get", &ip.to_string()])
        .output()
        .await?;
    if !out.status.success() {
        return Err(anyhow!(
            "ip r command failed: {}",
            String::from_utf8_lossy(&out.stderr)
        ));
    }
    let route = std::str::from_utf8(&out.stdout)?;
    // ok if reply does not contain 'via' keyword
    if let Some((_, gw)) = route.split_once(" via ") {
        let gw = gw.split_once(' ').map(|(ip, _)| ip).unwrap_or(gw);
        info!("{ip} routed (via {gw})");
        return Err(anyhow!("routed via {}", gw));
    }
    Ok(())
}

async fn ip_allowed(ip: IpAddr) -> Result<()> {
    // use arguments if set, otherwise check IP is in local segment
    let args_allowed = &args().allowed_subnets;
    if !args_allowed.is_empty() {
        if args_allowed
            .iter()
            .flatten()
            .any(|subnet| subnet.contains(&ip))
        {
            return Ok(());
        }
        info!("{ip} not in {args_allowed:?}");
        return Err(anyhow!("not in allowed subnets"));
    }
    if let Some(cached) = ALLOWED_IPS.read().as_ref().and_then(|map| map.get(&ip)) {
        // allows are cached 60s, deny 2s:
        // we only cache deny to prevent death by a thousands cuts
        // running too many commands but it should expire fast in case of
        // IP change.
        let elapsed = cached.ts.elapsed().as_secs();
        match cached.allowed {
            true => {
                if elapsed < 60 {
                    return Ok(());
                }
            }
            false => {
                if elapsed < 2 {
                    info!("cached deny for {ip}");
                    return Err(anyhow!("Cached, retry in 2s"));
                }
            }
        }
    }

    let allowed = ip_allowed_uncached(&ip).await;
    let mut guard = ALLOWED_IPS.write();
    if guard.is_none() {
        *guard = Some(HashMap::new());
    }
    // just made sure it's Some..
    let _ = guard.as_mut().unwrap().insert(
        ip,
        Allowed {
            allowed: allowed.is_ok(),
            ts: Instant::now(),
        },
    );
    allowed
}

pub async fn ipaddr_filter(
    ConnectInfo(addr): ConnectInfo<SocketAddr>,
    req: Request,
    next: Next,
) -> Response {
    if let Err(e) = ip_allowed(addr.ip().simplify_ipv4()).await {
        return (
            StatusCode::FORBIDDEN,
            format!(
                "Please connect from a private address (connected from {}, {})",
                addr.ip().simplify_ipv4(),
                e
            ),
        )
            .into_response();
    }
    next.run(req).await
}
