aboutsummaryrefslogtreecommitdiff
path: root/src/request.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/request.rs')
-rw-r--r--src/request.rs159
1 files changed, 159 insertions, 0 deletions
diff --git a/src/request.rs b/src/request.rs
new file mode 100644
index 0000000..0c25e67
--- /dev/null
+++ b/src/request.rs
@@ -0,0 +1,159 @@
+//! Provides the [`Parts`] and [`Body`] convenience wrappers.
+
+use cookie::Cookie;
+use header::CONTENT_TYPE;
+use mime::{APPLICATION_WWW_FORM_URLENCODED, Mime};
+use serde::{Deserialize, de::DeserializeOwned};
+use hyper::{body::Bytes, header};
+use hyper::http::request::Parts as ReqParts;
+use std::collections::HashMap;
+
+use crate::{Error, security};
+
+type HyperRequest = hyper::Request<hyper::Body>;
+
+/// Convenience wrapper around [`hyper::Body`].
+pub struct Body {
+ body: hyper::Body,
+ content_type: Option<header::HeaderValue>,
+}
+
+/// Convert [`hyper::Request`] to ([`Parts`], [`Body`])
+pub fn adapt<'a>(req: HyperRequest) -> (Parts, Body) {
+ let (parts, body) = req.into_parts();
+ let body = Body{body, content_type: parts.headers.get(CONTENT_TYPE).map(|x| x.to_owned())};
+ let parts = Parts{parts, cookies: None};
+ (parts, body)
+}
+
+/// Convenience wrapper around [`hyper::http::request::Parts`].
+pub struct Parts {
+ parts: ReqParts,
+ cookies: Option<HashMap<String,Cookie<'static>>>,
+}
+
+#[derive(Deserialize)]
+struct CsrfData {
+ csrf: String,
+}
+
+impl Parts {
+ pub fn cookies(&mut self) -> &HashMap<String,Cookie> {
+ if let Some(ref cookies) = self.cookies {
+ return cookies
+ }
+ let mut cookies = HashMap::new();
+ for header in self.parts.headers.get_all(header::COOKIE) {
+ let raw_str = match std::str::from_utf8(header.as_bytes()) {
+ Ok(string) => string,
+ Err(_) => continue
+ };
+
+ for cookie_str in raw_str.split(';').map(|s| s.trim()) {
+ if let Ok(cookie) = Cookie::parse_encoded(cookie_str) {
+ cookies.insert(cookie.name().to_string(), cookie.into_owned());
+ }
+ }
+ }
+ self.cookies = Some(cookies);
+ &self.cookies.as_ref().unwrap()
+ }
+
+ pub fn method(&self) -> &hyper::Method {
+ &self.parts.method
+ }
+
+ pub fn headers(&self) -> &hyper::HeaderMap<header::HeaderValue> {
+ &self.parts.headers
+ }
+
+ pub fn uri(&self) -> &hyper::Uri {
+ &self.parts.uri
+ }
+
+ /// Parses the query string of the request into a given struct.
+ pub fn query<T: DeserializeOwned>(&self) -> Result<T,Error> {
+ serde_urlencoded::from_str::<T>(self.parts.uri.query().unwrap_or("")).map_err(|e|Error::bad_request(e.to_string()))
+ }
+}
+
+impl Body {
+ pub async fn into_bytes(self) -> Result<Bytes,Error> {
+ hyper::body::to_bytes(self.body).await.map_err(|_|Error::internal("failed to read body".to_string()))
+ }
+
+ /// Parses a `application/x-www-form-urlencoded` request body into a given struct.
+ ///
+ /// This does make you vulnerable to CSRF so you normally want to use
+ /// [`parse_form_csrf()`] instead.
+ ///
+ /// # Example
+ ///
+ /// ```
+ /// use hyper::{Response};
+ /// use sputnik::{request::Body, Error};
+ /// use serde::Deserialize;
+ ///
+ /// #[derive(Deserialize)]
+ /// struct Message {text: String, year: i64}
+ ///
+ /// async fn greet(body: Body) -> Result<Response<hyper::Body>, Error> {
+ /// let msg: Message = body.into_form().await?;
+ /// Ok(Response::new(format!("hello {}", msg.text).into()))
+ /// }
+ /// ```
+ pub async fn into_form<T: DeserializeOwned>(self) -> Result<T,Error> {
+ self.enforce_content_type(APPLICATION_WWW_FORM_URLENCODED)?;
+ let full_body = self.into_bytes().await?;
+ serde_urlencoded::from_bytes::<T>(&full_body).map_err(|e|Error::bad_request(e.to_string()))
+ }
+
+ /// Parses a `application/x-www-form-urlencoded` request body into a given struct.
+ /// Protects from CSRF by checking that the request body contains the same token retrieved from the cookies.
+ ///
+ /// The CSRF parameter is expected as the `csrf` parameter in the request body.
+ /// This means for HTML forms you need to embed the token as a hidden input.
+ ///
+ /// # Example
+ ///
+ /// ```
+ /// use hyper::{Method};
+ /// use sputnik::{request::{Parts, Body}, response::Response, Error};
+ /// use sputnik::security::CsrfToken;
+ /// use serde::Deserialize;
+ ///
+ /// #[derive(Deserialize)]
+ /// struct Message {text: String}
+ ///
+ /// async fn greet(req: &mut Parts, body: Body) -> Result<Response, Error> {
+ /// let mut response = Response::new();
+ /// let csrf_token = CsrfToken::from_parts(req, &mut response);
+ /// *response.body() = match (req.method()) {
+ /// &Method::GET => format!("<form method=post>
+ /// <input name=text>{}<button>Submit</button></form>", csrf_token.html_input()).into(),
+ /// &Method::POST => {
+ /// let msg: Message = body.into_form_csrf(&csrf_token).await?;
+ /// format!("hello {}", msg.text).into()
+ /// },
+ /// _ => return Err(Error::method_not_allowed("only GET and POST allowed".to_owned())),
+ /// };
+ /// Ok(response)
+ /// }
+ /// ```
+ pub async fn into_form_csrf<T: DeserializeOwned>(self, csrf_token: &security::CsrfToken) -> Result<T,Error> {
+ self.enforce_content_type(APPLICATION_WWW_FORM_URLENCODED)?;
+ let full_body = self.into_bytes().await?;
+ let csrf_data = serde_urlencoded::from_bytes::<CsrfData>(&full_body).map_err(|_|Error::bad_request("no csrf token".to_string()))?;
+ csrf_token.matches(csrf_data.csrf)?;
+ serde_urlencoded::from_bytes::<T>(&full_body).map_err(|e|Error::bad_request(e.to_string()))
+ }
+
+ fn enforce_content_type(&self, mime: Mime) -> Result<(),Error> {
+ if let Some(content_type) = &self.content_type {
+ if *content_type == mime.to_string() {
+ return Ok(())
+ }
+ }
+ Err(Error::bad_request(format!("expected content-type: {}", mime)))
+ }
+} \ No newline at end of file