//! Provides the [`SputnikParts`] and [`SputnikBody`] traits.

use cookie::Cookie;
use mime::Mime;
use serde::{Deserialize, de::DeserializeOwned};
use hyper::{HeaderMap, body::Bytes, header, http::request::Parts};
use time::Duration;
use std::{collections::HashMap, sync::Arc};
use rand::{Rng, distributions::Alphanumeric};
use async_trait::async_trait;

use crate::response::SputnikHeaders;

const CSRF_COOKIE_NAME : &str = "csrf";

pub trait SputnikParts {
    /// Parses the query string of the request into a given struct.
    fn query<X: DeserializeOwned>(&self) -> Result<X,QueryError>;

    /// Parses the cookies of the request.
    fn cookies(&mut self) -> Arc<HashMap<String, Cookie<'static>>>;

    /// Enforces a specific Content-Type.
    fn enforce_content_type(&self, mime: Mime) -> Result<(), WrongContentTypeError>;

    /// A map of response headers to allow methods of this trait to set response
    /// headers without needing to take a [`Response`](hyper::http::response::Response) as an argument.
    ///
    /// You need to take care to append these headers to the response yourself.
    /// This is intended to be done after your routing logic so that your
    /// individual request handlers don't have to worry about it.
    fn response_headers(&mut self) -> &mut HeaderMap;

    /// Returns a CSRF token, either extracted from the `csrf` cookie or newly
    /// generated if the cookie wasn't sent (in which case a set-cookie header is
    /// appended to [`Self::response_headers`]).
    ///
    /// If there is no cookie, calling this method multiple times only generates
    /// a new token on the first call, further calls return the previously
    /// generated token.
    fn csrf_token(&mut self) -> CsrfToken;
}

impl SputnikParts for hyper::http::request::Parts {
    fn query<T: DeserializeOwned>(&self) -> Result<T,QueryError> {
        serde_urlencoded::from_str::<T>(self.uri.query().unwrap_or("")).map_err(QueryError)
    }

    fn response_headers(&mut self) -> &mut HeaderMap {
        if self.extensions.get::<HeaderMap>().is_none() {
            self.extensions.insert(HeaderMap::new());
        }
        self.extensions.get_mut::<HeaderMap>().unwrap()
    }

    fn cookies(&mut self) -> Arc<HashMap<String, Cookie<'static>>> {
        let cookies: Option<&Arc<HashMap<String, Cookie>>> = self.extensions.get();
        if let Some(cookies) = cookies {
            return cookies.clone();
        }

        let mut cookies = HashMap::new();
        for header in self.headers.get_all(header::COOKIE) {
            let raw_str = match std::str::from_utf8(header.as_bytes()) {
                Ok(string) => string,
                Err(_) => continue
            };

            for cookie_str in raw_str.split(';').map(|s| s.trim()) {
                if let Ok(cookie) = Cookie::parse_encoded(cookie_str) {
                    cookies.insert(cookie.name().to_string(), cookie.into_owned());
                }
            }
        }
        let cookies = Arc::new(cookies);
        self.extensions.insert(cookies.clone());
        cookies
    }

    fn enforce_content_type(&self, mime: Mime) -> Result<(), WrongContentTypeError> {
        if let Some(content_type) = self.headers.get(header::CONTENT_TYPE) {
            if *content_type == mime.to_string() {
                return Ok(())
            }
        }
        Err(WrongContentTypeError{expected: mime, received: self.headers.get(header::CONTENT_TYPE).as_ref().and_then(|h| h.to_str().ok().map(|s| s.to_owned()))})
    }

    fn csrf_token(&mut self) -> CsrfToken {
        if let Some(token) = self.extensions.get::<CsrfToken>() {
            return token.clone()
        }
        csrf_token_from_cookies(self)
        .unwrap_or_else(|| {
            let token: String = rand::thread_rng().sample_iter(
                Alphanumeric
                // must be HTML-safe because we embed it in CsrfToken::html_input
            ).take(16).collect();
            let mut c = Cookie::new(CSRF_COOKIE_NAME, token.clone());
            c.set_secure(Some(true));
            c.set_max_age(Some(Duration::hours(1)));

            self.response_headers().set_cookie(c);
            let token = CsrfToken(token);
            self.extensions.insert(token.clone());
            token
        })
    }
}

/// A CSRF token retrievable with [`SputnikParts::csrf_token`].
#[derive(Clone)]
pub struct CsrfToken(String);

impl std::fmt::Display for CsrfToken {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}", self.0)
    }
}

impl CsrfToken {
    /// Returns a hidden HTML input to be embedded in forms that are received
    /// with [`crate::request::SputnikBody::into_form_csrf`].
    pub fn html_input(&self) -> String {
        format!("<input name=csrf type=hidden value=\"{}\">", self)
    }
}

#[async_trait]
pub trait SputnikBody {
    async fn into_bytes(self) -> Result<Bytes, BodyError>;

    /// Parses a `application/x-www-form-urlencoded` request body into a given struct.
    ///
    /// This does make you vulnerable to CSRF, so you normally want to use
    /// [`SputnikBody::into_form_csrf()`] instead.
    async fn into_form<T: DeserializeOwned>(self) -> Result<T, FormError>;

    /// Parses a `application/x-www-form-urlencoded` request body into a given struct.
    /// Protects from CSRF by checking that the request body contains the same token retrieved from the cookies.
    ///
    /// The HTML form must embed a hidden input generated with [`CsrfToken::html_input`].
    async fn into_form_csrf<T: DeserializeOwned>(self, req: &mut Parts) -> Result<T, CsrfProtectedFormError>;

    /// Attempts to deserialize the request body as JSON.
    #[cfg(feature = "json")]
    #[cfg_attr(docsrs, doc(cfg(feature = "json")))]
    async fn into_json<T: DeserializeOwned>(self) -> Result<T, JsonError>;
}

fn csrf_token_from_cookies(req: &mut Parts) -> Option<CsrfToken> {
    req.cookies()
        .get(CSRF_COOKIE_NAME)
        .map(|cookie| {
            let token = CsrfToken(cookie.value().to_string());
            req.extensions.insert(token.clone());
            token
        })
}

#[async_trait]
impl SputnikBody for hyper::Body {
    async fn into_bytes(self) -> Result<Bytes, BodyError> {
        hyper::body::to_bytes(self).await.map_err(BodyError)
    }

    async fn into_form<T: DeserializeOwned>(self) -> Result<T, FormError> {
        let full_body = self.into_bytes().await?;
        Ok(serde_urlencoded::from_bytes::<T>(&full_body)?)
    }

    async fn into_form_csrf<T: DeserializeOwned>(self, req: &mut Parts) -> Result<T, CsrfProtectedFormError> {
        let full_body = self.into_bytes().await?;
        let csrf_data = serde_urlencoded::from_bytes::<CsrfData>(&full_body).map_err(|_| CsrfProtectedFormError::NoCsrf)?;
        match csrf_token_from_cookies(req) {
            Some(token) => if token.to_string() == csrf_data.csrf {
                Ok(serde_urlencoded::from_bytes::<T>(&full_body)?)
            } else {
                Err(CsrfProtectedFormError::Mismatch)
            }
            None => Err(CsrfProtectedFormError::NoCookie)
        }
    }

    #[cfg(feature = "json")]
    #[cfg_attr(docsrs, doc(cfg(feature = "json")))]
    async fn into_json<T: DeserializeOwned>(self) -> Result<T, JsonError> {
        let full_body = self.into_bytes().await?;
        Ok(serde_json::from_slice::<T>(&full_body)?)
    }
}

#[derive(Deserialize)]
struct CsrfData {
    csrf: String,
}

#[derive(thiserror::Error, Debug)]
#[error("query deserialize error: {0}")]
pub struct QueryError(pub serde_urlencoded::de::Error);

#[derive(thiserror::Error, Debug)]
#[error("failed to read body")]
pub struct BodyError(pub hyper::Error);

#[derive(thiserror::Error, Debug)]
#[error("expected Content-Type {expected} but received {}", received.as_ref().unwrap_or(&"nothing".to_owned()))]
pub struct WrongContentTypeError {
    pub expected: Mime,
    pub received: Option<String>,
}

#[derive(thiserror::Error, Debug)]
pub enum FormError {
    #[error("{0}")]
    Body(#[from] BodyError),

    #[error("form deserialize error: {0}")]
    Deserialize(#[from] serde_urlencoded::de::Error),
}

#[cfg(feature = "json")]
#[cfg_attr(docsrs, doc(cfg(feature = "json")))]
#[derive(thiserror::Error, Debug)]
pub enum JsonError {
    #[error("{0}")]
    Body(#[from] BodyError),

    #[error("json deserialize error: {0}")]
    Deserialize(#[from] serde_json::Error),
}

#[derive(thiserror::Error, Debug)]
pub enum CsrfProtectedFormError {
    #[error("{0}")]
    Body(#[from] BodyError),

    #[error("form deserialize error: {0}")]
    Deserialize(#[from] serde_urlencoded::de::Error),

    #[error("no csrf token in form data")]
    NoCsrf,

    #[error("no csrf cookie")]
    NoCookie,

    #[error("csrf parameter doesn't match csrf cookie")]
    Mismatch,
}

#[cfg(test)]
mod tests {
    use hyper::Request;

    use super::*;

    #[test]
    fn test_csrf_token() {
        let mut parts = Request::new(hyper::Body::empty()).into_parts().0;
        let tok1 = parts.csrf_token();
        let tok2 = parts.csrf_token();
        assert_eq!(tok1.to_string(), tok2.to_string());
        assert_eq!(parts.response_headers().len(), 1);
    }
}