// SPDX-License-Identifier: MIT

use axum::extract::rejection::JsonRejection;
use axum::{
    extract::{FromRequest, Json, Request},
    http::header::CONTENT_TYPE,
    response::IntoResponse,
    RequestExt,
};
use axum_extra::extract::{Form, FormRejection};
use serde::Serialize;

use crate::error::{ErrorString, ErrorStringResult};

// convert a serializable struct into a json response for ErrorStringResult
pub fn json_response<T>(data: &T) -> ErrorStringResult
where
    T: Serialize,
{
    Ok(Json(data).into_response())
}

#[allow(dead_code)]
pub fn json_response_raw(s: String) -> ErrorStringResult {
    let mut response = s.into_response();
    response
        .headers_mut()
        .insert("content-type", "application/json".parse().unwrap());
    Ok(response)
}

impl From<FormRejection> for ErrorString {
    fn from(rejection: FormRejection) -> Self {
        let body = match rejection {
            FormRejection::RawFormRejection(inner) => inner.body_text(),
            FormRejection::FailedToDeserializeForm(inner) => inner.to_string(),
            _ => rejection.to_string(), // should never happen, but enum is marked non_exhaustive
        };
        body.into()
    }
}

impl From<JsonRejection> for ErrorString {
    fn from(rejection: JsonRejection) -> Self {
        rejection.body_text().into()
    }
}

pub struct JsonOrFormOption<T>(pub Option<T>);

impl<S, T> FromRequest<S> for JsonOrFormOption<T>
where
    Json<T>: FromRequest<(), Rejection = JsonRejection>,
    Form<T>: FromRequest<(), Rejection = FormRejection>,
    T: 'static,
    S: Send + Sync,
{
    type Rejection = ErrorString;

    async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
        let Some(content_type) = req
            .headers()
            .get(CONTENT_TYPE)
            .and_then(|value| value.to_str().ok())
            .map(|value| value.to_lowercase())
        else {
            return Ok(JsonOrFormOption(None));
        };
        if content_type.starts_with("application/json") {
            let Json(payload) = req.extract::<Json<T>, _>().await?;
            Ok(JsonOrFormOption(Some(payload)))
        } else if content_type.starts_with("application/x-www-form-urlencoded") {
            let Form(payload) = req.extract::<Form<T>, _>().await?;
            Ok(JsonOrFormOption(Some(payload)))
        } else {
            Ok(JsonOrFormOption(None))
        }
    }
}

pub struct JsonOrForm<T>(pub T);

impl<S, T> FromRequest<S> for JsonOrForm<T>
where
    Json<T>: FromRequest<(), Rejection = JsonRejection>,
    Form<T>: FromRequest<(), Rejection = FormRejection>,
    T: 'static,
    S: Send + Sync,
{
    type Rejection = ErrorString;

    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
        match JsonOrFormOption::from_request(req, state).await? {
           JsonOrFormOption(None) => {
               Err("`content-type` must be set (application/json or application/x-www-form-urlencoded)")?
           },
            JsonOrFormOption(Some(payload)) => Ok(JsonOrForm(payload))
        }
    }
}
