diff options
-rw-r--r-- | README.md | 4 | ||||
-rw-r--r-- | examples/csrf/src/main.rs | 4 | ||||
-rw-r--r-- | src/request.rs | 56 |
3 files changed, 31 insertions, 33 deletions
@@ -29,7 +29,7 @@ use hyper::{Method, Server, StatusCode, Body}; use hyper::http::request::Parts; use hyper::http::response::Builder; use serde::Deserialize; -use sputnik::{mime, request::{SputnikParts, SputnikBody}, response::SputnikBuilder}; +use sputnik::{mime, request::{SputnikParts, SputnikBody, CsrfToken}, response::SputnikBuilder}; use sputnik::request::CsrfProtectedFormError; type Response = hyper::Response<Body>; @@ -63,7 +63,7 @@ fn get_form(req: &mut Parts) -> Response { .body( format!( "<form method=post><input name=text>{}<button>Submit</button></form>", - req.csrf_token().html_input() + CsrfToken::from_request(req).html_input() ).into() ).unwrap() } diff --git a/examples/csrf/src/main.rs b/examples/csrf/src/main.rs index 7259abd..53ea87f 100644 --- a/examples/csrf/src/main.rs +++ b/examples/csrf/src/main.rs @@ -4,7 +4,7 @@ use hyper::{Method, Server, StatusCode, Body}; use hyper::http::request::Parts; use hyper::http::response::Builder; use serde::Deserialize; -use sputnik::{mime, request::{SputnikParts, SputnikBody}, response::SputnikBuilder}; +use sputnik::{mime, request::{SputnikParts, SputnikBody, CsrfToken}, response::SputnikBuilder}; use sputnik::request::CsrfProtectedFormError; type Response = hyper::Response<Body>; @@ -38,7 +38,7 @@ fn get_form(req: &mut Parts) -> Response { .body( format!( "<form method=post><input name=text>{}<button>Submit</button></form>", - req.csrf_token().html_input() + CsrfToken::from_request(req).html_input() ).into() ).unwrap() } diff --git a/src/request.rs b/src/request.rs index e64f9dd..f2392b8 100644 --- a/src/request.rs +++ b/src/request.rs @@ -30,15 +30,6 @@ pub trait SputnikParts { /// This is intended to be done after your routing logic so that your /// individual request handlers don't have to worry about it. fn response_headers(&mut self) -> &mut HeaderMap; - - /// Returns a CSRF token, either extracted from the `csrf` cookie or newly - /// generated if the cookie wasn't sent (in which case a set-cookie header is - /// appended to [`Self::response_headers`]). - /// - /// If there is no cookie, calling this method multiple times only generates - /// a new token on the first call, further calls return the previously - /// generated token. - fn csrf_token(&mut self) -> CsrfToken; } impl SputnikParts for hyper::http::request::Parts { @@ -85,12 +76,31 @@ impl SputnikParts for hyper::http::request::Parts { } Err(WrongContentTypeError{expected: mime, received: self.headers.get(header::CONTENT_TYPE).as_ref().and_then(|h| h.to_str().ok().map(|s| s.to_owned()))}) } +} + +/// A cookie-based CSRF token to be used with [`SputnikBody::into_form_csrf`]. +#[derive(Clone)] +pub struct CsrfToken(String); + +impl std::fmt::Display for CsrfToken { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} - fn csrf_token(&mut self) -> CsrfToken { - if let Some(token) = self.extensions.get::<CsrfToken>() { +impl CsrfToken { + /// Returns a CSRF token, either extracted from the `csrf` cookie or newly + /// generated if the cookie wasn't sent (in which case a set-cookie header is + /// appended to [`SputnikParts::response_headers`]). + /// + /// If there is no cookie, calling this method multiple times only generates + /// a new token on the first call, further calls return the previously + /// generated token. + pub fn from_request(req: &mut Parts) -> CsrfToken { + if let Some(token) = req.extensions.get::<CsrfToken>() { return token.clone() } - csrf_token_from_cookies(self) + csrf_token_from_cookies(req) .unwrap_or_else(|| { let token: String = rand::thread_rng().sample_iter( Alphanumeric @@ -100,27 +110,15 @@ impl SputnikParts for hyper::http::request::Parts { c.set_secure(Some(true)); c.set_max_age(Some(Duration::hours(1))); - self.response_headers().set_cookie(c); + req.response_headers().set_cookie(c); let token = CsrfToken(token); - self.extensions.insert(token.clone()); + req.extensions.insert(token.clone()); token }) } -} - -/// A CSRF token retrievable with [`SputnikParts::csrf_token`]. -#[derive(Clone)] -pub struct CsrfToken(String); - -impl std::fmt::Display for CsrfToken { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} -impl CsrfToken { /// Returns a hidden HTML input to be embedded in forms that are received - /// with [`crate::request::SputnikBody::into_form_csrf`]. + /// with [`SputnikBody::into_form_csrf`]. pub fn html_input(&self) -> String { format!("<input name=csrf type=hidden value=\"{}\">", self) } @@ -257,8 +255,8 @@ mod tests { #[test] fn test_csrf_token() { let mut parts = Request::new(hyper::Body::empty()).into_parts().0; - let tok1 = parts.csrf_token(); - let tok2 = parts.csrf_token(); + let tok1 = CsrfToken::from_request(&mut parts); + let tok2 = CsrfToken::from_request(&mut parts); assert_eq!(tok1.to_string(), tok2.to_string()); assert_eq!(parts.response_headers().len(), 1); } |