aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Cargo.toml2
-rw-r--r--README.md18
-rw-r--r--examples/csrf/src/main.rs18
-rw-r--r--src/request.rs75
-rw-r--r--src/security.rs35
5 files changed, 62 insertions, 86 deletions
diff --git a/Cargo.toml b/Cargo.toml
index 62d1f54..6925869 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -1,6 +1,6 @@
[package]
name = "sputnik"
-version = "0.3.0"
+version = "0.3.1"
authors = ["Martin Fischer <martin@push-f.com>"]
license = "MIT"
description = "A lightweight layer on top of hyper to facilitate building web applications."
diff --git a/README.md b/README.md
index 38bc5d8..9892e21 100644
--- a/README.md
+++ b/README.md
@@ -51,29 +51,27 @@ fn render_error(err: Error) -> (StatusCode, String) {
async fn route(req: &mut Parts, body: Body) -> Result<Response, Error> {
match (&req.method, req.uri.path()) {
- (&Method::GET, "/form") => get_form(req).await,
+ (&Method::GET, "/form") => Ok(get_form(req)),
(&Method::POST, "/form") => post_form(req, body).await,
_ => return Err(Error::NotFound("page not found".to_owned()))
}
}
-async fn get_form(req: &mut Parts) -> Result<Response, Error> {
+fn get_form(req: &mut Parts) -> Response {
let mut response = Builder::new();
- let csrf_token = req.csrf_token(&mut response);
- Ok(response.content_type(mime::TEXT_HTML).body(
+ let csrf_input = req.csrf_html_input(&mut response);
+ response.content_type(mime::TEXT_HTML).body(
format!("<form method=post>
- <input name=text>{}<button>Submit</button></form>", csrf_token.html_input()).into()
- ).unwrap())
+ <input name=text>{}<button>Submit</button></form>", csrf_input).into()
+ ).unwrap()
}
#[derive(Deserialize)]
struct FormData {text: String}
async fn post_form(req: &mut Parts, body: Body) -> Result<Response, Error> {
- let mut response = Builder::new();
- let csrf_token = req.csrf_token(&mut response);
- let msg: FormData = body.into_form_csrf(&csrf_token).await?;
- Ok(response.body(
+ let msg: FormData = body.into_form_csrf(req).await?;
+ Ok(Builder::new().body(
format!("hello {}", msg.text).into()
).unwrap())
}
diff --git a/examples/csrf/src/main.rs b/examples/csrf/src/main.rs
index e7e1bfa..94fd09c 100644
--- a/examples/csrf/src/main.rs
+++ b/examples/csrf/src/main.rs
@@ -26,29 +26,27 @@ fn render_error(err: Error) -> (StatusCode, String) {
async fn route(req: &mut Parts, body: Body) -> Result<Response, Error> {
match (&req.method, req.uri.path()) {
- (&Method::GET, "/form") => get_form(req).await,
+ (&Method::GET, "/form") => Ok(get_form(req)),
(&Method::POST, "/form") => post_form(req, body).await,
_ => return Err(Error::NotFound("page not found".to_owned()))
}
}
-async fn get_form(req: &mut Parts) -> Result<Response, Error> {
+fn get_form(req: &mut Parts) -> Response {
let mut response = Builder::new();
- let csrf_token = req.csrf_token(&mut response);
- Ok(response.content_type(mime::TEXT_HTML).body(
+ let csrf_input = req.csrf_html_input(&mut response);
+ response.content_type(mime::TEXT_HTML).body(
format!("<form method=post>
- <input name=text>{}<button>Submit</button></form>", csrf_token.html_input()).into()
- ).unwrap())
+ <input name=text>{}<button>Submit</button></form>", csrf_input).into()
+ ).unwrap()
}
#[derive(Deserialize)]
struct FormData {text: String}
async fn post_form(req: &mut Parts, body: Body) -> Result<Response, Error> {
- let mut response = Builder::new();
- let csrf_token = req.csrf_token(&mut response);
- let msg: FormData = body.into_form_csrf(&csrf_token).await?;
- Ok(response.body(
+ let msg: FormData = body.into_form_csrf(req).await?;
+ Ok(Builder::new().body(
format!("hello {}", msg.text).into()
).unwrap())
}
diff --git a/src/request.rs b/src/request.rs
index 1166ef2..509e2e7 100644
--- a/src/request.rs
+++ b/src/request.rs
@@ -2,14 +2,16 @@
use cookie::Cookie;
use mime::Mime;
-use rand::{Rng, distributions::Alphanumeric};
-use security::CsrfToken;
use serde::{Deserialize, de::DeserializeOwned};
-use hyper::{body::Bytes, header};
+use hyper::{body::Bytes, header, http::{request::Parts, response::Builder}};
use time::Duration;
use std::{collections::HashMap, sync::Arc};
+use rand::{Rng, distributions::Alphanumeric};
+use async_trait::async_trait;
-use crate::{response::SputnikBuilder, security};
+use crate::response::SputnikBuilder;
+
+const CSRF_COOKIE_NAME : &str = "csrf";
pub trait SputnikParts {
/// Parses the query string of the request into a given struct.
@@ -18,12 +20,14 @@ pub trait SputnikParts {
/// Parses the cookies of the request.
fn cookies(&mut self) -> Arc<HashMap<String, Cookie<'static>>>;
- /// Retrieves the CSRF token from a `csrf` cookie or generates
- /// a new token and stores it as a cookie if it doesn't exist.
- fn csrf_token(&mut self, builder: &mut dyn SputnikBuilder) -> CsrfToken;
-
/// Enforces a specific Content-Type.
fn enforce_content_type(&self, mime: Mime) -> Result<(), WrongContentTypeError>;
+
+ /// Retrievs the CSRF token from a cookie or generates
+ /// a new token and stores it as a cookie if it doesn't exist.
+ /// Returns a hidden HTML input to be embedded in forms that are received
+ /// with [`crate::request::SputnikBody::into_form_csrf`].
+ fn csrf_html_input(&mut self, builder: &mut Builder) -> String;
}
impl SputnikParts for hyper::http::request::Parts {
@@ -55,18 +59,6 @@ impl SputnikParts for hyper::http::request::Parts {
cookies
}
- fn csrf_token(&mut self, builder: &mut dyn SputnikBuilder) -> CsrfToken {
- if let Some(cookie) = self.cookies().get("csrf") {
- return CsrfToken{token: cookie.value().to_string(), from_client: true}
- }
- let val: String = rand::thread_rng().sample_iter(Alphanumeric).take(16).collect();
- let mut c = Cookie::new("csrf", val.clone());
- c.set_secure(Some(true));
- c.set_max_age(Some(Duration::hours(1)));
- builder.set_cookie(c);
- CsrfToken{token: val, from_client: false}
- }
-
fn enforce_content_type(&self, mime: Mime) -> Result<(), WrongContentTypeError> {
if let Some(content_type) = self.headers.get(header::CONTENT_TYPE) {
if *content_type == mime.to_string() {
@@ -75,9 +67,19 @@ impl SputnikParts for hyper::http::request::Parts {
}
Err(WrongContentTypeError{expected: mime, received: self.headers.get(header::CONTENT_TYPE).as_ref().and_then(|h| h.to_str().ok().map(|s| s.to_owned()))})
}
-}
-use async_trait::async_trait;
+ fn csrf_html_input(&mut self, builder: &mut Builder) -> String {
+ let token = csrf_token_from_cookies(self).unwrap_or_else(|| {
+ let token: String = rand::thread_rng().sample_iter(Alphanumeric).take(16).collect();
+ let mut c = Cookie::new(CSRF_COOKIE_NAME, token.clone());
+ c.set_secure(Some(true));
+ c.set_max_age(Some(Duration::hours(1)));
+ builder.set_cookie(c);
+ token
+ });
+ format!("<input name=csrf type=hidden value=\"{}\">", token)
+ }
+}
#[async_trait]
pub trait SputnikBody {
@@ -92,9 +94,12 @@ pub trait SputnikBody {
/// Parses a `application/x-www-form-urlencoded` request body into a given struct.
/// Protects from CSRF by checking that the request body contains the same token retrieved from the cookies.
///
- /// The CSRF parameter is expected as the `csrf` parameter in the request body.
- /// This means for HTML forms you need to embed the token as a hidden input.
- async fn into_form_csrf<T: DeserializeOwned>(self, csrf_token: &security::CsrfToken) -> Result<T, CsrfProtectedFormError>;
+ /// The HTML form must embed a hidden input generated with [`crate::request::SputnikParts::csrf_html_input`].
+ async fn into_form_csrf<T: DeserializeOwned>(self, req: &mut Parts) -> Result<T, CsrfProtectedFormError>;
+}
+
+fn csrf_token_from_cookies(req: &mut Parts) -> Option<String> {
+ req.cookies().get(CSRF_COOKIE_NAME).map(|c| c.value().to_string())
}
#[async_trait]
@@ -108,11 +113,17 @@ impl SputnikBody for hyper::Body {
Ok(serde_urlencoded::from_bytes::<T>(&full_body)?)
}
- async fn into_form_csrf<T: DeserializeOwned>(self, csrf_token: &CsrfToken) -> Result<T, CsrfProtectedFormError> {
+ async fn into_form_csrf<T: DeserializeOwned>(self, req: &mut Parts) -> Result<T, CsrfProtectedFormError> {
let full_body = self.into_bytes().await?;
let csrf_data = serde_urlencoded::from_bytes::<CsrfData>(&full_body).map_err(|_| CsrfProtectedFormError::NoCsrf)?;
- csrf_token.matches(csrf_data.csrf)?;
- serde_urlencoded::from_bytes::<T>(&full_body).map_err(CsrfProtectedFormError::Deserialize)
+ match csrf_token_from_cookies(req) {
+ Some(token) => if token == csrf_data.csrf {
+ Ok(serde_urlencoded::from_bytes::<T>(&full_body)?)
+ } else {
+ Err(CsrfProtectedFormError::Mismatch)
+ }
+ None => Err(CsrfProtectedFormError::NoCookie)
+ }
}
}
@@ -121,7 +132,6 @@ struct CsrfData {
csrf: String,
}
-use crate::security::CsrfError;
#[derive(thiserror::Error, Debug)]
#[error("query deserialize error: {0}")]
pub struct QueryError(pub serde_urlencoded::de::Error);
@@ -157,6 +167,9 @@ pub enum CsrfProtectedFormError {
#[error("no csrf token in form data")]
NoCsrf,
- #[error("{0}")]
- Csrf(#[from] CsrfError),
+ #[error("no csrf cookie")]
+ NoCookie,
+
+ #[error("csrf parameter doesn't match csrf cookie")]
+ Mismatch,
} \ No newline at end of file
diff --git a/src/security.rs b/src/security.rs
index 5247d9e..0ffa7a0 100644
--- a/src/security.rs
+++ b/src/security.rs
@@ -1,42 +1,9 @@
-//! [`CsrfToken`], [`Key`] and functions to encode & decode expiring claims.
+//! [`Key`] and functions to encode & decode expiring claims.
use time::OffsetDateTime;
-use thiserror::Error;
pub use crate::signed::Key;
-/// A cookie-based CSRF token to be used with [`crate::request::SputnikBody::into_form_csrf`].
-pub struct CsrfToken {
- pub(crate) token: String,
- pub(crate) from_client: bool,
-}
-
-#[derive(Error, Debug)]
-pub enum CsrfError {
- #[error("expected csrf cookie")]
- NoCookie,
-
- #[error("csrf parameter doesn't match csrf cookie")]
- Mismatch,
-}
-
-impl CsrfToken {
- /// Wraps the token in a hidden HTML input.
- pub fn html_input(&self) -> String {
- format!("<input name=csrf type=hidden value=\"{}\">", self.token)
- }
-
- pub(crate) fn matches(&self, str: String) -> Result<(), CsrfError> {
- if !self.from_client {
- return Err(CsrfError::NoCookie)
- }
- if self.token != str {
- return Err(CsrfError::Mismatch)
- }
- Ok(())
- }
-}
-
/// Join a string and an expiry date together into a string.
pub fn encode_expiring_claim(claim: &str, expiry_date: OffsetDateTime) -> String {
format!("{}:{}", claim, expiry_date.unix_timestamp())