//! Provides the [`SputnikParts`] trait. use mime::Mime; use serde::Deserialize; use std::str::Split; use std::time::Duration; use crate::http::{header, request::Parts, HeaderMap}; use crate::response::{delete_cookie, Cookie, SputnikHeaders}; /// Adds convenience methods to [`http::request::Parts`](Parts). pub trait SputnikParts { /// Parses the query string of the request into a given struct. fn query<'a, X: Deserialize<'a>>(&'a self) -> Result; /// Parses the cookies of the request. fn cookies(&self) -> CookieIter; /// Enforces a specific Content-Type. fn enforce_content_type(&self, mime: Mime) -> Result<(), WrongContentTypeError>; /// A map of response headers to allow methods of this trait to set response /// headers without needing to take a [`Response`](crate::http::response::Response) as an argument. /// /// You need to take care to append these headers to the response yourself. /// This is intended to be done after your routing logic so that your /// individual request handlers don't have to worry about it. fn response_headers(&mut self) -> &mut HeaderMap; } pub struct CookieIter<'a>(Split<'a, char>); impl<'a> Iterator for CookieIter<'a> { type Item = (&'a str, &'a str); fn next(&mut self) -> Option { self.0.next().and_then(|str| { let mut iter = str.splitn(2, '='); let name = iter.next().expect("first splitn().next() returns Some"); let value = iter.next(); match value { None => self.next(), Some(mut value) => { if value.starts_with('"') && value.ends_with('"') && value.len() >= 2 { value = &value[1..value.len() - 1]; } Some((name, value)) } } }) } } impl SputnikParts for Parts { fn query<'a, T: Deserialize<'a>>(&'a self) -> Result { serde_urlencoded::from_str::(self.uri.query().unwrap_or("")).map_err(QueryError) } fn response_headers(&mut self) -> &mut HeaderMap { if self.extensions.get::().is_none() { self.extensions.insert(HeaderMap::new()); } self.extensions.get_mut::().unwrap() } fn cookies(&self) -> CookieIter { CookieIter( self.headers .get(header::COOKIE) .and_then(|h| std::str::from_utf8(h.as_bytes()).ok()) .unwrap_or("") .split(';'), ) } fn enforce_content_type(&self, mime: Mime) -> Result<(), WrongContentTypeError> { if let Some(content_type) = self.headers.get(header::CONTENT_TYPE) { if *content_type == mime.to_string() { return Ok(()); } } Err(WrongContentTypeError { expected: mime, received: self .headers .get(header::CONTENT_TYPE) .as_ref() .and_then(|h| h.to_str().ok().map(|s| s.to_owned())), }) } } const FLASH_COOKIE_NAME: &str = "flash"; /// Show the user a message after redirecting them. pub struct Flash { name: String, message: String, } impl From for Cookie { fn from(flash: Flash) -> Self { Cookie { name: FLASH_COOKIE_NAME.into(), value: format!("{}:{}", flash.name, flash.message), max_age: Some(Duration::from_secs(5 * 60)), ..Default::default() } } } impl Flash { /// If the request has a flash cookie retrieve it and append a set-cookie /// header to delete the cookie to [`SputnikParts::response_headers`]. pub fn from_request(req: &mut Parts) -> Option { let value = req .cookies() .find(|(name, _value)| *name == FLASH_COOKIE_NAME)? .1 .to_owned(); req.response_headers() .set_cookie(delete_cookie(FLASH_COOKIE_NAME)); let mut iter = value.splitn(2, ':'); if let (Some(name), Some(message)) = (iter.next(), iter.next()) { return Some(Flash { name: name.to_owned(), message: message.to_owned(), }); } None } /// Constructs a new Flash message. The name must not contain a colon (`:`). pub fn new(name: String, message: String) -> Self { Flash { name, message } } /// Constructs a new "success" Flash message. pub fn success(message: String) -> Self { Flash { name: "success".to_owned(), message, } } /// Constructs a new "warning" Flash message. pub fn warning(message: String) -> Self { Flash { name: "warning".to_owned(), message, } } /// Constructs a new "error" Flash message. pub fn error(message: String) -> Self { Flash { name: "error".to_owned(), message, } } /// Returns the name of the Flash message. pub fn name(&self) -> &str { &self.name } /// Returns the message of the Flash message. pub fn message(&self) -> &str { &self.message } } #[derive(thiserror::Error, Debug)] #[error("query deserialize error: {0}")] pub struct QueryError(pub serde_urlencoded::de::Error); #[derive(thiserror::Error, Debug)] #[error("expected Content-Type {expected} but received {}", received.as_ref().unwrap_or(&"nothing".to_owned()))] pub struct WrongContentTypeError { pub expected: Mime, pub received: Option, } #[cfg(test)] mod tests { use std::convert::TryInto; use crate::http::{header, Request}; use super::SputnikParts; #[test] fn test_enforce_content_type() { let (mut parts, _body) = Request::new("").into_parts(); assert!(parts.enforce_content_type(mime::APPLICATION_JSON).is_err()); parts .headers .append(header::CONTENT_TYPE, "application/json".try_into().unwrap()); assert!(parts.enforce_content_type(mime::APPLICATION_JSON).is_ok()); parts .headers .insert(header::CONTENT_TYPE, "text/html".try_into().unwrap()); assert!(parts.enforce_content_type(mime::APPLICATION_JSON).is_err()); } }