aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMartin Fischer <martin@push-f.com>2021-01-26 14:37:04 +0100
committerMartin Fischer <martin@push-f.com>2021-01-26 15:48:17 +0100
commitfc15b41a37e123434ec39a277f107b78c1507bd8 (patch)
tree78d450bc04ef64e59a636a37c2147ec9bffba40c
parent8e9a4400ea9bcb80c90232fecc2ad2ae5f6c3303 (diff)
introduce SputnikParts::response_headers
-rw-r--r--README.md19
-rw-r--r--examples/csrf/src/main.rs19
-rw-r--r--src/request.rs36
-rw-r--r--src/response.rs49
4 files changed, 74 insertions, 49 deletions
diff --git a/README.md b/README.md
index 4ee6456..56078ba 100644
--- a/README.md
+++ b/README.md
@@ -58,11 +58,13 @@ async fn route(req: &mut Parts, body: Body) -> Result<Response, Error> {
}
fn get_form(req: &mut Parts) -> Response {
- let mut response = Builder::new();
- let csrf_input = req.csrf_token(&mut response).html_input();
- response.content_type(mime::TEXT_HTML).body(
- format!("<form method=post>
- <input name=text>{}<button>Submit</button></form>", csrf_input).into()
+ Builder::new()
+ .content_type(mime::TEXT_HTML)
+ .body(
+ format!(
+ "<form method=post><input name=text>{}<button>Submit</button></form>",
+ req.csrf_token().html_input()
+ ).into()
).unwrap()
}
@@ -79,7 +81,12 @@ async fn post_form(req: &mut Parts, body: Body) -> Result<Response, Error> {
async fn service(req: hyper::Request<hyper::Body>) -> Result<hyper::Response<hyper::Body>, Infallible> {
let (mut parts, body) = req.into_parts();
match route(&mut parts, body).await {
- Ok(res) => Ok(res),
+ Ok(mut res) => {
+ for (k,v) in parts.response_headers().iter() {
+ res.headers_mut().append(k, v.clone());
+ }
+ Ok(res)
+ }
Err(err) => {
let (code, message) = render_error(err);
// you can easily wrap or log errors here
diff --git a/examples/csrf/src/main.rs b/examples/csrf/src/main.rs
index 1048689..7259abd 100644
--- a/examples/csrf/src/main.rs
+++ b/examples/csrf/src/main.rs
@@ -33,11 +33,13 @@ async fn route(req: &mut Parts, body: Body) -> Result<Response, Error> {
}
fn get_form(req: &mut Parts) -> Response {
- let mut response = Builder::new();
- let csrf_input = req.csrf_token(&mut response).html_input();
- response.content_type(mime::TEXT_HTML).body(
- format!("<form method=post>
- <input name=text>{}<button>Submit</button></form>", csrf_input).into()
+ Builder::new()
+ .content_type(mime::TEXT_HTML)
+ .body(
+ format!(
+ "<form method=post><input name=text>{}<button>Submit</button></form>",
+ req.csrf_token().html_input()
+ ).into()
).unwrap()
}
@@ -54,7 +56,12 @@ async fn post_form(req: &mut Parts, body: Body) -> Result<Response, Error> {
async fn service(req: hyper::Request<hyper::Body>) -> Result<hyper::Response<hyper::Body>, Infallible> {
let (mut parts, body) = req.into_parts();
match route(&mut parts, body).await {
- Ok(res) => Ok(res),
+ Ok(mut res) => {
+ for (k,v) in parts.response_headers().iter() {
+ res.headers_mut().append(k, v.clone());
+ }
+ Ok(res)
+ }
Err(err) => {
let (code, message) = render_error(err);
// you can easily wrap or log errors here
diff --git a/src/request.rs b/src/request.rs
index 947142c..e64f9dd 100644
--- a/src/request.rs
+++ b/src/request.rs
@@ -3,13 +3,13 @@
use cookie::Cookie;
use mime::Mime;
use serde::{Deserialize, de::DeserializeOwned};
-use hyper::{body::Bytes, header, http::{request::Parts, response::Builder}};
+use hyper::{HeaderMap, body::Bytes, header, http::request::Parts};
use time::Duration;
use std::{collections::HashMap, sync::Arc};
use rand::{Rng, distributions::Alphanumeric};
use async_trait::async_trait;
-use crate::response::SputnikBuilder;
+use crate::response::SputnikHeaders;
const CSRF_COOKIE_NAME : &str = "csrf";
@@ -23,13 +23,22 @@ pub trait SputnikParts {
/// Enforces a specific Content-Type.
fn enforce_content_type(&self, mime: Mime) -> Result<(), WrongContentTypeError>;
+ /// A map of response headers to allow methods of this trait to set response
+ /// headers without needing to take a [`Response`](hyper::http::response::Response) as an argument.
+ ///
+ /// You need to take care to append these headers to the response yourself.
+ /// This is intended to be done after your routing logic so that your
+ /// individual request handlers don't have to worry about it.
+ fn response_headers(&mut self) -> &mut HeaderMap;
+
/// Returns a CSRF token, either extracted from the `csrf` cookie or newly
- /// generated if the cookie wasn't sent (in which case the cookie is set).
+ /// generated if the cookie wasn't sent (in which case a set-cookie header is
+ /// appended to [`Self::response_headers`]).
///
/// If there is no cookie, calling this method multiple times only generates
/// a new token on the first call, further calls return the previously
/// generated token.
- fn csrf_token(&mut self, builder: &mut Builder) -> CsrfToken;
+ fn csrf_token(&mut self) -> CsrfToken;
}
impl SputnikParts for hyper::http::request::Parts {
@@ -37,6 +46,13 @@ impl SputnikParts for hyper::http::request::Parts {
serde_urlencoded::from_str::<T>(self.uri.query().unwrap_or("")).map_err(QueryError)
}
+ fn response_headers(&mut self) -> &mut HeaderMap {
+ if self.extensions.get::<HeaderMap>().is_none() {
+ self.extensions.insert(HeaderMap::new());
+ }
+ self.extensions.get_mut::<HeaderMap>().unwrap()
+ }
+
fn cookies(&mut self) -> Arc<HashMap<String, Cookie<'static>>> {
let cookies: Option<&Arc<HashMap<String, Cookie>>> = self.extensions.get();
if let Some(cookies) = cookies {
@@ -70,7 +86,7 @@ impl SputnikParts for hyper::http::request::Parts {
Err(WrongContentTypeError{expected: mime, received: self.headers.get(header::CONTENT_TYPE).as_ref().and_then(|h| h.to_str().ok().map(|s| s.to_owned()))})
}
- fn csrf_token(&mut self, builder: &mut Builder) -> CsrfToken {
+ fn csrf_token(&mut self) -> CsrfToken {
if let Some(token) = self.extensions.get::<CsrfToken>() {
return token.clone()
}
@@ -83,7 +99,8 @@ impl SputnikParts for hyper::http::request::Parts {
let mut c = Cookie::new(CSRF_COOKIE_NAME, token.clone());
c.set_secure(Some(true));
c.set_max_age(Some(Duration::hours(1)));
- builder.set_cookie(c);
+
+ self.response_headers().set_cookie(c);
let token = CsrfToken(token);
self.extensions.insert(token.clone());
token
@@ -240,10 +257,9 @@ mod tests {
#[test]
fn test_csrf_token() {
let mut parts = Request::new(hyper::Body::empty()).into_parts().0;
- let mut builder = Builder::new();
- let tok1 = parts.csrf_token(&mut builder);
- let tok2 = parts.csrf_token(&mut builder);
+ let tok1 = parts.csrf_token();
+ let tok2 = parts.csrf_token();
assert_eq!(tok1.to_string(), tok2.to_string());
- assert_eq!(builder.body(hyper::Body::empty()).unwrap().headers().len(), 1);
+ assert_eq!(parts.response_headers().len(), 1);
}
} \ No newline at end of file
diff --git a/src/response.rs b/src/response.rs
index 78e9a70..ceb0f61 100644
--- a/src/response.rs
+++ b/src/response.rs
@@ -3,31 +3,39 @@
use std::convert::TryInto;
use cookie::Cookie;
-use hyper::{StatusCode, header, http};
+use hyper::{HeaderMap, StatusCode, header, http};
use time::{Duration, OffsetDateTime};
use hyper::http::response::Builder;
pub trait SputnikBuilder {
- /// Appends a Set-Cookie header.
- fn set_cookie(&mut self, cookie: Cookie);
-
- /// Appends a Set-Cookie header to delete a cookie.
- fn delete_cookie(&mut self, name: &str);
-
/// Sets the Content-Type.
fn content_type(self, mime: mime::Mime) -> Builder;
}
-
pub fn redirect(location: &str, code: StatusCode) -> Builder {
Builder::new().status(code).header(header::LOCATION, location)
}
impl SputnikBuilder for Builder {
- fn set_cookie(&mut self, cookie: Cookie) {
+ fn content_type(mut self, mime: mime::Mime) -> Self {
if let Some(headers) = self.headers_mut() {
- headers.append(header::SET_COOKIE, cookie.encoded().to_string().try_into().unwrap());
+ headers.insert(header::CONTENT_TYPE, mime.to_string().try_into().unwrap());
}
+ self
+ }
+}
+
+pub trait SputnikHeaders {
+ /// Appends a Set-Cookie header.
+ fn set_cookie(&mut self, cookie: Cookie);
+
+ /// Appends a Set-Cookie header to delete a cookie.
+ fn delete_cookie(&mut self, name: &str);
+}
+
+impl SputnikHeaders for HeaderMap {
+ fn set_cookie(&mut self, cookie: Cookie) {
+ self.append(header::SET_COOKIE, cookie.encoded().to_string().try_into().unwrap());
}
fn delete_cookie(&mut self, name: &str) {
@@ -36,13 +44,6 @@ impl SputnikBuilder for Builder {
cookie.set_expires(OffsetDateTime::now_utc() - Duration::days(365));
self.set_cookie(cookie);
}
-
- fn content_type(mut self, mime: mime::Mime) -> Self {
- if let Some(headers) = self.headers_mut() {
- headers.insert(header::CONTENT_TYPE, mime.to_string().try_into().unwrap());
- }
- self
- }
}
pub trait EmptyBuilder<B> {
@@ -62,16 +63,10 @@ mod tests {
#[test]
fn test_set_cookie() {
- let mut builder = Builder::new();
- builder.set_cookie(Cookie::new("some", "cookie"));
- builder.set_cookie(Cookie::new("some", "cookie"));
- let resp = builder.body(hyper::Body::empty()).unwrap();
- assert_eq!(resp.headers().len(), 2);
-
- let mut builder = Builder::new()
- .header("foo", "invalid\r\n");
- // doesn't panic after invalid header
- builder.set_cookie(Cookie::new("some", "cookie"));
+ let mut map = HeaderMap::new();
+ map.set_cookie(Cookie::new("some", "cookie"));
+ map.set_cookie(Cookie::new("some", "cookie"));
+ assert_eq!(map.len(), 2);
}
#[test]