diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/request.rs | 56 | 
1 files changed, 27 insertions, 29 deletions
| diff --git a/src/request.rs b/src/request.rs index e64f9dd..f2392b8 100644 --- a/src/request.rs +++ b/src/request.rs @@ -30,15 +30,6 @@ pub trait SputnikParts {      /// This is intended to be done after your routing logic so that your      /// individual request handlers don't have to worry about it.      fn response_headers(&mut self) -> &mut HeaderMap; - -    /// Returns a CSRF token, either extracted from the `csrf` cookie or newly -    /// generated if the cookie wasn't sent (in which case a set-cookie header is -    /// appended to [`Self::response_headers`]). -    /// -    /// If there is no cookie, calling this method multiple times only generates -    /// a new token on the first call, further calls return the previously -    /// generated token. -    fn csrf_token(&mut self) -> CsrfToken;  }  impl SputnikParts for hyper::http::request::Parts { @@ -85,12 +76,31 @@ impl SputnikParts for hyper::http::request::Parts {          }          Err(WrongContentTypeError{expected: mime, received: self.headers.get(header::CONTENT_TYPE).as_ref().and_then(|h| h.to_str().ok().map(|s| s.to_owned()))})      } +} + +/// A cookie-based CSRF token to be used with [`SputnikBody::into_form_csrf`]. +#[derive(Clone)] +pub struct CsrfToken(String); + +impl std::fmt::Display for CsrfToken { +    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +        write!(f, "{}", self.0) +    } +} -    fn csrf_token(&mut self) -> CsrfToken { -        if let Some(token) = self.extensions.get::<CsrfToken>() { +impl CsrfToken { +    /// Returns a CSRF token, either extracted from the `csrf` cookie or newly +    /// generated if the cookie wasn't sent (in which case a set-cookie header is +    /// appended to [`SputnikParts::response_headers`]). +    /// +    /// If there is no cookie, calling this method multiple times only generates +    /// a new token on the first call, further calls return the previously +    /// generated token. +    pub fn from_request(req: &mut Parts) -> CsrfToken { +        if let Some(token) = req.extensions.get::<CsrfToken>() {              return token.clone()          } -        csrf_token_from_cookies(self) +        csrf_token_from_cookies(req)          .unwrap_or_else(|| {              let token: String = rand::thread_rng().sample_iter(                  Alphanumeric @@ -100,27 +110,15 @@ impl SputnikParts for hyper::http::request::Parts {              c.set_secure(Some(true));              c.set_max_age(Some(Duration::hours(1))); -            self.response_headers().set_cookie(c); +            req.response_headers().set_cookie(c);              let token = CsrfToken(token); -            self.extensions.insert(token.clone()); +            req.extensions.insert(token.clone());              token          })      } -} - -/// A CSRF token retrievable with [`SputnikParts::csrf_token`]. -#[derive(Clone)] -pub struct CsrfToken(String); - -impl std::fmt::Display for CsrfToken { -    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { -        write!(f, "{}", self.0) -    } -} -impl CsrfToken {      /// Returns a hidden HTML input to be embedded in forms that are received -    /// with [`crate::request::SputnikBody::into_form_csrf`]. +    /// with [`SputnikBody::into_form_csrf`].      pub fn html_input(&self) -> String {          format!("<input name=csrf type=hidden value=\"{}\">", self)      } @@ -257,8 +255,8 @@ mod tests {      #[test]      fn test_csrf_token() {          let mut parts = Request::new(hyper::Body::empty()).into_parts().0; -        let tok1 = parts.csrf_token(); -        let tok2 = parts.csrf_token(); +        let tok1 = CsrfToken::from_request(&mut parts); +        let tok2 = CsrfToken::from_request(&mut parts);          assert_eq!(tok1.to_string(), tok2.to_string());          assert_eq!(parts.response_headers().len(), 1);      } | 
