// SPDX-License-Identifier: MIT

use axum::{
    http::request::Parts,
    http::{header, HeaderValue},
};
use base64::engine::general_purpose::STANDARD as B64_ENGINE;
use base64::Engine;

pub struct AuthHeaders<'a> {
    headers: header::ValueIter<'a, HeaderValue>,
}
impl<'a> AuthHeaders<'a> {
    pub fn from_parts(parts: &'a Parts) -> Self {
        AuthHeaders {
            headers: parts.headers.get_all("Authorization").iter(),
        }
    }
}

#[derive(Debug, PartialEq)]
pub enum AuthHeader<'a> {
    Basic(String, String),
    Bearer(&'a str),
}

impl<'a> AuthHeader<'a> {
    fn parse(auth: &'a str) -> Option<Self> {
        if let Some(b64) = auth.strip_prefix("Basic ") {
            let decode = String::from_utf8(B64_ENGINE.decode(b64).ok()?).ok()?;
            let (user, pass) = decode.split_once(':')?;
            Some(AuthHeader::Basic(user.into(), pass.into()))
        } else {
            auth.strip_prefix("Bearer ").map(AuthHeader::Bearer)
        }
    }
}

impl<'a> Iterator for AuthHeaders<'a> {
    type Item = AuthHeader<'a>;
    fn next(&mut self) -> Option<Self::Item> {
        loop {
            let header = self.headers.next()?;
            if let Some(auth) = AuthHeader::parse(header.to_str().ok()?) {
                return Some(auth);
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use crate::common::auth_header::AuthHeader;
    use anyhow::Result;

    #[test]
    fn test_parse_auth() -> Result<()> {
        assert_eq!(
            AuthHeader::parse("Basic dGVzdDpwYXNz"),
            Some(AuthHeader::Basic("test".into(), "pass".into()))
        );
        assert_eq!(
            AuthHeader::parse("Bearer abcd123"),
            Some(AuthHeader::Bearer("abcd123"))
        );
        Ok(())
    }
}
