diff options
author | Martin Fischer <martin@push-f.com> | 2021-01-26 14:37:04 +0100 |
---|---|---|
committer | Martin Fischer <martin@push-f.com> | 2021-01-26 15:48:17 +0100 |
commit | fc15b41a37e123434ec39a277f107b78c1507bd8 (patch) | |
tree | 78d450bc04ef64e59a636a37c2147ec9bffba40c | |
parent | 8e9a4400ea9bcb80c90232fecc2ad2ae5f6c3303 (diff) |
introduce SputnikParts::response_headers
-rw-r--r-- | README.md | 19 | ||||
-rw-r--r-- | examples/csrf/src/main.rs | 19 | ||||
-rw-r--r-- | src/request.rs | 36 | ||||
-rw-r--r-- | src/response.rs | 49 |
4 files changed, 74 insertions, 49 deletions
@@ -58,11 +58,13 @@ async fn route(req: &mut Parts, body: Body) -> Result<Response, Error> { } fn get_form(req: &mut Parts) -> Response { - let mut response = Builder::new(); - let csrf_input = req.csrf_token(&mut response).html_input(); - response.content_type(mime::TEXT_HTML).body( - format!("<form method=post> - <input name=text>{}<button>Submit</button></form>", csrf_input).into() + Builder::new() + .content_type(mime::TEXT_HTML) + .body( + format!( + "<form method=post><input name=text>{}<button>Submit</button></form>", + req.csrf_token().html_input() + ).into() ).unwrap() } @@ -79,7 +81,12 @@ async fn post_form(req: &mut Parts, body: Body) -> Result<Response, Error> { async fn service(req: hyper::Request<hyper::Body>) -> Result<hyper::Response<hyper::Body>, Infallible> { let (mut parts, body) = req.into_parts(); match route(&mut parts, body).await { - Ok(res) => Ok(res), + Ok(mut res) => { + for (k,v) in parts.response_headers().iter() { + res.headers_mut().append(k, v.clone()); + } + Ok(res) + } Err(err) => { let (code, message) = render_error(err); // you can easily wrap or log errors here diff --git a/examples/csrf/src/main.rs b/examples/csrf/src/main.rs index 1048689..7259abd 100644 --- a/examples/csrf/src/main.rs +++ b/examples/csrf/src/main.rs @@ -33,11 +33,13 @@ async fn route(req: &mut Parts, body: Body) -> Result<Response, Error> { } fn get_form(req: &mut Parts) -> Response { - let mut response = Builder::new(); - let csrf_input = req.csrf_token(&mut response).html_input(); - response.content_type(mime::TEXT_HTML).body( - format!("<form method=post> - <input name=text>{}<button>Submit</button></form>", csrf_input).into() + Builder::new() + .content_type(mime::TEXT_HTML) + .body( + format!( + "<form method=post><input name=text>{}<button>Submit</button></form>", + req.csrf_token().html_input() + ).into() ).unwrap() } @@ -54,7 +56,12 @@ async fn post_form(req: &mut Parts, body: Body) -> Result<Response, Error> { async fn service(req: hyper::Request<hyper::Body>) -> Result<hyper::Response<hyper::Body>, Infallible> { let (mut parts, body) = req.into_parts(); match route(&mut parts, body).await { - Ok(res) => Ok(res), + Ok(mut res) => { + for (k,v) in parts.response_headers().iter() { + res.headers_mut().append(k, v.clone()); + } + Ok(res) + } Err(err) => { let (code, message) = render_error(err); // you can easily wrap or log errors here diff --git a/src/request.rs b/src/request.rs index 947142c..e64f9dd 100644 --- a/src/request.rs +++ b/src/request.rs @@ -3,13 +3,13 @@ use cookie::Cookie; use mime::Mime; use serde::{Deserialize, de::DeserializeOwned}; -use hyper::{body::Bytes, header, http::{request::Parts, response::Builder}}; +use hyper::{HeaderMap, body::Bytes, header, http::request::Parts}; use time::Duration; use std::{collections::HashMap, sync::Arc}; use rand::{Rng, distributions::Alphanumeric}; use async_trait::async_trait; -use crate::response::SputnikBuilder; +use crate::response::SputnikHeaders; const CSRF_COOKIE_NAME : &str = "csrf"; @@ -23,13 +23,22 @@ pub trait SputnikParts { /// Enforces a specific Content-Type. fn enforce_content_type(&self, mime: Mime) -> Result<(), WrongContentTypeError>; + /// A map of response headers to allow methods of this trait to set response + /// headers without needing to take a [`Response`](hyper::http::response::Response) as an argument. + /// + /// You need to take care to append these headers to the response yourself. + /// This is intended to be done after your routing logic so that your + /// individual request handlers don't have to worry about it. + fn response_headers(&mut self) -> &mut HeaderMap; + /// Returns a CSRF token, either extracted from the `csrf` cookie or newly - /// generated if the cookie wasn't sent (in which case the cookie is set). + /// generated if the cookie wasn't sent (in which case a set-cookie header is + /// appended to [`Self::response_headers`]). /// /// If there is no cookie, calling this method multiple times only generates /// a new token on the first call, further calls return the previously /// generated token. - fn csrf_token(&mut self, builder: &mut Builder) -> CsrfToken; + fn csrf_token(&mut self) -> CsrfToken; } impl SputnikParts for hyper::http::request::Parts { @@ -37,6 +46,13 @@ impl SputnikParts for hyper::http::request::Parts { serde_urlencoded::from_str::<T>(self.uri.query().unwrap_or("")).map_err(QueryError) } + fn response_headers(&mut self) -> &mut HeaderMap { + if self.extensions.get::<HeaderMap>().is_none() { + self.extensions.insert(HeaderMap::new()); + } + self.extensions.get_mut::<HeaderMap>().unwrap() + } + fn cookies(&mut self) -> Arc<HashMap<String, Cookie<'static>>> { let cookies: Option<&Arc<HashMap<String, Cookie>>> = self.extensions.get(); if let Some(cookies) = cookies { @@ -70,7 +86,7 @@ impl SputnikParts for hyper::http::request::Parts { Err(WrongContentTypeError{expected: mime, received: self.headers.get(header::CONTENT_TYPE).as_ref().and_then(|h| h.to_str().ok().map(|s| s.to_owned()))}) } - fn csrf_token(&mut self, builder: &mut Builder) -> CsrfToken { + fn csrf_token(&mut self) -> CsrfToken { if let Some(token) = self.extensions.get::<CsrfToken>() { return token.clone() } @@ -83,7 +99,8 @@ impl SputnikParts for hyper::http::request::Parts { let mut c = Cookie::new(CSRF_COOKIE_NAME, token.clone()); c.set_secure(Some(true)); c.set_max_age(Some(Duration::hours(1))); - builder.set_cookie(c); + + self.response_headers().set_cookie(c); let token = CsrfToken(token); self.extensions.insert(token.clone()); token @@ -240,10 +257,9 @@ mod tests { #[test] fn test_csrf_token() { let mut parts = Request::new(hyper::Body::empty()).into_parts().0; - let mut builder = Builder::new(); - let tok1 = parts.csrf_token(&mut builder); - let tok2 = parts.csrf_token(&mut builder); + let tok1 = parts.csrf_token(); + let tok2 = parts.csrf_token(); assert_eq!(tok1.to_string(), tok2.to_string()); - assert_eq!(builder.body(hyper::Body::empty()).unwrap().headers().len(), 1); + assert_eq!(parts.response_headers().len(), 1); } }
\ No newline at end of file diff --git a/src/response.rs b/src/response.rs index 78e9a70..ceb0f61 100644 --- a/src/response.rs +++ b/src/response.rs @@ -3,31 +3,39 @@ use std::convert::TryInto; use cookie::Cookie; -use hyper::{StatusCode, header, http}; +use hyper::{HeaderMap, StatusCode, header, http}; use time::{Duration, OffsetDateTime}; use hyper::http::response::Builder; pub trait SputnikBuilder { - /// Appends a Set-Cookie header. - fn set_cookie(&mut self, cookie: Cookie); - - /// Appends a Set-Cookie header to delete a cookie. - fn delete_cookie(&mut self, name: &str); - /// Sets the Content-Type. fn content_type(self, mime: mime::Mime) -> Builder; } - pub fn redirect(location: &str, code: StatusCode) -> Builder { Builder::new().status(code).header(header::LOCATION, location) } impl SputnikBuilder for Builder { - fn set_cookie(&mut self, cookie: Cookie) { + fn content_type(mut self, mime: mime::Mime) -> Self { if let Some(headers) = self.headers_mut() { - headers.append(header::SET_COOKIE, cookie.encoded().to_string().try_into().unwrap()); + headers.insert(header::CONTENT_TYPE, mime.to_string().try_into().unwrap()); } + self + } +} + +pub trait SputnikHeaders { + /// Appends a Set-Cookie header. + fn set_cookie(&mut self, cookie: Cookie); + + /// Appends a Set-Cookie header to delete a cookie. + fn delete_cookie(&mut self, name: &str); +} + +impl SputnikHeaders for HeaderMap { + fn set_cookie(&mut self, cookie: Cookie) { + self.append(header::SET_COOKIE, cookie.encoded().to_string().try_into().unwrap()); } fn delete_cookie(&mut self, name: &str) { @@ -36,13 +44,6 @@ impl SputnikBuilder for Builder { cookie.set_expires(OffsetDateTime::now_utc() - Duration::days(365)); self.set_cookie(cookie); } - - fn content_type(mut self, mime: mime::Mime) -> Self { - if let Some(headers) = self.headers_mut() { - headers.insert(header::CONTENT_TYPE, mime.to_string().try_into().unwrap()); - } - self - } } pub trait EmptyBuilder<B> { @@ -62,16 +63,10 @@ mod tests { #[test] fn test_set_cookie() { - let mut builder = Builder::new(); - builder.set_cookie(Cookie::new("some", "cookie")); - builder.set_cookie(Cookie::new("some", "cookie")); - let resp = builder.body(hyper::Body::empty()).unwrap(); - assert_eq!(resp.headers().len(), 2); - - let mut builder = Builder::new() - .header("foo", "invalid\r\n"); - // doesn't panic after invalid header - builder.set_cookie(Cookie::new("some", "cookie")); + let mut map = HeaderMap::new(); + map.set_cookie(Cookie::new("some", "cookie")); + map.set_cookie(Cookie::new("some", "cookie")); + assert_eq!(map.len(), 2); } #[test] |