From 9fa7442e41bc11ab3d62f43f5f6e90b59e160da2 Mon Sep 17 00:00:00 2001
From: Martin Fischer <martin@push-f.com>
Date: Mon, 25 Jan 2021 14:47:47 +0100
Subject: simplify CSRF API

This commit gets rid of the CsrfToken type,
simplifying submission handling:

  // before
  let csrf_token = req.csrf_token(&mut response);
  let msg: FormData = body.into_form_csrf(&csrf_token).await?;

  // after
  let msg: FormData = body.into_form_csrf(req).await?;

As well as HTML input retrieval:

  // before
  req.csrf_token(&mut response).html_input();

  // after
  req.csrf_html_input(&mut response);

This commit also merges the CsrfError type into CsrfProtectedFormError.

bump version to 0.3.1
---
 src/request.rs  | 75 +++++++++++++++++++++++++++++++++------------------------
 src/security.rs | 35 +--------------------------
 2 files changed, 45 insertions(+), 65 deletions(-)

(limited to 'src')

diff --git a/src/request.rs b/src/request.rs
index 1166ef2..509e2e7 100644
--- a/src/request.rs
+++ b/src/request.rs
@@ -2,14 +2,16 @@
 
 use cookie::Cookie;
 use mime::Mime;
-use rand::{Rng, distributions::Alphanumeric};
-use security::CsrfToken;
 use serde::{Deserialize, de::DeserializeOwned};
-use hyper::{body::Bytes, header};
+use hyper::{body::Bytes, header, http::{request::Parts, response::Builder}};
 use time::Duration;
 use std::{collections::HashMap, sync::Arc};
+use rand::{Rng, distributions::Alphanumeric};
+use async_trait::async_trait;
 
-use crate::{response::SputnikBuilder, security};
+use crate::response::SputnikBuilder;
+
+const CSRF_COOKIE_NAME : &str = "csrf";
 
 pub trait SputnikParts {
     /// Parses the query string of the request into a given struct.
@@ -18,12 +20,14 @@ pub trait SputnikParts {
     /// Parses the cookies of the request.
     fn cookies(&mut self) -> Arc<HashMap<String, Cookie<'static>>>;
 
-    /// Retrieves the CSRF token from a `csrf` cookie or generates
-    /// a new token and stores it as a cookie if it doesn't exist.
-    fn csrf_token(&mut self, builder: &mut dyn SputnikBuilder) -> CsrfToken;
-
     /// Enforces a specific Content-Type.
     fn enforce_content_type(&self, mime: Mime) -> Result<(), WrongContentTypeError>;
+
+    /// Retrievs the CSRF token from a cookie or generates
+    /// a new token and stores it as a cookie if it doesn't exist.
+    /// Returns a hidden HTML input to be embedded in forms that are received
+    /// with [`crate::request::SputnikBody::into_form_csrf`].
+    fn csrf_html_input(&mut self, builder: &mut Builder) -> String;
 }
 
 impl SputnikParts for hyper::http::request::Parts {
@@ -55,18 +59,6 @@ impl SputnikParts for hyper::http::request::Parts {
         cookies
     }
 
-    fn csrf_token(&mut self, builder: &mut dyn SputnikBuilder) -> CsrfToken {
-        if let Some(cookie) = self.cookies().get("csrf") {
-            return CsrfToken{token: cookie.value().to_string(), from_client: true}
-        }
-        let val: String = rand::thread_rng().sample_iter(Alphanumeric).take(16).collect();
-        let mut c = Cookie::new("csrf", val.clone());
-        c.set_secure(Some(true));
-        c.set_max_age(Some(Duration::hours(1)));
-        builder.set_cookie(c);
-        CsrfToken{token: val, from_client: false}
-    }
-
     fn enforce_content_type(&self, mime: Mime) -> Result<(), WrongContentTypeError> {
         if let Some(content_type) = self.headers.get(header::CONTENT_TYPE) {
             if *content_type == mime.to_string() {
@@ -75,9 +67,19 @@ impl SputnikParts for hyper::http::request::Parts {
         }
         Err(WrongContentTypeError{expected: mime, received: self.headers.get(header::CONTENT_TYPE).as_ref().and_then(|h| h.to_str().ok().map(|s| s.to_owned()))})
     }
-}
 
-use async_trait::async_trait;
+    fn csrf_html_input(&mut self, builder: &mut Builder) -> String {
+        let token = csrf_token_from_cookies(self).unwrap_or_else(|| {
+            let token: String = rand::thread_rng().sample_iter(Alphanumeric).take(16).collect();
+            let mut c = Cookie::new(CSRF_COOKIE_NAME, token.clone());
+            c.set_secure(Some(true));
+            c.set_max_age(Some(Duration::hours(1)));
+            builder.set_cookie(c);
+            token
+        });
+        format!("<input name=csrf type=hidden value=\"{}\">", token)
+    }
+}
 
 #[async_trait]
 pub trait SputnikBody {
@@ -92,9 +94,12 @@ pub trait SputnikBody {
     /// Parses a `application/x-www-form-urlencoded` request body into a given struct.
     /// Protects from CSRF by checking that the request body contains the same token retrieved from the cookies.
     ///
-    /// The CSRF parameter is expected as the `csrf` parameter in the request body.
-    /// This means for HTML forms you need to embed the token as a hidden input.
-    async fn into_form_csrf<T: DeserializeOwned>(self, csrf_token: &security::CsrfToken) -> Result<T, CsrfProtectedFormError>;
+    /// The HTML form must embed a hidden input generated with [`crate::request::SputnikParts::csrf_html_input`].
+    async fn into_form_csrf<T: DeserializeOwned>(self, req: &mut Parts) -> Result<T, CsrfProtectedFormError>;
+}
+
+fn csrf_token_from_cookies(req: &mut Parts) -> Option<String> {
+    req.cookies().get(CSRF_COOKIE_NAME).map(|c| c.value().to_string())
 }
 
 #[async_trait]
@@ -108,11 +113,17 @@ impl SputnikBody for hyper::Body {
         Ok(serde_urlencoded::from_bytes::<T>(&full_body)?)
     }
 
-    async fn into_form_csrf<T: DeserializeOwned>(self, csrf_token: &CsrfToken) -> Result<T, CsrfProtectedFormError> {
+    async fn into_form_csrf<T: DeserializeOwned>(self, req: &mut Parts) -> Result<T, CsrfProtectedFormError> {
         let full_body = self.into_bytes().await?;
         let csrf_data = serde_urlencoded::from_bytes::<CsrfData>(&full_body).map_err(|_| CsrfProtectedFormError::NoCsrf)?;
-        csrf_token.matches(csrf_data.csrf)?;
-        serde_urlencoded::from_bytes::<T>(&full_body).map_err(CsrfProtectedFormError::Deserialize)
+        match csrf_token_from_cookies(req) {
+            Some(token) => if token == csrf_data.csrf {
+                Ok(serde_urlencoded::from_bytes::<T>(&full_body)?)
+            } else {
+                Err(CsrfProtectedFormError::Mismatch)
+            }
+            None => Err(CsrfProtectedFormError::NoCookie)
+        }
     }
 }
 
@@ -121,7 +132,6 @@ struct CsrfData {
     csrf: String,
 }
 
-use crate::security::CsrfError;
 #[derive(thiserror::Error, Debug)]
 #[error("query deserialize error: {0}")]
 pub struct QueryError(pub serde_urlencoded::de::Error);
@@ -157,6 +167,9 @@ pub enum CsrfProtectedFormError {
     #[error("no csrf token in form data")]
     NoCsrf,
 
-    #[error("{0}")]
-    Csrf(#[from] CsrfError),
+    #[error("no csrf cookie")]
+    NoCookie,
+
+    #[error("csrf parameter doesn't match csrf cookie")]
+    Mismatch,
 }
\ No newline at end of file
diff --git a/src/security.rs b/src/security.rs
index 5247d9e..0ffa7a0 100644
--- a/src/security.rs
+++ b/src/security.rs
@@ -1,42 +1,9 @@
-//! [`CsrfToken`], [`Key`] and functions to encode & decode expiring claims.
+//! [`Key`] and functions to encode & decode expiring claims.
 
 use time::OffsetDateTime;
-use thiserror::Error;
 
 pub use crate::signed::Key;
 
-/// A cookie-based CSRF token to be used with [`crate::request::SputnikBody::into_form_csrf`].
-pub struct CsrfToken {
-    pub(crate) token: String,
-    pub(crate) from_client: bool,
-}
-
-#[derive(Error, Debug)]
-pub enum CsrfError {
-    #[error("expected csrf cookie")]
-    NoCookie,
-
-    #[error("csrf parameter doesn't match csrf cookie")]
-    Mismatch,
-}
-
-impl CsrfToken {
-    /// Wraps the token in a hidden HTML input.
-    pub fn html_input(&self) -> String {
-        format!("<input name=csrf type=hidden value=\"{}\">", self.token)
-    }
-
-    pub(crate) fn matches(&self, str: String) -> Result<(), CsrfError> {
-        if !self.from_client {
-            return Err(CsrfError::NoCookie)
-        }
-        if self.token != str {
-            return Err(CsrfError::Mismatch)
-        }
-        Ok(())
-    }
-}
-
 /// Join a string and an expiry date together into a string.
 pub fn encode_expiring_claim(claim: &str, expiry_date: OffsetDateTime) -> String {
     format!("{}:{}", claim, expiry_date.unix_timestamp())
-- 
cgit v1.2.3