//! Provides the [`SputnikParts`] and [`SputnikBody`] traits.

use cookie::Cookie;
use mime::Mime;
use serde::de::DeserializeOwned;
use hyper::{HeaderMap, body::Bytes, header, http::request::Parts};
use time::Duration;
use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait;

use crate::response::{SputnikHeaders, delete_cookie};

/// Adds convenience methods to [`http::request::Parts`](Parts).
pub trait SputnikParts {
    /// Parses the query string of the request into a given struct.
    fn query<X: DeserializeOwned>(&self) -> Result<X,QueryError>;

    /// Parses the cookies of the request.
    fn cookies(&mut self) -> Arc<HashMap<String, Cookie<'static>>>;

    /// Enforces a specific Content-Type.
    fn enforce_content_type(&self, mime: Mime) -> Result<(), WrongContentTypeError>;

    /// A map of response headers to allow methods of this trait to set response
    /// headers without needing to take a [`Response`](hyper::http::response::Response) as an argument.
    ///
    /// You need to take care to append these headers to the response yourself.
    /// This is intended to be done after your routing logic so that your
    /// individual request handlers don't have to worry about it.
    fn response_headers(&mut self) -> &mut HeaderMap;
}

impl SputnikParts for Parts {
    fn query<T: DeserializeOwned>(&self) -> Result<T,QueryError> {
        serde_urlencoded::from_str::<T>(self.uri.query().unwrap_or("")).map_err(QueryError)
    }

    fn response_headers(&mut self) -> &mut HeaderMap {
        if self.extensions.get::<HeaderMap>().is_none() {
            self.extensions.insert(HeaderMap::new());
        }
        self.extensions.get_mut::<HeaderMap>().unwrap()
    }

    fn cookies(&mut self) -> Arc<HashMap<String, Cookie<'static>>> {
        let cookies: Option<&Arc<HashMap<String, Cookie>>> = self.extensions.get();
        if let Some(cookies) = cookies {
            return cookies.clone();
        }

        let mut cookies = HashMap::new();
        for header in self.headers.get_all(header::COOKIE) {
            let raw_str = match std::str::from_utf8(header.as_bytes()) {
                Ok(string) => string,
                Err(_) => continue
            };

            for cookie_str in raw_str.split(';').map(|s| s.trim()) {
                if let Ok(cookie) = Cookie::parse_encoded(cookie_str) {
                    cookies.insert(cookie.name().to_string(), cookie.into_owned());
                }
            }
        }
        let cookies = Arc::new(cookies);
        self.extensions.insert(cookies.clone());
        cookies
    }

    fn enforce_content_type(&self, mime: Mime) -> Result<(), WrongContentTypeError> {
        if let Some(content_type) = self.headers.get(header::CONTENT_TYPE) {
            if *content_type == mime.to_string() {
                return Ok(())
            }
        }
        Err(WrongContentTypeError{expected: mime, received: self.headers.get(header::CONTENT_TYPE).as_ref().and_then(|h| h.to_str().ok().map(|s| s.to_owned()))})
    }
}

const FLASH_COOKIE_NAME: &str = "flash";

/// Show the user a message after redirecting them.
pub struct Flash {
    name: String,
    message: String,
}

impl From<Flash> for Cookie<'_> {
    fn from(flash: Flash) -> Self {
        Cookie::build(FLASH_COOKIE_NAME, format!("{}:{}", flash.name, flash.message))
        .max_age(Duration::minutes(5)).finish()
    }
}

impl Flash {
    /// If the request has a flash cookie retrieve it and append a set-cookie
    /// header to delete the cookie to [`SputnikParts::response_headers`].
    pub fn from_request(req: &mut Parts) -> Option<Self> {
        req.cookies().get(FLASH_COOKIE_NAME)
        .and_then(|cookie| {
            req.response_headers().set_cookie(delete_cookie(FLASH_COOKIE_NAME));
            let mut iter = cookie.value().splitn(2, ':');
            if let (Some(name), Some(message)) = (iter.next(), iter.next()) {
                return Some(Flash{name: name.to_owned(), message: message.to_owned()})
            }
            None
        })
    }

    /// Constructs a new Flash message. The name must not contain a colon (`:`).
    pub fn new(name: String, message: String) -> Self {
        Flash{name, message}
    }

    /// Constructs a new "success" Flash message.
    pub fn success(message: String) -> Self {
        Flash{name: "success".to_owned(), message}
    }

    /// Constructs a new "warning" Flash message.
    pub fn warning(message: String) -> Self {
        Flash{name: "warning".to_owned(), message}
    }

    /// Constructs a new "error" Flash message.
    pub fn error(message: String) -> Self {
        Flash{name: "error".to_owned(), message}
    }

    /// Returns the name of the Flash message.
    pub fn name(&self) -> &str {
        &self.name
    }

    /// Returns the message of the Flash message.
    pub fn message(&self) -> &str {
        &self.message
    }
}

/// Adds deserialization methods to [`hyper::Body`].
#[async_trait]
pub trait SputnikBody {
    async fn into_bytes(self) -> Result<Bytes, BodyError>;

    /// Parses a `application/x-www-form-urlencoded` request body into a given struct.
    async fn into_form<T: DeserializeOwned>(self) -> Result<T, FormError>;

    /// Attempts to deserialize the request body as JSON.
    #[cfg(feature = "json")]
    #[cfg_attr(docsrs, doc(cfg(feature = "json")))]
    async fn into_json<T: DeserializeOwned>(self) -> Result<T, JsonError>;
}

#[async_trait]
impl SputnikBody for hyper::Body {
    async fn into_bytes(self) -> Result<Bytes, BodyError> {
        hyper::body::to_bytes(self).await.map_err(BodyError)
    }

    async fn into_form<T: DeserializeOwned>(self) -> Result<T, FormError> {
        let full_body = self.into_bytes().await?;
        Ok(serde_urlencoded::from_bytes::<T>(&full_body)?)
    }

    #[cfg(feature = "json")]
    #[cfg_attr(docsrs, doc(cfg(feature = "json")))]
    async fn into_json<T: DeserializeOwned>(self) -> Result<T, JsonError> {
        let full_body = self.into_bytes().await?;
        Ok(serde_json::from_slice::<T>(&full_body)?)
    }
}

#[derive(thiserror::Error, Debug)]
#[error("query deserialize error: {0}")]
pub struct QueryError(pub serde_urlencoded::de::Error);

#[derive(thiserror::Error, Debug)]
#[error("failed to read body")]
pub struct BodyError(pub hyper::Error);

#[derive(thiserror::Error, Debug)]
#[error("expected Content-Type {expected} but received {}", received.as_ref().unwrap_or(&"nothing".to_owned()))]
pub struct WrongContentTypeError {
    pub expected: Mime,
    pub received: Option<String>,
}

#[derive(thiserror::Error, Debug)]
pub enum FormError {
    #[error("{0}")]
    Body(#[from] BodyError),

    #[error("form deserialize error: {0}")]
    Deserialize(#[from] serde_urlencoded::de::Error),
}

#[cfg(feature = "json")]
#[cfg_attr(docsrs, doc(cfg(feature = "json")))]
#[derive(thiserror::Error, Debug)]
pub enum JsonError {
    #[error("{0}")]
    Body(#[from] BodyError),

    #[error("json deserialize error: {0}")]
    Deserialize(#[from] serde_json::Error),
}

#[cfg(test)]
mod tests {
    use std::convert::TryInto;

    use hyper::{Request, header};

    use super::SputnikParts;

    #[test]
    fn test_enforce_content_type() {
        let (mut parts, _body) = Request::new(hyper::Body::empty()).into_parts();
        assert!(parts.enforce_content_type(mime::APPLICATION_JSON).is_err());

        parts.headers.append(header::CONTENT_TYPE, "application/json".try_into().unwrap());
        assert!(parts.enforce_content_type(mime::APPLICATION_JSON).is_ok());

        parts.headers.insert(header::CONTENT_TYPE, "text/html".try_into().unwrap());
        assert!(parts.enforce_content_type(mime::APPLICATION_JSON).is_err());
    }
}