diff options
author | Martin Fischer <martin@push-f.com> | 2021-07-04 00:25:01 +0200 |
---|---|---|
committer | Martin Fischer <martin@push-f.com> | 2021-07-04 11:19:41 +0200 |
commit | c0b8d6a9876a95bc5d8fd8a30333e65949f5c9d1 (patch) | |
tree | 40bfa7846a11fae026f45aefb3b1a87dbfa59ec6 | |
parent | 17d98a70a76efc02f643e9cfe13836120a2c5114 (diff) |
strictly enforce Host and Origin headers
Previously the Origin header was only checked if you specified an origin
with --origin on startup and when you didn't we just printed a warning
that this might make you vulnerable to CSRF attacks.
I implemented it this way since I wanted GitPad to be runnable without
any command-line options, but such warnings are of course suboptimal
for security since they can simply be ignored.
This commit changes this behavior so that the Origin header is always
checked for POST requests. If you just run "gitpad" the enforced origin
defaults to http://127.0.0.1:<port>. Additionally this commit also
enforces an exact Host header (extracted from the Origin) to prevent DNS
rebinding attacks.
-rw-r--r-- | src/main.rs | 106 | ||||
-rw-r--r-- | src/post_routes.rs | 27 |
2 files changed, 82 insertions, 51 deletions
diff --git a/src/main.rs b/src/main.rs index d3c88f6..513f59e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -28,6 +28,7 @@ use std::env; use std::path::Component; use std::path::Path; use std::path::PathBuf; +use std::str::FromStr; use std::sync::Arc; use url::Url; @@ -80,9 +81,9 @@ struct Args { #[clap(short, default_value = "8000")] port: u16, - /// Enforce the given HTTP Origin header value to prevent CSRF attacks. - #[clap(long, validator = validate_origin)] - origin: Option<String>, + /// e.g. https://example.com (used to enforce Host and Origin headers) + #[clap(long)] + origin: Option<HttpOrigin>, /// Serve via the given Unix domain socket path. #[cfg(unix)] @@ -90,33 +91,12 @@ struct Args { socket: Option<String>, } -fn validate_origin(input: &str) -> Result<(), String> { - let url = Url::parse(input).map_err(|e| e.to_string())?; - if url.scheme() != "http" && url.scheme() != "https" { - return Err("must start with http:// or https://".into()); - } - if url.path() != "/" { - return Err("must not have a path".into()); - } - if input.ends_with('/') { - return Err("must not end with a trailing slash".into()); - } - Ok(()) -} - #[tokio::main] async fn main() { let args = Args::parse(); let repo = Repository::open_bare(env::current_dir().unwrap()) .expect("expected current directory to be a bare Git repository"); - if args.origin.is_none() { - eprintln!( - "[warning] Running gitpad without --origin might \ - make you vulnerable to CSRF attacks." - ); - } - if args.multiuser { serve(MultiUserController::new(&repo), args).await; } else { @@ -136,22 +116,55 @@ async fn main() { } } +#[derive(Clone, Debug)] +struct HttpOrigin { + origin: String, + host_idx: usize, +} + +impl HttpOrigin { + /// Returns the Host header value (e.g. `example.com` for the origin `https://example.com`). + fn host(&self) -> &str { + &self.origin[self.host_idx..] + } +} + +impl FromStr for HttpOrigin { + type Err = &'static str; + + fn from_str(s: &str) -> Result<Self, Self::Err> { + let url = Url::parse(s).map_err(|_| "invalid URL")?; + if url.scheme() != "http" && url.scheme() != "https" { + return Err("expected http:// or https:// scheme"); + } + if url.path() != "/" { + return Err("path must be /".into()); + } + Ok(HttpOrigin { + origin: url.origin().ascii_serialization(), + host_idx: url.scheme().len() + "://".len(), + }) + } +} + async fn serve<C: Controller + Send + Sync + 'static>(controller: C, args: Args) { let controller = Arc::new(controller); - let args = Arc::new(args); - let server_args = args.clone(); #[cfg(unix)] - if let Some(socket_path) = &server_args.socket { + if let Some(socket_path) = &args.socket { + let origin: &'static HttpOrigin = Box::leak(Box::new( + args.origin + .expect("if you use --socket, you must specify an --origin"), + )); + // TODO: get rid of code duplication // we somehow need to specify the closure type or it gets too specific let service = make_service_fn(move |_| { let controller = controller.clone(); - let args = args.clone(); async move { Ok::<_, hyper::Error>(service_fn(move |req| { - service(controller.clone(), args.clone(), req) + service(origin, controller.clone(), req) })) } }); @@ -192,25 +205,29 @@ async fn serve<C: Controller + Send + Sync + 'static>(controller: C, args: Args) ); eprintln!(); + let addr = ([127, 0, 0, 1], args.port).into(); + let url = format!("http://{}", addr); + let origin: &'static HttpOrigin = Box::leak(Box::new( + args.origin.unwrap_or_else(|| url.parse().unwrap()), + )); + let service = make_service_fn(move |_| { let controller = controller.clone(); - let args = args.clone(); async move { Ok::<_, hyper::Error>(service_fn(move |req| { - service(controller.clone(), args.clone(), req) + service(&origin, controller.clone(), req) })) } }); - let addr = ([127, 0, 0, 1], server_args.port).into(); let server = Server::bind(&addr).serve(service); - println!("Listening on http://{}", addr); + println!("Listening on {}", url); server.await.expect("server error"); } async fn service<C: Controller>( + origin: &HttpOrigin, controller: Arc<C>, - args: Arc<Args>, request: Request, ) -> Result<HyperResponse, Infallible> { let (mut parts, body) = request.into_parts(); @@ -218,7 +235,7 @@ async fn service<C: Controller>( let mut script_csp = "'none'".into(); let mut frame_csp = "'none'".into(); - let mut resp = build_response(args, &*controller, &mut parts, body) + let mut resp = build_response(origin, &*controller, &mut parts, body) .await .map(|resp| match resp { Response::Raw(resp) => resp, @@ -344,11 +361,26 @@ impl Branch { } async fn build_response<C: Controller>( - args: Arc<Args>, + origin: &HttpOrigin, controller: &C, parts: &mut Parts, body: Body, ) -> Result<Response, Error> { + let host = parts + .headers + .get("Host") + .ok_or_else(|| Error::BadRequest("Host header required".into()))? + .to_str() + .unwrap(); + if host != origin.host() { + // We enforce an exact Host header to prevent DNS rebinding attacks. + return Err(Error::BadRequest(format!("<h1>Bad Request: Unknown Host header</h1>\ + Received the header <pre>Host: {}</pre> + But expected the header <pre>Host: {}</pre> \ + <p>If you want to serve GitPad under a different hostname you need to specify it on startup with <code>--origin</code>.</p>", + html_escape(host), html_escape(origin.host())))); + } + let unsanitized_path = percent_decode_str(parts.uri.path()) .decode_utf8() .map_err(|_| Error::BadRequest("failed to percent-decode path as UTF-8".into()))? @@ -392,7 +424,7 @@ async fn build_response<C: Controller>( } if parts.method == Method::POST { - return post_routes::build_response(&args, ¶ms, controller, ctx, body, parts).await; + return post_routes::build_response(origin, ¶ms, controller, ctx, body, parts).await; } let mut tree = ctx diff --git a/src/post_routes.rs b/src/post_routes.rs index 7588abc..1b5d615 100644 --- a/src/post_routes.rs +++ b/src/post_routes.rs @@ -18,32 +18,31 @@ use crate::forms::EditForm; use crate::forms::MoveForm; use crate::get_renderer; use crate::ActionParam; -use crate::Args; use crate::Context; +use crate::HttpOrigin; use crate::RenderMode; use crate::Response; use crate::{controller::Controller, Error}; pub(crate) async fn build_response<C: Controller>( - args: &Args, + host: &HttpOrigin, params: &ActionParam, controller: &C, ctx: Context, body: Body, parts: &mut Parts, ) -> Result<Response, Error> { - if let Some(ref enforced_origin) = args.origin { - if parts - .headers - .get(header::ORIGIN) - .filter(|h| h.as_bytes() == enforced_origin.as_bytes()) - .is_none() - { - return Err(Error::BadRequest(format!( - "POST requests must be sent with the header Origin: {}", - enforced_origin - ))); - } + if parts + .headers + .get(header::ORIGIN) + .filter(|h| h.as_bytes() == host.origin.as_bytes()) + .is_none() + { + // This check prevents cross-site request forgery (CSRF). + return Err(Error::BadRequest(format!( + "POST requests must be sent with the header Origin: {}", + host.origin + ))); } match params.action.as_ref() { "edit" => return update_blob(body, controller, ctx, parts).await, |