diff options
| -rw-r--r-- | README.md | 51 | ||||
| -rw-r--r-- | examples/form/src/main.rs | 51 | ||||
| -rw-r--r-- | src/hyper_body.rs | 19 | ||||
| -rw-r--r-- | src/lib.rs | 12 | ||||
| -rw-r--r-- | src/request.rs | 74 | ||||
| -rw-r--r-- | src/response.rs | 34 | ||||
| -rw-r--r-- | src/security.rs | 33 | ||||
| -rw-r--r-- | src/security/signed.rs | 8 | 
8 files changed, 179 insertions, 103 deletions
| @@ -35,14 +35,14 @@ header matches your domain name (especially if you have unauthenticated POST end  ## Hyper Example  ```rust -use std::convert::Infallible; -use hyper::service::{service_fn, make_service_fn}; -use hyper::{Method, Server, StatusCode, Body};  use hyper::http::request::Parts;  use hyper::http::response::Builder; +use hyper::service::{make_service_fn, service_fn}; +use hyper::{Body, Method, Server, StatusCode};  use serde::Deserialize; +use sputnik::hyper_body::{FormError, SputnikBody};  use sputnik::{html_escape, mime, request::SputnikParts, response::SputnikBuilder}; -use sputnik::hyper_body::{SputnikBody, FormError}; +use std::convert::Infallible;  type Response = hyper::Response<Body>; @@ -51,7 +51,7 @@ enum Error {      #[error("page not found")]      NotFound(String),      #[error("{0}")] -    FormError(#[from] FormError) +    FormError(#[from] FormError),  }  fn render_error(err: Error) -> (StatusCode, String) { @@ -65,33 +65,37 @@ async fn route(req: &mut Parts, body: Body) -> Result<Response, Error> {      match (&req.method, req.uri.path()) {          (&Method::GET, "/form") => Ok(get_form(req)),          (&Method::POST, "/form") => post_form(req, body).await, -        _ => return Err(Error::NotFound("page not found".to_owned())) +        _ => return Err(Error::NotFound("page not found".to_owned())),      }  }  fn get_form(_req: &mut Parts) -> Response {      Builder::new() -    .content_type(mime::TEXT_HTML) -    .body( -        "<form method=post><input name=text> <button>Submit</button></form>".into() -    ).unwrap() +        .content_type(mime::TEXT_HTML) +        .body("<form method=post><input name=text> <button>Submit</button></form>".into()) +        .unwrap()  }  #[derive(Deserialize)] -struct FormData {text: String} +struct FormData { +    text: String, +}  async fn post_form(_req: &mut Parts, body: Body) -> Result<Response, Error> {      let msg: FormData = body.into_form().await?; -    Ok(Builder::new().content_type(mime::TEXT_HTML).body( -        format!("hello <em>{}</em>", html_escape(msg.text)).into() -    ).unwrap()) +    Ok(Builder::new() +        .content_type(mime::TEXT_HTML) +        .body(format!("hello <em>{}</em>", html_escape(msg.text)).into()) +        .unwrap())  } -async fn service(req: hyper::Request<hyper::Body>) -> Result<hyper::Response<hyper::Body>, Infallible> { +async fn service( +    req: hyper::Request<hyper::Body>, +) -> Result<hyper::Response<hyper::Body>, Infallible> {      let (mut parts, body) = req.into_parts();      match route(&mut parts, body).await {          Ok(mut res) => { -            for (k,v) in parts.response_headers().iter() { +            for (k, v) in parts.response_headers().iter() {                  res.headers_mut().append(k, v.clone());              }              Ok(res) @@ -99,19 +103,18 @@ async fn service(req: hyper::Request<hyper::Body>) -> Result<hyper::Response<hyp          Err(err) => {              let (code, message) = render_error(err);              // you can easily wrap or log errors here -            Ok(hyper::Response::builder().status(code).body(message.into()).unwrap()) +            Ok(hyper::Response::builder() +                .status(code) +                .body(message.into()) +                .unwrap())          }      }  }  #[tokio::main]  async fn main() { -    let service = make_service_fn(move |_| { -        async move { -            Ok::<_, hyper::Error>(service_fn(move |req| { -                service(req) -            })) -        } +    let service = make_service_fn(move |_| async move { +        Ok::<_, hyper::Error>(service_fn(move |req| service(req)))      });      let addr = ([127, 0, 0, 1], 8000).into(); @@ -155,4 +158,4 @@ let userid = req.cookies().find(|(name, _value)| *name == "userid")  Tip: If you want to store multiple claims in the cookie, you can  (de)serialize a struct with [serde_json](https://docs.serde.rs/serde_json/).  This approach can pose a lightweight alternative to JWT, if you don't care -about the standardization aspect.
\ No newline at end of file +about the standardization aspect. diff --git a/examples/form/src/main.rs b/examples/form/src/main.rs index a63560c..af6b2af 100644 --- a/examples/form/src/main.rs +++ b/examples/form/src/main.rs @@ -1,11 +1,11 @@ -use std::convert::Infallible; -use hyper::service::{service_fn, make_service_fn}; -use hyper::{Method, Server, StatusCode, Body};  use hyper::http::request::Parts;  use hyper::http::response::Builder; +use hyper::service::{make_service_fn, service_fn}; +use hyper::{Body, Method, Server, StatusCode};  use serde::Deserialize; +use sputnik::hyper_body::{FormError, SputnikBody};  use sputnik::{html_escape, mime, request::SputnikParts, response::SputnikBuilder}; -use sputnik::hyper_body::{SputnikBody, FormError}; +use std::convert::Infallible;  type Response = hyper::Response<Body>; @@ -14,7 +14,7 @@ enum Error {      #[error("page not found")]      NotFound(String),      #[error("{0}")] -    FormError(#[from] FormError) +    FormError(#[from] FormError),  }  fn render_error(err: Error) -> (StatusCode, String) { @@ -28,33 +28,37 @@ async fn route(req: &mut Parts, body: Body) -> Result<Response, Error> {      match (&req.method, req.uri.path()) {          (&Method::GET, "/form") => Ok(get_form(req)),          (&Method::POST, "/form") => post_form(req, body).await, -        _ => return Err(Error::NotFound("page not found".to_owned())) +        _ => return Err(Error::NotFound("page not found".to_owned())),      }  }  fn get_form(_req: &mut Parts) -> Response {      Builder::new() -    .content_type(mime::TEXT_HTML) -    .body( -        "<form method=post><input name=text> <button>Submit</button></form>".into() -    ).unwrap() +        .content_type(mime::TEXT_HTML) +        .body("<form method=post><input name=text> <button>Submit</button></form>".into()) +        .unwrap()  }  #[derive(Deserialize)] -struct FormData {text: String} +struct FormData { +    text: String, +}  async fn post_form(_req: &mut Parts, body: Body) -> Result<Response, Error> {      let msg: FormData = body.into_form().await?; -    Ok(Builder::new().content_type(mime::TEXT_HTML).body( -        format!("hello <em>{}</em>", html_escape(msg.text)).into() -    ).unwrap()) +    Ok(Builder::new() +        .content_type(mime::TEXT_HTML) +        .body(format!("hello <em>{}</em>", html_escape(msg.text)).into()) +        .unwrap())  } -async fn service(req: hyper::Request<hyper::Body>) -> Result<hyper::Response<hyper::Body>, Infallible> { +async fn service( +    req: hyper::Request<hyper::Body>, +) -> Result<hyper::Response<hyper::Body>, Infallible> {      let (mut parts, body) = req.into_parts();      match route(&mut parts, body).await {          Ok(mut res) => { -            for (k,v) in parts.response_headers().iter() { +            for (k, v) in parts.response_headers().iter() {                  res.headers_mut().append(k, v.clone());              }              Ok(res) @@ -62,23 +66,22 @@ async fn service(req: hyper::Request<hyper::Body>) -> Result<hyper::Response<hyp          Err(err) => {              let (code, message) = render_error(err);              // you can easily wrap or log errors here -            Ok(hyper::Response::builder().status(code).body(message.into()).unwrap()) +            Ok(hyper::Response::builder() +                .status(code) +                .body(message.into()) +                .unwrap())          }      }  }  #[tokio::main]  async fn main() { -    let service = make_service_fn(move |_| { -        async move { -            Ok::<_, hyper::Error>(service_fn(move |req| { -                service(req) -            })) -        } +    let service = make_service_fn(move |_| async move { +        Ok::<_, hyper::Error>(service_fn(move |req| service(req)))      });      let addr = ([127, 0, 0, 1], 8000).into();      let server = Server::bind(&addr).serve(service);      println!("Listening on http://{}", addr);      server.await; -}
\ No newline at end of file +} diff --git a/src/hyper_body.rs b/src/hyper_body.rs index 50efdc3..7400633 100644 --- a/src/hyper_body.rs +++ b/src/hyper_body.rs @@ -15,17 +15,22 @@ impl EmptyBuilder<hyper::Body> for Builder {  /// Adds deserialization methods to `hyper::Body`.  pub trait SputnikBody {      /// Parses a `application/x-www-form-urlencoded` request body into a given struct. -    fn into_form<T: DeserializeOwned>(self) -> Pin<Box<dyn Future<Output=Result<T, FormError>> + Send + Sync>>; +    fn into_form<T: DeserializeOwned>( +        self, +    ) -> Pin<Box<dyn Future<Output = Result<T, FormError>> + Send + Sync>>;      /// Attempts to deserialize the request body as JSON.      #[cfg(feature = "hyper_body_json")]      #[cfg_attr(docsrs, doc(cfg(feature = "hyper_body_json")))] -    fn into_json<T: DeserializeOwned>(self) -> Pin<Box<dyn Future<Output=Result<T, JsonError>> + Send + Sync>>; +    fn into_json<T: DeserializeOwned>( +        self, +    ) -> Pin<Box<dyn Future<Output = Result<T, JsonError>> + Send + Sync>>;  }  impl SputnikBody for hyper::Body { - -    fn into_form<T: DeserializeOwned>(self) -> Pin<Box<dyn Future<Output=Result<T, FormError>> + Send + Sync>> { +    fn into_form<T: DeserializeOwned>( +        self, +    ) -> Pin<Box<dyn Future<Output = Result<T, FormError>> + Send + Sync>> {          Box::pin(async move {              let full_body = hyper::body::to_bytes(self).await.map_err(BodyError)?;              Ok(serde_urlencoded::from_bytes::<T>(&full_body)?) @@ -34,7 +39,9 @@ impl SputnikBody for hyper::Body {      #[cfg(feature = "hyper_body_json")]      #[cfg_attr(docsrs, doc(cfg(feature = "hyper_body_json")))] -    fn into_json<T: DeserializeOwned>(self) -> Pin<Box<dyn Future<Output=Result<T, JsonError>> + Send + Sync>> { +    fn into_json<T: DeserializeOwned>( +        self, +    ) -> Pin<Box<dyn Future<Output = Result<T, JsonError>> + Send + Sync>> {          Box::pin(async move {              let full_body = hyper::body::to_bytes(self).await.map_err(BodyError)?;              Ok(serde_json::from_slice::<T>(&full_body)?) @@ -64,4 +71,4 @@ pub enum JsonError {      #[error("json deserialize error: {0}")]      Deserialize(#[from] serde_json::Error), -}
\ No newline at end of file +} @@ -5,23 +5,23 @@  use std::borrow::Cow; -pub use mime;  pub use httpdate; +pub use mime;  pub mod request;  pub mod response; -#[cfg(feature="security")] +#[cfg(feature = "security")]  #[cfg_attr(docsrs, doc(cfg(feature = "security")))]  pub mod security; -#[cfg(feature="hyper_body")] +#[cfg(feature = "hyper_body")]  #[cfg_attr(docsrs, doc(cfg(feature = "hyper_body")))]  pub mod hyper_body; -#[cfg(not(feature="hyper_body"))] +#[cfg(not(feature = "hyper_body"))]  use http; -#[cfg(feature="hyper_body")] +#[cfg(feature = "hyper_body")]  use hyper::http;  /// HTML escapes the given string. @@ -52,4 +52,4 @@ pub fn html_escape<'a, S: Into<Cow<'a, str>>>(input: S) -> Cow<'a, str> {      } else {          input      } -}
\ No newline at end of file +} diff --git a/src/request.rs b/src/request.rs index 19292a1..6b8dbe1 100644 --- a/src/request.rs +++ b/src/request.rs @@ -5,13 +5,13 @@ use serde::Deserialize;  use std::str::Split;  use std::time::Duration; -use crate::response::{Cookie, SputnikHeaders, delete_cookie}; -use crate::http::{HeaderMap, header, request::Parts}; +use crate::http::{header, request::Parts, HeaderMap}; +use crate::response::{delete_cookie, Cookie, SputnikHeaders};  /// Adds convenience methods to [`http::request::Parts`](Parts).  pub trait SputnikParts {      /// Parses the query string of the request into a given struct. -    fn query<'a, X: Deserialize<'a>>(&'a self) -> Result<X,QueryError>; +    fn query<'a, X: Deserialize<'a>>(&'a self) -> Result<X, QueryError>;      /// Parses the cookies of the request.      fn cookies(&self) -> CookieIter; @@ -42,7 +42,7 @@ impl<'a> Iterator for CookieIter<'a> {                  None => self.next(),                  Some(mut value) => {                      if value.starts_with('"') && value.ends_with('"') && value.len() >= 2 { -                        value = &value[1..value.len()-1]; +                        value = &value[1..value.len() - 1];                      }                      Some((name, value))                  } @@ -52,7 +52,7 @@ impl<'a> Iterator for CookieIter<'a> {  }  impl SputnikParts for Parts { -    fn query<'a, T: Deserialize<'a>>(&'a self) -> Result<T,QueryError> { +    fn query<'a, T: Deserialize<'a>>(&'a self) -> Result<T, QueryError> {          serde_urlencoded::from_str::<T>(self.uri.query().unwrap_or("")).map_err(QueryError)      } @@ -64,16 +64,29 @@ impl SputnikParts for Parts {      }      fn cookies(&self) -> CookieIter { -        CookieIter(self.headers.get(header::COOKIE).and_then(|h| std::str::from_utf8(h.as_bytes()).ok()).unwrap_or("").split(';')) +        CookieIter( +            self.headers +                .get(header::COOKIE) +                .and_then(|h| std::str::from_utf8(h.as_bytes()).ok()) +                .unwrap_or("") +                .split(';'), +        )      }      fn enforce_content_type(&self, mime: Mime) -> Result<(), WrongContentTypeError> {          if let Some(content_type) = self.headers.get(header::CONTENT_TYPE) {              if *content_type == mime.to_string() { -                return Ok(()) +                return Ok(());              }          } -        Err(WrongContentTypeError{expected: mime, received: self.headers.get(header::CONTENT_TYPE).as_ref().and_then(|h| h.to_str().ok().map(|s| s.to_owned()))}) +        Err(WrongContentTypeError { +            expected: mime, +            received: self +                .headers +                .get(header::CONTENT_TYPE) +                .as_ref() +                .and_then(|h| h.to_str().ok().map(|s| s.to_owned())), +        })      }  } @@ -100,33 +113,50 @@ impl Flash {      /// If the request has a flash cookie retrieve it and append a set-cookie      /// header to delete the cookie to [`SputnikParts::response_headers`].      pub fn from_request(req: &mut Parts) -> Option<Self> { -        let value = req.cookies().find(|(name, _value)| *name == FLASH_COOKIE_NAME)?.1.to_owned(); -        req.response_headers().set_cookie(delete_cookie(FLASH_COOKIE_NAME)); +        let value = req +            .cookies() +            .find(|(name, _value)| *name == FLASH_COOKIE_NAME)? +            .1 +            .to_owned(); +        req.response_headers() +            .set_cookie(delete_cookie(FLASH_COOKIE_NAME));          let mut iter = value.splitn(2, ':');          if let (Some(name), Some(message)) = (iter.next(), iter.next()) { -            return Some(Flash{name: name.to_owned(), message: message.to_owned()}) +            return Some(Flash { +                name: name.to_owned(), +                message: message.to_owned(), +            });          }          None      }      /// Constructs a new Flash message. The name must not contain a colon (`:`).      pub fn new(name: String, message: String) -> Self { -        Flash{name, message} +        Flash { name, message }      }      /// Constructs a new "success" Flash message.      pub fn success(message: String) -> Self { -        Flash{name: "success".to_owned(), message} +        Flash { +            name: "success".to_owned(), +            message, +        }      }      /// Constructs a new "warning" Flash message.      pub fn warning(message: String) -> Self { -        Flash{name: "warning".to_owned(), message} +        Flash { +            name: "warning".to_owned(), +            message, +        }      }      /// Constructs a new "error" Flash message.      pub fn error(message: String) -> Self { -        Flash{name: "error".to_owned(), message} +        Flash { +            name: "error".to_owned(), +            message, +        }      }      /// Returns the name of the Flash message. @@ -140,7 +170,6 @@ impl Flash {      }  } -  #[derive(thiserror::Error, Debug)]  #[error("query deserialize error: {0}")]  pub struct QueryError(pub serde_urlencoded::de::Error); @@ -152,12 +181,11 @@ pub struct WrongContentTypeError {      pub received: Option<String>,  } -  #[cfg(test)]  mod tests {      use std::convert::TryInto; -    use crate::http::{Request, header}; +    use crate::http::{header, Request};      use super::SputnikParts; @@ -166,10 +194,14 @@ mod tests {          let (mut parts, _body) = Request::new("").into_parts();          assert!(parts.enforce_content_type(mime::APPLICATION_JSON).is_err()); -        parts.headers.append(header::CONTENT_TYPE, "application/json".try_into().unwrap()); +        parts +            .headers +            .append(header::CONTENT_TYPE, "application/json".try_into().unwrap());          assert!(parts.enforce_content_type(mime::APPLICATION_JSON).is_ok()); -        parts.headers.insert(header::CONTENT_TYPE, "text/html".try_into().unwrap()); +        parts +            .headers +            .insert(header::CONTENT_TYPE, "text/html".try_into().unwrap());          assert!(parts.enforce_content_type(mime::APPLICATION_JSON).is_err());      } -}
\ No newline at end of file +} diff --git a/src/response.rs b/src/response.rs index c9d83a7..8035c4f 100644 --- a/src/response.rs +++ b/src/response.rs @@ -1,8 +1,12 @@  //! Provides convenience traits and functions to build HTTP responses. -use std::{convert::TryInto, fmt::Display, time::{Duration, SystemTime}}; +use std::{ +    convert::TryInto, +    fmt::Display, +    time::{Duration, SystemTime}, +}; -use crate::http::{self, HeaderMap, StatusCode, header, response::Builder}; +use crate::http::{self, header, response::Builder, HeaderMap, StatusCode};  /// Adds convenience methods to [`Builder`].  pub trait SputnikBuilder { @@ -54,14 +58,16 @@ impl Display for Cookie {          if let Some(time) = self.expires {              write!(f, "; Expires={}", httpdate::fmt_http_date(time))?;          } -         +          Ok(())      }  }  #[derive(Debug, PartialEq)]  pub enum SameSite { -    Strict, Lax, None +    Strict, +    Lax, +    None,  }  impl Display for SameSite { @@ -76,7 +82,9 @@ impl Display for SameSite {  /// Creates a new builder with a given Location header and status code.  pub fn redirect(location: &str, code: StatusCode) -> Builder { -    Builder::new().status(code).header(header::LOCATION, location) +    Builder::new() +        .status(code) +        .header(header::LOCATION, location)  }  impl SputnikBuilder for Builder { @@ -93,10 +101,10 @@ impl SputnikBuilder for Builder {  /// Constructs an expired cookie to delete a cookie.  pub fn delete_cookie(name: &str) -> Cookie { -    Cookie{ +    Cookie {          name: name.into(),          max_age: Some(Duration::from_secs(0)), -        expires: Some(SystemTime::now() - Duration::from_secs(60*60*24)), +        expires: Some(SystemTime::now() - Duration::from_secs(60 * 60 * 24)),          ..Default::default()      }  } @@ -133,8 +141,16 @@ mod tests {      #[test]      fn test_set_cookie() {          let mut map = HeaderMap::new(); -        map.set_cookie(Cookie{name: "some".into(), value: "cookie".into(), ..Default::default()}); -        map.set_cookie(Cookie{name: "some".into(), value: "cookie".into(), ..Default::default()}); +        map.set_cookie(Cookie { +            name: "some".into(), +            value: "cookie".into(), +            ..Default::default() +        }); +        map.set_cookie(Cookie { +            name: "some".into(), +            value: "cookie".into(), +            ..Default::default() +        });          assert_eq!(map.len(), 2);      } diff --git a/src/security.rs b/src/security.rs index cd9d7bd..bc3b381 100644 --- a/src/security.rs +++ b/src/security.rs @@ -7,17 +7,30 @@ mod signed;  /// Join a string and an expiry date together into a string.  pub fn encode_expiring_claim(claim: &str, expiry_date: SystemTime) -> String { -    format!("{}:{}", claim, expiry_date.duration_since(UNIX_EPOCH).unwrap().as_secs()) +    format!( +        "{}:{}", +        claim, +        expiry_date.duration_since(UNIX_EPOCH).unwrap().as_secs() +    )  }  /// Extract the string, failing if the expiry date is in the past.  pub fn decode_expiring_claim(value: &str) -> Result<&str, &'static str> {      let mut parts = value.rsplitn(2, ':'); -    let expiry_date = parts.next().expect("first .rsplitn().next() is expected to return Some"); +    let expiry_date = parts +        .next() +        .expect("first .rsplitn().next() is expected to return Some");      let claim = parts.next().ok_or("expected colon")?; -    let expiry_date: u64 = expiry_date.parse().map_err(|_| "failed to parse timestamp")?; - -    if expiry_date > SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() { +    let expiry_date: u64 = expiry_date +        .parse() +        .map_err(|_| "failed to parse timestamp")?; + +    if expiry_date +        > SystemTime::now() +            .duration_since(UNIX_EPOCH) +            .unwrap() +            .as_secs() +    {          Ok(claim)      } else {          Err("token is expired") @@ -26,17 +39,19 @@ pub fn decode_expiring_claim(value: &str) -> Result<&str, &'static str> {  #[cfg(test)]  mod tests { -    use std::time::{SystemTime, Duration}; +    use std::time::{Duration, SystemTime};      #[test]      fn test_expiring_claim() {          for claim in vec!["test", "", "foo:bar"] { -            let encoded_claim = super::encode_expiring_claim(claim, SystemTime::now() + Duration::from_secs(60)); +            let encoded_claim = +                super::encode_expiring_claim(claim, SystemTime::now() + Duration::from_secs(60));              assert_eq!(super::decode_expiring_claim(&encoded_claim).unwrap(), claim); -            let encoded_claim = super::encode_expiring_claim(claim, SystemTime::now() - Duration::from_secs(60)); +            let encoded_claim = +                super::encode_expiring_claim(claim, SystemTime::now() - Duration::from_secs(60));              assert!(super::decode_expiring_claim(&encoded_claim).is_err());          }          assert!(super::decode_expiring_claim("test".into()).is_err());      } -}
\ No newline at end of file +} diff --git a/src/security/signed.rs b/src/security/signed.rs index 4da6760..2954954 100644 --- a/src/security/signed.rs +++ b/src/security/signed.rs @@ -1,4 +1,4 @@ -use hmac::{Hmac,NewMac,Mac}; +use hmac::{Hmac, Mac, NewMac};  use sha2::Sha256;  const SIGNED_KEY_LEN: usize = 32; @@ -9,11 +9,11 @@ const BASE64_DIGEST_LEN: usize = 44;  /// This code was adapted from the [`cookie`] crate which does not make the sign and verify functions public  /// forcing the use of [`CookieJar`](cookie::CookieJar)s, which are akward to work with without a high-level framework.  // Thanks to Sergio Benitez for writing the original code and releasing it under MIT! -pub struct Key (pub [u8; SIGNED_KEY_LEN]); +pub struct Key(pub [u8; SIGNED_KEY_LEN]);  impl Key {      const fn zero() -> Self { -        Key ( [0; SIGNED_KEY_LEN]) +        Key([0; SIGNED_KEY_LEN])      }      /// Attempts to generate signing/encryption keys from a secure, random @@ -69,4 +69,4 @@ impl Key {              .map(|_| value.to_string())              .map_err(|_| "value did not verify".to_string())      } -}
\ No newline at end of file +} | 
