diff options
Diffstat (limited to 'src/lib.rs')
-rw-r--r-- | src/lib.rs | 207 |
1 files changed, 207 insertions, 0 deletions
diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..9ea1c0f --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,207 @@ +//! A lightweight layer on top of [Hyper](https://hyper.rs/) +//! to facilitate building web applications. +use std::collections::HashMap; + +pub use error::Error; +pub use mime; +use cookie::Cookie; +use header::HeaderName; +use mime::{APPLICATION_WWW_FORM_URLENCODED, Mime}; +use serde::{Deserialize, de::DeserializeOwned}; +use hyper::{Body, StatusCode, body::Bytes, header::{self, HeaderValue}, http::request::Parts}; +use time::{Duration, OffsetDateTime}; + +pub use httpdate; + +type HyperRequest = hyper::Request<Body>; +type HyperResponse = hyper::Response<Body>; + +pub mod security; +mod error; +mod signed; + +/// Convenience wrapper around [`hyper::Request`]. +pub struct Request { + body: Body, + parts: Parts, + cookies: Option<HashMap<String,Cookie<'static>>>, +} + +impl From<HyperRequest> for Request { + fn from(req: HyperRequest) -> Self { + let (parts, body) = req.into_parts(); + Request{parts, body, cookies: None} + } +} + +impl Into<HyperResponse> for Response { + fn into(self) -> HyperResponse { + self.res + } +} + +fn enforce_content_type(req: &Parts, mime: Mime) -> Result<(),Error> { + let received_type = req.headers.get(header::CONTENT_TYPE).ok_or(Error::bad_request(format!("expected content-type: {}", mime)))?; + if *received_type != mime.to_string() { + return Err(Error::bad_request(format!("expected content-type: {}", mime))) + } + Ok(()) +} + +#[derive(Deserialize)] +struct CsrfData { + csrf: String, +} + +impl Request { + pub fn cookies(&mut self) -> &HashMap<String,Cookie> { + if let Some(ref cookies) = self.cookies { + return cookies + } + let mut cookies = HashMap::new(); + for header in self.parts.headers.get_all(header::COOKIE) { + let raw_str = match std::str::from_utf8(header.as_bytes()) { + Ok(string) => string, + Err(_) => continue + }; + + for cookie_str in raw_str.split(';').map(|s| s.trim()) { + if let Ok(cookie) = Cookie::parse_encoded(cookie_str) { + cookies.insert(cookie.name().to_string(), cookie.into_owned()); + } + } + } + self.cookies = Some(cookies); + &self.cookies.as_ref().unwrap() + } + + pub fn method(&self) -> &hyper::Method { + &self.parts.method + } + + pub fn uri(&self) -> &hyper::Uri { + &self.parts.uri + } + + /// Parses a `application/x-www-form-urlencoded` request body into a given struct. + /// + /// This does make you vulnerable to CSRF so you normally want to use + /// [`parse_form_csrf()`] instead. + /// + /// # Example + /// + /// ``` + /// use hyper::{Response, Body}; + /// use sputnik::{Request, Error}; + /// use serde::Deserialize; + /// + /// #[derive(Deserialize)] + /// struct Message {text: String, year: i64} + /// + /// async fn greet(req: &mut Request) -> Result<Response<Body>, Error> { + /// let msg: Message = req.into_form().await?; + /// Ok(Response::new(format!("hello {}", msg.text).into())) + /// } + /// ``` + pub async fn into_form<T: DeserializeOwned>(&mut self) -> Result<T,Error> { + enforce_content_type(&self.parts, APPLICATION_WWW_FORM_URLENCODED)?; + let full_body = self.into_body().await?; + serde_urlencoded::from_bytes::<T>(&full_body).map_err(|e|Error::bad_request(e.to_string())) + } + + /// Parses a `application/x-www-form-urlencoded` request body into a given struct. + /// Protects from CSRF by checking that the request body contains the same token retrieved from the cookies. + /// + /// The CSRF parameter is expected as the `csrf` parameter in the request body. + /// This means for HTML forms you need to embed the token as a hidden input. + /// + /// # Example + /// + /// ``` + /// use hyper::{Method}; + /// use sputnik::{Request, Response, Error}; + /// use sputnik::security::CsrfToken; + /// use serde::Deserialize; + /// + /// #[derive(Deserialize)] + /// struct Message {text: String} + /// + /// async fn greet(req: &mut Request) -> Result<Response, Error> { + /// let mut response = Response::new(); + /// let csrf_token = CsrfToken::from_request(req, &mut response); + /// *response.body() = match (req.method()) { + /// &Method::GET => format!("<form method=post> + /// <input name=text>{}<button>Submit</button></form>", csrf_token.html_input()).into(), + /// &Method::POST => { + /// let msg: Message = req.into_form_csrf(&csrf_token).await?; + /// format!("hello {}", msg.text).into() + /// }, + /// _ => return Err(Error::method_not_allowed("only GET and POST allowed".to_owned())), + /// }; + /// Ok(response) + /// } + /// ``` + pub async fn into_form_csrf<T: DeserializeOwned>(&mut self, csrf_token: &security::CsrfToken) -> Result<T,Error> { + enforce_content_type(&self.parts, APPLICATION_WWW_FORM_URLENCODED)?; + let full_body = self.into_body().await?; + let csrf_data = serde_urlencoded::from_bytes::<CsrfData>(&full_body).map_err(|_|Error::bad_request("no csrf token".to_string()))?; + csrf_token.matches(csrf_data.csrf)?; + serde_urlencoded::from_bytes::<T>(&full_body).map_err(|e|Error::bad_request(e.to_string())) + } + + pub async fn into_body(&mut self) -> Result<Bytes,Error> { + hyper::body::to_bytes(&mut self.body).await.map_err(|_|Error::internal("failed to read body".to_string())) + } + + /// Parses the query string of the request into a given struct. + pub fn query<T: DeserializeOwned>(&self) -> Result<T,Error> { + serde_urlencoded::from_str::<T>(self.parts.uri.query().unwrap_or("")).map_err(|e|Error::bad_request(e.to_string())) + } +} + +/// Convenience wrapper around [`hyper::Response`]. +pub struct Response { + res: HyperResponse +} + +impl Response { + pub fn new() -> Self { + Response{res: HyperResponse::new(Body::empty())} + } + + pub fn status(&mut self) -> &mut StatusCode { + self.res.status_mut() + } + + pub fn body(&mut self) -> &mut Body { + self.res.body_mut() + } + + pub fn headers(&mut self) -> &mut hyper::HeaderMap<header::HeaderValue> { + self.res.headers_mut() + } + + pub fn set_header<S: AsRef<str>>(&mut self, header: HeaderName, value: S) { + self.res.headers_mut().insert(header, HeaderValue::from_str(value.as_ref()).unwrap()); + } + + pub fn set_content_type(&mut self, mime: mime::Mime) { + self.res.headers_mut().insert(header::CONTENT_TYPE, mime.to_string().parse().unwrap()); + } + + pub fn redirect<S: AsRef<str>>(&mut self, location: S, code: StatusCode) { + *self.res.status_mut() = code; + self.set_header(header::LOCATION, location); + } + + pub fn set_cookie(&mut self, cookie: Cookie) { + self.res.headers_mut().append(header::SET_COOKIE, cookie.encoded().to_string().parse().unwrap()); + } + + pub fn delete_cookie(&mut self, name: &str) { + let mut cookie = Cookie::new(name, ""); + cookie.set_max_age(Duration::seconds(0)); + cookie.set_expires(OffsetDateTime::now_utc() - Duration::days(365)); + self.set_cookie(cookie); + } +}
\ No newline at end of file |