aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMartin Fischer <martin@push-f.com>2021-01-26 13:09:37 +0100
committerMartin Fischer <martin@push-f.com>2021-01-26 13:09:48 +0100
commit8e9a4400ea9bcb80c90232fecc2ad2ae5f6c3303 (patch)
tree41592b47ba32fe891e534ff71ec303605929016b
parent18b1b875c80a244f6ea29894672db18b90ec4eea (diff)
make csrf_token safe to be called multiple times
-rw-r--r--src/request.rs45
1 files changed, 39 insertions, 6 deletions
diff --git a/src/request.rs b/src/request.rs
index 7734d03..947142c 100644
--- a/src/request.rs
+++ b/src/request.rs
@@ -25,6 +25,10 @@ pub trait SputnikParts {
/// Returns a CSRF token, either extracted from the `csrf` cookie or newly
/// generated if the cookie wasn't sent (in which case the cookie is set).
+ ///
+ /// If there is no cookie, calling this method multiple times only generates
+ /// a new token on the first call, further calls return the previously
+ /// generated token.
fn csrf_token(&mut self, builder: &mut Builder) -> CsrfToken;
}
@@ -67,7 +71,11 @@ impl SputnikParts for hyper::http::request::Parts {
}
fn csrf_token(&mut self, builder: &mut Builder) -> CsrfToken {
- let token = csrf_token_from_cookies(self).unwrap_or_else(|| {
+ if let Some(token) = self.extensions.get::<CsrfToken>() {
+ return token.clone()
+ }
+ csrf_token_from_cookies(self)
+ .unwrap_or_else(|| {
let token: String = rand::thread_rng().sample_iter(
Alphanumeric
// must be HTML-safe because we embed it in CsrfToken::html_input
@@ -76,13 +84,15 @@ impl SputnikParts for hyper::http::request::Parts {
c.set_secure(Some(true));
c.set_max_age(Some(Duration::hours(1)));
builder.set_cookie(c);
+ let token = CsrfToken(token);
+ self.extensions.insert(token.clone());
token
- });
- CsrfToken(token)
+ })
}
}
/// A CSRF token retrievable with [`SputnikParts::csrf_token`].
+#[derive(Clone)]
pub struct CsrfToken(String);
impl std::fmt::Display for CsrfToken {
@@ -121,8 +131,14 @@ pub trait SputnikBody {
async fn into_json<T: DeserializeOwned>(self) -> Result<T, JsonError>;
}
-fn csrf_token_from_cookies(req: &mut Parts) -> Option<String> {
- req.cookies().get(CSRF_COOKIE_NAME).map(|c| c.value().to_string())
+fn csrf_token_from_cookies(req: &mut Parts) -> Option<CsrfToken> {
+ req.cookies()
+ .get(CSRF_COOKIE_NAME)
+ .map(|cookie| {
+ let token = CsrfToken(cookie.value().to_string());
+ req.extensions.insert(token.clone());
+ token
+ })
}
#[async_trait]
@@ -140,7 +156,7 @@ impl SputnikBody for hyper::Body {
let full_body = self.into_bytes().await?;
let csrf_data = serde_urlencoded::from_bytes::<CsrfData>(&full_body).map_err(|_| CsrfProtectedFormError::NoCsrf)?;
match csrf_token_from_cookies(req) {
- Some(token) => if token == csrf_data.csrf {
+ Some(token) => if token.to_string() == csrf_data.csrf {
Ok(serde_urlencoded::from_bytes::<T>(&full_body)?)
} else {
Err(CsrfProtectedFormError::Mismatch)
@@ -213,4 +229,21 @@ pub enum CsrfProtectedFormError {
#[error("csrf parameter doesn't match csrf cookie")]
Mismatch,
+}
+
+#[cfg(test)]
+mod tests {
+ use hyper::Request;
+
+ use super::*;
+
+ #[test]
+ fn test_csrf_token() {
+ let mut parts = Request::new(hyper::Body::empty()).into_parts().0;
+ let mut builder = Builder::new();
+ let tok1 = parts.csrf_token(&mut builder);
+ let tok2 = parts.csrf_token(&mut builder);
+ assert_eq!(tok1.to_string(), tok2.to_string());
+ assert_eq!(builder.body(hyper::Body::empty()).unwrap().headers().len(), 1);
+ }
} \ No newline at end of file