//! Provides the [`SputnikParts`] and [`SputnikBody`] traits. use cookie::Cookie; use mime::Mime; use serde::de::DeserializeOwned; use hyper::{HeaderMap, body::Bytes, header, http::request::Parts}; use time::Duration; use std::{collections::HashMap, sync::Arc}; use async_trait::async_trait; use crate::response::{SputnikHeaders, delete_cookie}; /// Adds convenience methods to [`http::request::Parts`](Parts). pub trait SputnikParts { /// Parses the query string of the request into a given struct. fn query(&self) -> Result; /// Parses the cookies of the request. fn cookies(&mut self) -> Arc>>; /// Enforces a specific Content-Type. fn enforce_content_type(&self, mime: Mime) -> Result<(), WrongContentTypeError>; /// A map of response headers to allow methods of this trait to set response /// headers without needing to take a [`Response`](hyper::http::response::Response) as an argument. /// /// You need to take care to append these headers to the response yourself. /// This is intended to be done after your routing logic so that your /// individual request handlers don't have to worry about it. fn response_headers(&mut self) -> &mut HeaderMap; } impl SputnikParts for Parts { fn query(&self) -> Result { serde_urlencoded::from_str::(self.uri.query().unwrap_or("")).map_err(QueryError) } fn response_headers(&mut self) -> &mut HeaderMap { if self.extensions.get::().is_none() { self.extensions.insert(HeaderMap::new()); } self.extensions.get_mut::().unwrap() } fn cookies(&mut self) -> Arc>> { let cookies: Option<&Arc>> = self.extensions.get(); if let Some(cookies) = cookies { return cookies.clone(); } let mut cookies = HashMap::new(); for header in self.headers.get_all(header::COOKIE) { let raw_str = match std::str::from_utf8(header.as_bytes()) { Ok(string) => string, Err(_) => continue }; for cookie_str in raw_str.split(';').map(|s| s.trim()) { if let Ok(cookie) = Cookie::parse_encoded(cookie_str) { cookies.insert(cookie.name().to_string(), cookie.into_owned()); } } } let cookies = Arc::new(cookies); self.extensions.insert(cookies.clone()); cookies } fn enforce_content_type(&self, mime: Mime) -> Result<(), WrongContentTypeError> { if let Some(content_type) = self.headers.get(header::CONTENT_TYPE) { if *content_type == mime.to_string() { return Ok(()) } } Err(WrongContentTypeError{expected: mime, received: self.headers.get(header::CONTENT_TYPE).as_ref().and_then(|h| h.to_str().ok().map(|s| s.to_owned()))}) } } const FLASH_COOKIE_NAME: &str = "flash"; /// Show the user a message after redirecting them. pub struct Flash { name: String, message: String, } impl From for Cookie<'_> { fn from(flash: Flash) -> Self { Cookie::build(FLASH_COOKIE_NAME, format!("{}:{}", flash.name, flash.message)) .max_age(Duration::minutes(5)).finish() } } impl Flash { /// If the request has a flash cookie retrieve it and append a set-cookie /// header to delete the cookie to [`SputnikParts::response_headers`]. pub fn from_request(req: &mut Parts) -> Option { req.cookies().get(FLASH_COOKIE_NAME) .and_then(|cookie| { req.response_headers().set_cookie(delete_cookie(FLASH_COOKIE_NAME)); let mut iter = cookie.value().splitn(2, ':'); if let (Some(name), Some(message)) = (iter.next(), iter.next()) { return Some(Flash{name: name.to_owned(), message: message.to_owned()}) } None }) } /// Constructs a new Flash message. The name must not contain a colon (`:`). pub fn new(name: String, message: String) -> Self { Flash{name, message} } /// Constructs a new "success" Flash message. pub fn success(message: String) -> Self { Flash{name: "success".to_owned(), message} } /// Constructs a new "warning" Flash message. pub fn warning(message: String) -> Self { Flash{name: "warning".to_owned(), message} } /// Constructs a new "error" Flash message. pub fn error(message: String) -> Self { Flash{name: "error".to_owned(), message} } /// Returns the name of the Flash message. pub fn name(&self) -> &str { &self.name } /// Returns the message of the Flash message. pub fn message(&self) -> &str { &self.message } } /// Adds deserialization methods to [`hyper::Body`]. #[async_trait] pub trait SputnikBody { async fn into_bytes(self) -> Result; /// Parses a `application/x-www-form-urlencoded` request body into a given struct. async fn into_form(self) -> Result; /// Attempts to deserialize the request body as JSON. #[cfg(feature = "json")] #[cfg_attr(docsrs, doc(cfg(feature = "json")))] async fn into_json(self) -> Result; } #[async_trait] impl SputnikBody for hyper::Body { async fn into_bytes(self) -> Result { hyper::body::to_bytes(self).await.map_err(BodyError) } async fn into_form(self) -> Result { let full_body = self.into_bytes().await?; Ok(serde_urlencoded::from_bytes::(&full_body)?) } #[cfg(feature = "json")] #[cfg_attr(docsrs, doc(cfg(feature = "json")))] async fn into_json(self) -> Result { let full_body = self.into_bytes().await?; Ok(serde_json::from_slice::(&full_body)?) } } #[derive(thiserror::Error, Debug)] #[error("query deserialize error: {0}")] pub struct QueryError(pub serde_urlencoded::de::Error); #[derive(thiserror::Error, Debug)] #[error("failed to read body")] pub struct BodyError(pub hyper::Error); #[derive(thiserror::Error, Debug)] #[error("expected Content-Type {expected} but received {}", received.as_ref().unwrap_or(&"nothing".to_owned()))] pub struct WrongContentTypeError { pub expected: Mime, pub received: Option, } #[derive(thiserror::Error, Debug)] pub enum FormError { #[error("{0}")] Body(#[from] BodyError), #[error("form deserialize error: {0}")] Deserialize(#[from] serde_urlencoded::de::Error), } #[cfg(feature = "json")] #[cfg_attr(docsrs, doc(cfg(feature = "json")))] #[derive(thiserror::Error, Debug)] pub enum JsonError { #[error("{0}")] Body(#[from] BodyError), #[error("json deserialize error: {0}")] Deserialize(#[from] serde_json::Error), } #[cfg(test)] mod tests { use std::convert::TryInto; use hyper::{Request, header}; use super::SputnikParts; #[test] fn test_enforce_content_type() { let (mut parts, _body) = Request::new(hyper::Body::empty()).into_parts(); assert!(parts.enforce_content_type(mime::APPLICATION_JSON).is_err()); parts.headers.append(header::CONTENT_TYPE, "application/json".try_into().unwrap()); assert!(parts.enforce_content_type(mime::APPLICATION_JSON).is_ok()); parts.headers.insert(header::CONTENT_TYPE, "text/html".try_into().unwrap()); assert!(parts.enforce_content_type(mime::APPLICATION_JSON).is_err()); } }