diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/request.rs | 45 |
1 files changed, 39 insertions, 6 deletions
diff --git a/src/request.rs b/src/request.rs index 7734d03..947142c 100644 --- a/src/request.rs +++ b/src/request.rs @@ -25,6 +25,10 @@ pub trait SputnikParts { /// Returns a CSRF token, either extracted from the `csrf` cookie or newly /// generated if the cookie wasn't sent (in which case the cookie is set). + /// + /// If there is no cookie, calling this method multiple times only generates + /// a new token on the first call, further calls return the previously + /// generated token. fn csrf_token(&mut self, builder: &mut Builder) -> CsrfToken; } @@ -67,7 +71,11 @@ impl SputnikParts for hyper::http::request::Parts { } fn csrf_token(&mut self, builder: &mut Builder) -> CsrfToken { - let token = csrf_token_from_cookies(self).unwrap_or_else(|| { + if let Some(token) = self.extensions.get::<CsrfToken>() { + return token.clone() + } + csrf_token_from_cookies(self) + .unwrap_or_else(|| { let token: String = rand::thread_rng().sample_iter( Alphanumeric // must be HTML-safe because we embed it in CsrfToken::html_input @@ -76,13 +84,15 @@ impl SputnikParts for hyper::http::request::Parts { c.set_secure(Some(true)); c.set_max_age(Some(Duration::hours(1))); builder.set_cookie(c); + let token = CsrfToken(token); + self.extensions.insert(token.clone()); token - }); - CsrfToken(token) + }) } } /// A CSRF token retrievable with [`SputnikParts::csrf_token`]. +#[derive(Clone)] pub struct CsrfToken(String); impl std::fmt::Display for CsrfToken { @@ -121,8 +131,14 @@ pub trait SputnikBody { async fn into_json<T: DeserializeOwned>(self) -> Result<T, JsonError>; } -fn csrf_token_from_cookies(req: &mut Parts) -> Option<String> { - req.cookies().get(CSRF_COOKIE_NAME).map(|c| c.value().to_string()) +fn csrf_token_from_cookies(req: &mut Parts) -> Option<CsrfToken> { + req.cookies() + .get(CSRF_COOKIE_NAME) + .map(|cookie| { + let token = CsrfToken(cookie.value().to_string()); + req.extensions.insert(token.clone()); + token + }) } #[async_trait] @@ -140,7 +156,7 @@ impl SputnikBody for hyper::Body { let full_body = self.into_bytes().await?; let csrf_data = serde_urlencoded::from_bytes::<CsrfData>(&full_body).map_err(|_| CsrfProtectedFormError::NoCsrf)?; match csrf_token_from_cookies(req) { - Some(token) => if token == csrf_data.csrf { + Some(token) => if token.to_string() == csrf_data.csrf { Ok(serde_urlencoded::from_bytes::<T>(&full_body)?) } else { Err(CsrfProtectedFormError::Mismatch) @@ -213,4 +229,21 @@ pub enum CsrfProtectedFormError { #[error("csrf parameter doesn't match csrf cookie")] Mismatch, +} + +#[cfg(test)] +mod tests { + use hyper::Request; + + use super::*; + + #[test] + fn test_csrf_token() { + let mut parts = Request::new(hyper::Body::empty()).into_parts().0; + let mut builder = Builder::new(); + let tok1 = parts.csrf_token(&mut builder); + let tok2 = parts.csrf_token(&mut builder); + assert_eq!(tok1.to_string(), tok2.to_string()); + assert_eq!(builder.body(hyper::Body::empty()).unwrap().headers().len(), 1); + } }
\ No newline at end of file |