diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/request.rs | 82 | ||||
| -rw-r--r-- | src/response.rs | 82 | ||||
| -rw-r--r-- | src/security.rs | 17 | 
3 files changed, 122 insertions, 59 deletions
diff --git a/src/request.rs b/src/request.rs index 6efacc9..4f45e04 100644 --- a/src/request.rs +++ b/src/request.rs @@ -1,12 +1,11 @@  //! Provides the [`SputnikParts`] trait. -use cookie::Cookie;  use mime::Mime;  use serde::de::DeserializeOwned; -use time::Duration; -use std::{collections::HashMap, sync::Arc}; +use std::str::Split; +use std::time::Duration; -use crate::response::{SputnikHeaders, delete_cookie}; +use crate::response::{Cookie, SputnikHeaders, delete_cookie};  use crate::http::{HeaderMap, header, request::Parts};  /// Adds convenience methods to [`http::request::Parts`](Parts). @@ -15,7 +14,7 @@ pub trait SputnikParts {      fn query<X: DeserializeOwned>(&self) -> Result<X,QueryError>;      /// Parses the cookies of the request. -    fn cookies(&mut self) -> Arc<HashMap<String, Cookie<'static>>>; +    fn cookies(&self) -> CookieIter;      /// Enforces a specific Content-Type.      fn enforce_content_type(&self, mime: Mime) -> Result<(), WrongContentTypeError>; @@ -29,6 +28,29 @@ pub trait SputnikParts {      fn response_headers(&mut self) -> &mut HeaderMap;  } +pub struct CookieIter<'a>(Split<'a, char>); + +impl<'a> Iterator for CookieIter<'a> { +    type Item = (&'a str, &'a str); + +    fn next(&mut self) -> Option<Self::Item> { +        self.0.next().and_then(|str| { +            let mut iter = str.splitn(2, '='); +            let name = iter.next().expect("first splitn().next() returns Some"); +            let value = iter.next(); +            match value { +                None => self.next(), +                Some(mut value) => { +                    if value.starts_with('"') && value.ends_with('"') && value.len() >= 2 { +                        value = &value[1..value.len()-1]; +                    } +                    Some((name, value)) +                } +            } +        }) +    } +} +  impl SputnikParts for Parts {      fn query<T: DeserializeOwned>(&self) -> Result<T,QueryError> {          serde_urlencoded::from_str::<T>(self.uri.query().unwrap_or("")).map_err(QueryError) @@ -41,28 +63,8 @@ impl SputnikParts for Parts {          self.extensions.get_mut::<HeaderMap>().unwrap()      } -    fn cookies(&mut self) -> Arc<HashMap<String, Cookie<'static>>> { -        let cookies: Option<&Arc<HashMap<String, Cookie>>> = self.extensions.get(); -        if let Some(cookies) = cookies { -            return cookies.clone(); -        } - -        let mut cookies = HashMap::new(); -        for header in self.headers.get_all(header::COOKIE) { -            let raw_str = match std::str::from_utf8(header.as_bytes()) { -                Ok(string) => string, -                Err(_) => continue -            }; - -            for cookie_str in raw_str.split(';').map(|s| s.trim()) { -                if let Ok(cookie) = Cookie::parse_encoded(cookie_str) { -                    cookies.insert(cookie.name().to_string(), cookie.into_owned()); -                } -            } -        } -        let cookies = Arc::new(cookies); -        self.extensions.insert(cookies.clone()); -        cookies +    fn cookies(&self) -> CookieIter { +        CookieIter(self.headers.get(header::COOKIE).and_then(|h| std::str::from_utf8(h.as_bytes()).ok()).unwrap_or("").split(';'))      }      fn enforce_content_type(&self, mime: Mime) -> Result<(), WrongContentTypeError> { @@ -83,10 +85,14 @@ pub struct Flash {      message: String,  } -impl From<Flash> for Cookie<'_> { +impl From<Flash> for Cookie {      fn from(flash: Flash) -> Self { -        Cookie::build(FLASH_COOKIE_NAME, format!("{}:{}", flash.name, flash.message)) -        .max_age(Duration::minutes(5)).finish() +        Cookie { +            name: FLASH_COOKIE_NAME.into(), +            value: format!("{}:{}", flash.name, flash.message), +            max_age: Some(Duration::from_secs(5 * 60)), +            ..Default::default() +        }      }  } @@ -94,15 +100,13 @@ impl Flash {      /// If the request has a flash cookie retrieve it and append a set-cookie      /// header to delete the cookie to [`SputnikParts::response_headers`].      pub fn from_request(req: &mut Parts) -> Option<Self> { -        req.cookies().get(FLASH_COOKIE_NAME) -        .and_then(|cookie| { -            req.response_headers().set_cookie(delete_cookie(FLASH_COOKIE_NAME)); -            let mut iter = cookie.value().splitn(2, ':'); -            if let (Some(name), Some(message)) = (iter.next(), iter.next()) { -                return Some(Flash{name: name.to_owned(), message: message.to_owned()}) -            } -            None -        }) +        let value = req.cookies().find(|(name, _value)| *name == FLASH_COOKIE_NAME)?.1.to_owned(); +        req.response_headers().set_cookie(delete_cookie(FLASH_COOKIE_NAME)); +        let mut iter = value.splitn(2, ':'); +        if let (Some(name), Some(message)) = (iter.next(), iter.next()) { +            return Some(Flash{name: name.to_owned(), message: message.to_owned()}) +        } +        None      }      /// Constructs a new Flash message. The name must not contain a colon (`:`). diff --git a/src/response.rs b/src/response.rs index cb87a80..c9d83a7 100644 --- a/src/response.rs +++ b/src/response.rs @@ -1,9 +1,6 @@  //! Provides convenience traits and functions to build HTTP responses. -use std::convert::TryInto; - -use cookie::Cookie; -use time::{Duration, OffsetDateTime}; +use std::{convert::TryInto, fmt::Display, time::{Duration, SystemTime}};  use crate::http::{self, HeaderMap, StatusCode, header, response::Builder}; @@ -16,6 +13,67 @@ pub trait SputnikBuilder {      fn set_cookie(self, cookie: Cookie) -> Builder;  } +#[derive(Default, Debug)] +pub struct Cookie { +    pub name: String, +    pub value: String, +    pub expires: Option<SystemTime>, +    pub max_age: Option<Duration>, +    pub domain: Option<String>, +    pub path: Option<String>, +    pub secure: Option<bool>, +    pub http_only: Option<bool>, +    pub same_site: Option<SameSite>, +} + +impl Display for Cookie { +    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +        write!(f, "{}={}", self.name, self.value)?; +        if let Some(true) = self.http_only { +            write!(f, "; HttpOnly")?; +        } +        if let Some(same_site) = &self.same_site { +            write!(f, "; SameSite={}", same_site)?; + +            if same_site == &SameSite::None && self.secure.is_none() { +                write!(f, "; Secure")?; +            } +        } +        if let Some(true) = self.secure { +            write!(f, "; Secure")?; +        } +        if let Some(path) = &self.path { +            write!(f, "; Path={}", path)?; +        } +        if let Some(domain) = &self.domain { +            write!(f, "; Domain={}", domain)?; +        } +        if let Some(max_age) = &self.max_age { +            write!(f, "; Max-Age={}", max_age.as_secs())?; +        } +        if let Some(time) = self.expires { +            write!(f, "; Expires={}", httpdate::fmt_http_date(time))?; +        } +         +        Ok(()) +    } +} + +#[derive(Debug, PartialEq)] +pub enum SameSite { +    Strict, Lax, None +} + +impl Display for SameSite { +    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +        match self { +            SameSite::Strict => write!(f, "Strict"), +            SameSite::Lax => write!(f, "Lax"), +            SameSite::None => write!(f, "None"), +        } +    } +} +  /// Creates a new builder with a given Location header and status code.  pub fn redirect(location: &str, code: StatusCode) -> Builder {      Builder::new().status(code).header(header::LOCATION, location) @@ -35,10 +93,12 @@ impl SputnikBuilder for Builder {  /// Constructs an expired cookie to delete a cookie.  pub fn delete_cookie(name: &str) -> Cookie { -    let mut cookie = Cookie::new(name, ""); -    cookie.set_max_age(Duration::seconds(0)); -    cookie.set_expires(OffsetDateTime::now_utc() - Duration::days(365)); -    cookie +    Cookie{ +        name: name.into(), +        max_age: Some(Duration::from_secs(0)), +        expires: Some(SystemTime::now() - Duration::from_secs(60*60*24)), +        ..Default::default() +    }  }  /// Adds convenience methods to [`HeaderMap`]. @@ -56,7 +116,7 @@ impl SputnikHeaders for HeaderMap {      }      fn set_cookie(&mut self, cookie: Cookie) { -        self.append(header::SET_COOKIE, cookie.encoded().to_string().try_into().unwrap()); +        self.append(header::SET_COOKIE, cookie.to_string().try_into().unwrap());      }  } @@ -73,8 +133,8 @@ mod tests {      #[test]      fn test_set_cookie() {          let mut map = HeaderMap::new(); -        map.set_cookie(Cookie::new("some", "cookie")); -        map.set_cookie(Cookie::new("some", "cookie")); +        map.set_cookie(Cookie{name: "some".into(), value: "cookie".into(), ..Default::default()}); +        map.set_cookie(Cookie{name: "some".into(), value: "cookie".into(), ..Default::default()});          assert_eq!(map.len(), 2);      } diff --git a/src/security.rs b/src/security.rs index abe114e..cd9d7bd 100644 --- a/src/security.rs +++ b/src/security.rs @@ -1,14 +1,13 @@  //! Provides [`Key`] and functions to encode & decode expiring claims. -use time::OffsetDateTime; -  pub use signed::Key; +pub use std::time::{SystemTime, UNIX_EPOCH};  mod signed;  /// Join a string and an expiry date together into a string. -pub fn encode_expiring_claim(claim: &str, expiry_date: OffsetDateTime) -> String { -    format!("{}:{}", claim, expiry_date.unix_timestamp()) +pub fn encode_expiring_claim(claim: &str, expiry_date: SystemTime) -> String { +    format!("{}:{}", claim, expiry_date.duration_since(UNIX_EPOCH).unwrap().as_secs())  }  /// Extract the string, failing if the expiry date is in the past. @@ -16,9 +15,9 @@ pub fn decode_expiring_claim(value: &str) -> Result<&str, &'static str> {      let mut parts = value.rsplitn(2, ':');      let expiry_date = parts.next().expect("first .rsplitn().next() is expected to return Some");      let claim = parts.next().ok_or("expected colon")?; -    let expiry_date: i64 = expiry_date.parse().map_err(|_| "failed to parse timestamp")?; +    let expiry_date: u64 = expiry_date.parse().map_err(|_| "failed to parse timestamp")?; -    if expiry_date > OffsetDateTime::now_utc().unix_timestamp() { +    if expiry_date > SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() {          Ok(claim)      } else {          Err("token is expired") @@ -27,15 +26,15 @@ pub fn decode_expiring_claim(value: &str) -> Result<&str, &'static str> {  #[cfg(test)]  mod tests { -    use time::{OffsetDateTime, Duration}; +    use std::time::{SystemTime, Duration};      #[test]      fn test_expiring_claim() {          for claim in vec!["test", "", "foo:bar"] { -            let encoded_claim = super::encode_expiring_claim(claim, OffsetDateTime::now_utc() + Duration::minutes(1)); +            let encoded_claim = super::encode_expiring_claim(claim, SystemTime::now() + Duration::from_secs(60));              assert_eq!(super::decode_expiring_claim(&encoded_claim).unwrap(), claim); -            let encoded_claim = super::encode_expiring_claim(claim, OffsetDateTime::now_utc() - Duration::minutes(1)); +            let encoded_claim = super::encode_expiring_claim(claim, SystemTime::now() - Duration::from_secs(60));              assert!(super::decode_expiring_claim(&encoded_claim).is_err());          }          assert!(super::decode_expiring_claim("test".into()).is_err());  | 
