diff options
Diffstat (limited to 'examples')
-rw-r--r-- | examples/csrf/src/main.rs | 40 |
1 files changed, 22 insertions, 18 deletions
diff --git a/examples/csrf/src/main.rs b/examples/csrf/src/main.rs index 915d063..e7e1bfa 100644 --- a/examples/csrf/src/main.rs +++ b/examples/csrf/src/main.rs @@ -1,10 +1,13 @@ use std::convert::Infallible; use hyper::service::{service_fn, make_service_fn}; -use hyper::{Method, Server, StatusCode}; +use hyper::{Method, Server, StatusCode, Body}; +use hyper::http::request::Parts; +use hyper::http::response::Builder; use serde::Deserialize; -use sputnik::security::CsrfToken; -use sputnik::{request::{Parts, Body}, response::Response}; -use sputnik::request::error::*; +use sputnik::{mime, request::{SputnikParts, SputnikBody}, response::SputnikBuilder}; +use sputnik::request::CsrfProtectedFormError; + +type Response = hyper::Response<Body>; #[derive(thiserror::Error, Debug)] enum Error { @@ -21,8 +24,8 @@ fn render_error(err: Error) -> (StatusCode, String) { } } -async fn route(req: &mut Parts, body: Body) -> Result<Response,Error> { - match (req.method(), req.uri().path()) { +async fn route(req: &mut Parts, body: Body) -> Result<Response, Error> { + match (&req.method, req.uri.path()) { (&Method::GET, "/form") => get_form(req).await, (&Method::POST, "/form") => post_form(req, body).await, _ => return Err(Error::NotFound("page not found".to_owned())) @@ -30,29 +33,30 @@ async fn route(req: &mut Parts, body: Body) -> Result<Response,Error> { } async fn get_form(req: &mut Parts) -> Result<Response, Error> { - let mut response = Response::new(); - let csrf_token = CsrfToken::from_parts(req, &mut response); - *response.body() = format!("<form method=post> - <input name=text>{}<button>Submit</button></form>", csrf_token.html_input()).into(); - Ok(response) + let mut response = Builder::new(); + let csrf_token = req.csrf_token(&mut response); + Ok(response.content_type(mime::TEXT_HTML).body( + format!("<form method=post> + <input name=text>{}<button>Submit</button></form>", csrf_token.html_input()).into() + ).unwrap()) } #[derive(Deserialize)] struct FormData {text: String} async fn post_form(req: &mut Parts, body: Body) -> Result<Response, Error> { - let mut response = Response::new(); - let csrf_token = CsrfToken::from_parts(req, &mut response); + let mut response = Builder::new(); + let csrf_token = req.csrf_token(&mut response); let msg: FormData = body.into_form_csrf(&csrf_token).await?; - *response.body() = format!("hello {}", msg.text).into(); - Ok(response) + Ok(response.body( + format!("hello {}", msg.text).into() + ).unwrap()) } -/// adapt between Hyper's types and Sputnik's convenience types async fn service(req: hyper::Request<hyper::Body>) -> Result<hyper::Response<hyper::Body>, Infallible> { - let (mut parts, body) = sputnik::request::adapt(req); + let (mut parts, body) = req.into_parts(); match route(&mut parts, body).await { - Ok(res) => Ok(res.into()), + Ok(res) => Ok(res), Err(err) => { let (code, message) = render_error(err); // you can easily wrap or log errors here |