From 8e9a4400ea9bcb80c90232fecc2ad2ae5f6c3303 Mon Sep 17 00:00:00 2001
From: Martin Fischer <martin@push-f.com>
Date: Tue, 26 Jan 2021 13:09:37 +0100
Subject: make csrf_token safe to be called multiple times

---
 src/request.rs | 45 +++++++++++++++++++++++++++++++++++++++------
 1 file changed, 39 insertions(+), 6 deletions(-)

diff --git a/src/request.rs b/src/request.rs
index 7734d03..947142c 100644
--- a/src/request.rs
+++ b/src/request.rs
@@ -25,6 +25,10 @@ pub trait SputnikParts {
 
     /// Returns a CSRF token, either extracted from the `csrf` cookie or newly
     /// generated if the cookie wasn't sent (in which case the cookie is set).
+    ///
+    /// If there is no cookie, calling this method multiple times only generates
+    /// a new token on the first call, further calls return the previously
+    /// generated token.
     fn csrf_token(&mut self, builder: &mut Builder) -> CsrfToken;
 }
 
@@ -67,7 +71,11 @@ impl SputnikParts for hyper::http::request::Parts {
     }
 
     fn csrf_token(&mut self, builder: &mut Builder) -> CsrfToken {
-        let token = csrf_token_from_cookies(self).unwrap_or_else(|| {
+        if let Some(token) = self.extensions.get::<CsrfToken>() {
+            return token.clone()
+        }
+        csrf_token_from_cookies(self)
+        .unwrap_or_else(|| {
             let token: String = rand::thread_rng().sample_iter(
                 Alphanumeric
                 // must be HTML-safe because we embed it in CsrfToken::html_input
@@ -76,13 +84,15 @@ impl SputnikParts for hyper::http::request::Parts {
             c.set_secure(Some(true));
             c.set_max_age(Some(Duration::hours(1)));
             builder.set_cookie(c);
+            let token = CsrfToken(token);
+            self.extensions.insert(token.clone());
             token
-        });
-        CsrfToken(token)
+        })
     }
 }
 
 /// A CSRF token retrievable with [`SputnikParts::csrf_token`].
+#[derive(Clone)]
 pub struct CsrfToken(String);
 
 impl std::fmt::Display for CsrfToken {
@@ -121,8 +131,14 @@ pub trait SputnikBody {
     async fn into_json<T: DeserializeOwned>(self) -> Result<T, JsonError>;
 }
 
-fn csrf_token_from_cookies(req: &mut Parts) -> Option<String> {
-    req.cookies().get(CSRF_COOKIE_NAME).map(|c| c.value().to_string())
+fn csrf_token_from_cookies(req: &mut Parts) -> Option<CsrfToken> {
+    req.cookies()
+        .get(CSRF_COOKIE_NAME)
+        .map(|cookie| {
+            let token = CsrfToken(cookie.value().to_string());
+            req.extensions.insert(token.clone());
+            token
+        })
 }
 
 #[async_trait]
@@ -140,7 +156,7 @@ impl SputnikBody for hyper::Body {
         let full_body = self.into_bytes().await?;
         let csrf_data = serde_urlencoded::from_bytes::<CsrfData>(&full_body).map_err(|_| CsrfProtectedFormError::NoCsrf)?;
         match csrf_token_from_cookies(req) {
-            Some(token) => if token == csrf_data.csrf {
+            Some(token) => if token.to_string() == csrf_data.csrf {
                 Ok(serde_urlencoded::from_bytes::<T>(&full_body)?)
             } else {
                 Err(CsrfProtectedFormError::Mismatch)
@@ -213,4 +229,21 @@ pub enum CsrfProtectedFormError {
 
     #[error("csrf parameter doesn't match csrf cookie")]
     Mismatch,
+}
+
+#[cfg(test)]
+mod tests {
+    use hyper::Request;
+
+    use super::*;
+
+    #[test]
+    fn test_csrf_token() {
+        let mut parts = Request::new(hyper::Body::empty()).into_parts().0;
+        let mut builder = Builder::new();
+        let tok1 = parts.csrf_token(&mut builder);
+        let tok2 = parts.csrf_token(&mut builder);
+        assert_eq!(tok1.to_string(), tok2.to_string());
+        assert_eq!(builder.body(hyper::Body::empty()).unwrap().headers().len(), 1);
+    }
 }
\ No newline at end of file
-- 
cgit v1.2.3