diff options
Diffstat (limited to 'src/main.rs')
-rw-r--r-- | src/main.rs | 106 |
1 files changed, 69 insertions, 37 deletions
diff --git a/src/main.rs b/src/main.rs index d3c88f6..513f59e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -28,6 +28,7 @@ use std::env; use std::path::Component; use std::path::Path; use std::path::PathBuf; +use std::str::FromStr; use std::sync::Arc; use url::Url; @@ -80,9 +81,9 @@ struct Args { #[clap(short, default_value = "8000")] port: u16, - /// Enforce the given HTTP Origin header value to prevent CSRF attacks. - #[clap(long, validator = validate_origin)] - origin: Option<String>, + /// e.g. https://example.com (used to enforce Host and Origin headers) + #[clap(long)] + origin: Option<HttpOrigin>, /// Serve via the given Unix domain socket path. #[cfg(unix)] @@ -90,33 +91,12 @@ struct Args { socket: Option<String>, } -fn validate_origin(input: &str) -> Result<(), String> { - let url = Url::parse(input).map_err(|e| e.to_string())?; - if url.scheme() != "http" && url.scheme() != "https" { - return Err("must start with http:// or https://".into()); - } - if url.path() != "/" { - return Err("must not have a path".into()); - } - if input.ends_with('/') { - return Err("must not end with a trailing slash".into()); - } - Ok(()) -} - #[tokio::main] async fn main() { let args = Args::parse(); let repo = Repository::open_bare(env::current_dir().unwrap()) .expect("expected current directory to be a bare Git repository"); - if args.origin.is_none() { - eprintln!( - "[warning] Running gitpad without --origin might \ - make you vulnerable to CSRF attacks." - ); - } - if args.multiuser { serve(MultiUserController::new(&repo), args).await; } else { @@ -136,22 +116,55 @@ async fn main() { } } +#[derive(Clone, Debug)] +struct HttpOrigin { + origin: String, + host_idx: usize, +} + +impl HttpOrigin { + /// Returns the Host header value (e.g. `example.com` for the origin `https://example.com`). + fn host(&self) -> &str { + &self.origin[self.host_idx..] + } +} + +impl FromStr for HttpOrigin { + type Err = &'static str; + + fn from_str(s: &str) -> Result<Self, Self::Err> { + let url = Url::parse(s).map_err(|_| "invalid URL")?; + if url.scheme() != "http" && url.scheme() != "https" { + return Err("expected http:// or https:// scheme"); + } + if url.path() != "/" { + return Err("path must be /".into()); + } + Ok(HttpOrigin { + origin: url.origin().ascii_serialization(), + host_idx: url.scheme().len() + "://".len(), + }) + } +} + async fn serve<C: Controller + Send + Sync + 'static>(controller: C, args: Args) { let controller = Arc::new(controller); - let args = Arc::new(args); - let server_args = args.clone(); #[cfg(unix)] - if let Some(socket_path) = &server_args.socket { + if let Some(socket_path) = &args.socket { + let origin: &'static HttpOrigin = Box::leak(Box::new( + args.origin + .expect("if you use --socket, you must specify an --origin"), + )); + // TODO: get rid of code duplication // we somehow need to specify the closure type or it gets too specific let service = make_service_fn(move |_| { let controller = controller.clone(); - let args = args.clone(); async move { Ok::<_, hyper::Error>(service_fn(move |req| { - service(controller.clone(), args.clone(), req) + service(origin, controller.clone(), req) })) } }); @@ -192,25 +205,29 @@ async fn serve<C: Controller + Send + Sync + 'static>(controller: C, args: Args) ); eprintln!(); + let addr = ([127, 0, 0, 1], args.port).into(); + let url = format!("http://{}", addr); + let origin: &'static HttpOrigin = Box::leak(Box::new( + args.origin.unwrap_or_else(|| url.parse().unwrap()), + )); + let service = make_service_fn(move |_| { let controller = controller.clone(); - let args = args.clone(); async move { Ok::<_, hyper::Error>(service_fn(move |req| { - service(controller.clone(), args.clone(), req) + service(&origin, controller.clone(), req) })) } }); - let addr = ([127, 0, 0, 1], server_args.port).into(); let server = Server::bind(&addr).serve(service); - println!("Listening on http://{}", addr); + println!("Listening on {}", url); server.await.expect("server error"); } async fn service<C: Controller>( + origin: &HttpOrigin, controller: Arc<C>, - args: Arc<Args>, request: Request, ) -> Result<HyperResponse, Infallible> { let (mut parts, body) = request.into_parts(); @@ -218,7 +235,7 @@ async fn service<C: Controller>( let mut script_csp = "'none'".into(); let mut frame_csp = "'none'".into(); - let mut resp = build_response(args, &*controller, &mut parts, body) + let mut resp = build_response(origin, &*controller, &mut parts, body) .await .map(|resp| match resp { Response::Raw(resp) => resp, @@ -344,11 +361,26 @@ impl Branch { } async fn build_response<C: Controller>( - args: Arc<Args>, + origin: &HttpOrigin, controller: &C, parts: &mut Parts, body: Body, ) -> Result<Response, Error> { + let host = parts + .headers + .get("Host") + .ok_or_else(|| Error::BadRequest("Host header required".into()))? + .to_str() + .unwrap(); + if host != origin.host() { + // We enforce an exact Host header to prevent DNS rebinding attacks. + return Err(Error::BadRequest(format!("<h1>Bad Request: Unknown Host header</h1>\ + Received the header <pre>Host: {}</pre> + But expected the header <pre>Host: {}</pre> \ + <p>If you want to serve GitPad under a different hostname you need to specify it on startup with <code>--origin</code>.</p>", + html_escape(host), html_escape(origin.host())))); + } + let unsanitized_path = percent_decode_str(parts.uri.path()) .decode_utf8() .map_err(|_| Error::BadRequest("failed to percent-decode path as UTF-8".into()))? @@ -392,7 +424,7 @@ async fn build_response<C: Controller>( } if parts.method == Method::POST { - return post_routes::build_response(&args, ¶ms, controller, ctx, body, parts).await; + return post_routes::build_response(origin, ¶ms, controller, ctx, body, parts).await; } let mut tree = ctx |