//! Provides the [`SputnikParts`] and [`SputnikBody`] traits. use cookie::Cookie; use mime::Mime; use serde::{Deserialize, de::DeserializeOwned}; use hyper::{body::Bytes, header, http::{request::Parts, response::Builder}}; use time::Duration; use std::{collections::HashMap, sync::Arc}; use rand::{Rng, distributions::Alphanumeric}; use async_trait::async_trait; use crate::response::SputnikBuilder; const CSRF_COOKIE_NAME : &str = "csrf"; pub trait SputnikParts { /// Parses the query string of the request into a given struct. fn query(&self) -> Result; /// Parses the cookies of the request. fn cookies(&mut self) -> Arc>>; /// Enforces a specific Content-Type. fn enforce_content_type(&self, mime: Mime) -> Result<(), WrongContentTypeError>; /// Returns a CSRF token, either extracted from the `csrf` cookie or newly /// generated if the cookie wasn't sent (in which case the cookie is set). /// /// If there is no cookie, calling this method multiple times only generates /// a new token on the first call, further calls return the previously /// generated token. fn csrf_token(&mut self, builder: &mut Builder) -> CsrfToken; } impl SputnikParts for hyper::http::request::Parts { fn query(&self) -> Result { serde_urlencoded::from_str::(self.uri.query().unwrap_or("")).map_err(QueryError) } fn cookies(&mut self) -> Arc>> { let cookies: Option<&Arc>> = self.extensions.get(); if let Some(cookies) = cookies { return cookies.clone(); } let mut cookies = HashMap::new(); for header in self.headers.get_all(header::COOKIE) { let raw_str = match std::str::from_utf8(header.as_bytes()) { Ok(string) => string, Err(_) => continue }; for cookie_str in raw_str.split(';').map(|s| s.trim()) { if let Ok(cookie) = Cookie::parse_encoded(cookie_str) { cookies.insert(cookie.name().to_string(), cookie.into_owned()); } } } let cookies = Arc::new(cookies); self.extensions.insert(cookies.clone()); cookies } fn enforce_content_type(&self, mime: Mime) -> Result<(), WrongContentTypeError> { if let Some(content_type) = self.headers.get(header::CONTENT_TYPE) { if *content_type == mime.to_string() { return Ok(()) } } Err(WrongContentTypeError{expected: mime, received: self.headers.get(header::CONTENT_TYPE).as_ref().and_then(|h| h.to_str().ok().map(|s| s.to_owned()))}) } fn csrf_token(&mut self, builder: &mut Builder) -> CsrfToken { if let Some(token) = self.extensions.get::() { return token.clone() } csrf_token_from_cookies(self) .unwrap_or_else(|| { let token: String = rand::thread_rng().sample_iter( Alphanumeric // must be HTML-safe because we embed it in CsrfToken::html_input ).take(16).collect(); let mut c = Cookie::new(CSRF_COOKIE_NAME, token.clone()); c.set_secure(Some(true)); c.set_max_age(Some(Duration::hours(1))); builder.set_cookie(c); let token = CsrfToken(token); self.extensions.insert(token.clone()); token }) } } /// A CSRF token retrievable with [`SputnikParts::csrf_token`]. #[derive(Clone)] pub struct CsrfToken(String); impl std::fmt::Display for CsrfToken { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0) } } impl CsrfToken { /// Returns a hidden HTML input to be embedded in forms that are received /// with [`crate::request::SputnikBody::into_form_csrf`]. pub fn html_input(&self) -> String { format!("", self) } } #[async_trait] pub trait SputnikBody { async fn into_bytes(self) -> Result; /// Parses a `application/x-www-form-urlencoded` request body into a given struct. /// /// This does make you vulnerable to CSRF, so you normally want to use /// [`SputnikBody::into_form_csrf()`] instead. async fn into_form(self) -> Result; /// Parses a `application/x-www-form-urlencoded` request body into a given struct. /// Protects from CSRF by checking that the request body contains the same token retrieved from the cookies. /// /// The HTML form must embed a hidden input generated with [`CsrfToken::html_input`]. async fn into_form_csrf(self, req: &mut Parts) -> Result; /// Attempts to deserialize the request body as JSON. #[cfg(feature = "json")] #[cfg_attr(docsrs, doc(cfg(feature = "json")))] async fn into_json(self) -> Result; } fn csrf_token_from_cookies(req: &mut Parts) -> Option { req.cookies() .get(CSRF_COOKIE_NAME) .map(|cookie| { let token = CsrfToken(cookie.value().to_string()); req.extensions.insert(token.clone()); token }) } #[async_trait] impl SputnikBody for hyper::Body { async fn into_bytes(self) -> Result { hyper::body::to_bytes(self).await.map_err(BodyError) } async fn into_form(self) -> Result { let full_body = self.into_bytes().await?; Ok(serde_urlencoded::from_bytes::(&full_body)?) } async fn into_form_csrf(self, req: &mut Parts) -> Result { let full_body = self.into_bytes().await?; let csrf_data = serde_urlencoded::from_bytes::(&full_body).map_err(|_| CsrfProtectedFormError::NoCsrf)?; match csrf_token_from_cookies(req) { Some(token) => if token.to_string() == csrf_data.csrf { Ok(serde_urlencoded::from_bytes::(&full_body)?) } else { Err(CsrfProtectedFormError::Mismatch) } None => Err(CsrfProtectedFormError::NoCookie) } } #[cfg(feature = "json")] #[cfg_attr(docsrs, doc(cfg(feature = "json")))] async fn into_json(self) -> Result { let full_body = self.into_bytes().await?; Ok(serde_json::from_slice::(&full_body)?) } } #[derive(Deserialize)] struct CsrfData { csrf: String, } #[derive(thiserror::Error, Debug)] #[error("query deserialize error: {0}")] pub struct QueryError(pub serde_urlencoded::de::Error); #[derive(thiserror::Error, Debug)] #[error("failed to read body")] pub struct BodyError(pub hyper::Error); #[derive(thiserror::Error, Debug)] #[error("expected Content-Type {expected} but received {}", received.as_ref().unwrap_or(&"nothing".to_owned()))] pub struct WrongContentTypeError { pub expected: Mime, pub received: Option, } #[derive(thiserror::Error, Debug)] pub enum FormError { #[error("{0}")] Body(#[from] BodyError), #[error("form deserialize error: {0}")] Deserialize(#[from] serde_urlencoded::de::Error), } #[cfg(feature = "json")] #[cfg_attr(docsrs, doc(cfg(feature = "json")))] #[derive(thiserror::Error, Debug)] pub enum JsonError { #[error("{0}")] Body(#[from] BodyError), #[error("json deserialize error: {0}")] Deserialize(#[from] serde_json::Error), } #[derive(thiserror::Error, Debug)] pub enum CsrfProtectedFormError { #[error("{0}")] Body(#[from] BodyError), #[error("form deserialize error: {0}")] Deserialize(#[from] serde_urlencoded::de::Error), #[error("no csrf token in form data")] NoCsrf, #[error("no csrf cookie")] NoCookie, #[error("csrf parameter doesn't match csrf cookie")] Mismatch, } #[cfg(test)] mod tests { use hyper::Request; use super::*; #[test] fn test_csrf_token() { let mut parts = Request::new(hyper::Body::empty()).into_parts().0; let mut builder = Builder::new(); let tok1 = parts.csrf_token(&mut builder); let tok2 = parts.csrf_token(&mut builder); assert_eq!(tok1.to_string(), tok2.to_string()); assert_eq!(builder.body(hyper::Body::empty()).unwrap().headers().len(), 1); } }