diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/request.rs | 36 | ||||
-rw-r--r-- | src/response.rs | 49 |
2 files changed, 48 insertions, 37 deletions
diff --git a/src/request.rs b/src/request.rs index 947142c..e64f9dd 100644 --- a/src/request.rs +++ b/src/request.rs @@ -3,13 +3,13 @@ use cookie::Cookie; use mime::Mime; use serde::{Deserialize, de::DeserializeOwned}; -use hyper::{body::Bytes, header, http::{request::Parts, response::Builder}}; +use hyper::{HeaderMap, body::Bytes, header, http::request::Parts}; use time::Duration; use std::{collections::HashMap, sync::Arc}; use rand::{Rng, distributions::Alphanumeric}; use async_trait::async_trait; -use crate::response::SputnikBuilder; +use crate::response::SputnikHeaders; const CSRF_COOKIE_NAME : &str = "csrf"; @@ -23,13 +23,22 @@ pub trait SputnikParts { /// Enforces a specific Content-Type. fn enforce_content_type(&self, mime: Mime) -> Result<(), WrongContentTypeError>; + /// A map of response headers to allow methods of this trait to set response + /// headers without needing to take a [`Response`](hyper::http::response::Response) as an argument. + /// + /// You need to take care to append these headers to the response yourself. + /// This is intended to be done after your routing logic so that your + /// individual request handlers don't have to worry about it. + fn response_headers(&mut self) -> &mut HeaderMap; + /// Returns a CSRF token, either extracted from the `csrf` cookie or newly - /// generated if the cookie wasn't sent (in which case the cookie is set). + /// generated if the cookie wasn't sent (in which case a set-cookie header is + /// appended to [`Self::response_headers`]). /// /// If there is no cookie, calling this method multiple times only generates /// a new token on the first call, further calls return the previously /// generated token. - fn csrf_token(&mut self, builder: &mut Builder) -> CsrfToken; + fn csrf_token(&mut self) -> CsrfToken; } impl SputnikParts for hyper::http::request::Parts { @@ -37,6 +46,13 @@ impl SputnikParts for hyper::http::request::Parts { serde_urlencoded::from_str::<T>(self.uri.query().unwrap_or("")).map_err(QueryError) } + fn response_headers(&mut self) -> &mut HeaderMap { + if self.extensions.get::<HeaderMap>().is_none() { + self.extensions.insert(HeaderMap::new()); + } + self.extensions.get_mut::<HeaderMap>().unwrap() + } + fn cookies(&mut self) -> Arc<HashMap<String, Cookie<'static>>> { let cookies: Option<&Arc<HashMap<String, Cookie>>> = self.extensions.get(); if let Some(cookies) = cookies { @@ -70,7 +86,7 @@ impl SputnikParts for hyper::http::request::Parts { Err(WrongContentTypeError{expected: mime, received: self.headers.get(header::CONTENT_TYPE).as_ref().and_then(|h| h.to_str().ok().map(|s| s.to_owned()))}) } - fn csrf_token(&mut self, builder: &mut Builder) -> CsrfToken { + fn csrf_token(&mut self) -> CsrfToken { if let Some(token) = self.extensions.get::<CsrfToken>() { return token.clone() } @@ -83,7 +99,8 @@ impl SputnikParts for hyper::http::request::Parts { let mut c = Cookie::new(CSRF_COOKIE_NAME, token.clone()); c.set_secure(Some(true)); c.set_max_age(Some(Duration::hours(1))); - builder.set_cookie(c); + + self.response_headers().set_cookie(c); let token = CsrfToken(token); self.extensions.insert(token.clone()); token @@ -240,10 +257,9 @@ mod tests { #[test] fn test_csrf_token() { let mut parts = Request::new(hyper::Body::empty()).into_parts().0; - let mut builder = Builder::new(); - let tok1 = parts.csrf_token(&mut builder); - let tok2 = parts.csrf_token(&mut builder); + let tok1 = parts.csrf_token(); + let tok2 = parts.csrf_token(); assert_eq!(tok1.to_string(), tok2.to_string()); - assert_eq!(builder.body(hyper::Body::empty()).unwrap().headers().len(), 1); + assert_eq!(parts.response_headers().len(), 1); } }
\ No newline at end of file diff --git a/src/response.rs b/src/response.rs index 78e9a70..ceb0f61 100644 --- a/src/response.rs +++ b/src/response.rs @@ -3,31 +3,39 @@ use std::convert::TryInto; use cookie::Cookie; -use hyper::{StatusCode, header, http}; +use hyper::{HeaderMap, StatusCode, header, http}; use time::{Duration, OffsetDateTime}; use hyper::http::response::Builder; pub trait SputnikBuilder { - /// Appends a Set-Cookie header. - fn set_cookie(&mut self, cookie: Cookie); - - /// Appends a Set-Cookie header to delete a cookie. - fn delete_cookie(&mut self, name: &str); - /// Sets the Content-Type. fn content_type(self, mime: mime::Mime) -> Builder; } - pub fn redirect(location: &str, code: StatusCode) -> Builder { Builder::new().status(code).header(header::LOCATION, location) } impl SputnikBuilder for Builder { - fn set_cookie(&mut self, cookie: Cookie) { + fn content_type(mut self, mime: mime::Mime) -> Self { if let Some(headers) = self.headers_mut() { - headers.append(header::SET_COOKIE, cookie.encoded().to_string().try_into().unwrap()); + headers.insert(header::CONTENT_TYPE, mime.to_string().try_into().unwrap()); } + self + } +} + +pub trait SputnikHeaders { + /// Appends a Set-Cookie header. + fn set_cookie(&mut self, cookie: Cookie); + + /// Appends a Set-Cookie header to delete a cookie. + fn delete_cookie(&mut self, name: &str); +} + +impl SputnikHeaders for HeaderMap { + fn set_cookie(&mut self, cookie: Cookie) { + self.append(header::SET_COOKIE, cookie.encoded().to_string().try_into().unwrap()); } fn delete_cookie(&mut self, name: &str) { @@ -36,13 +44,6 @@ impl SputnikBuilder for Builder { cookie.set_expires(OffsetDateTime::now_utc() - Duration::days(365)); self.set_cookie(cookie); } - - fn content_type(mut self, mime: mime::Mime) -> Self { - if let Some(headers) = self.headers_mut() { - headers.insert(header::CONTENT_TYPE, mime.to_string().try_into().unwrap()); - } - self - } } pub trait EmptyBuilder<B> { @@ -62,16 +63,10 @@ mod tests { #[test] fn test_set_cookie() { - let mut builder = Builder::new(); - builder.set_cookie(Cookie::new("some", "cookie")); - builder.set_cookie(Cookie::new("some", "cookie")); - let resp = builder.body(hyper::Body::empty()).unwrap(); - assert_eq!(resp.headers().len(), 2); - - let mut builder = Builder::new() - .header("foo", "invalid\r\n"); - // doesn't panic after invalid header - builder.set_cookie(Cookie::new("some", "cookie")); + let mut map = HeaderMap::new(); + map.set_cookie(Cookie::new("some", "cookie")); + map.set_cookie(Cookie::new("some", "cookie")); + assert_eq!(map.len(), 2); } #[test] |