aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMartin Fischer <martin@push-f.com>2021-07-04 00:25:01 +0200
committerMartin Fischer <martin@push-f.com>2021-07-04 11:19:41 +0200
commitc0b8d6a9876a95bc5d8fd8a30333e65949f5c9d1 (patch)
tree40bfa7846a11fae026f45aefb3b1a87dbfa59ec6
parent17d98a70a76efc02f643e9cfe13836120a2c5114 (diff)
strictly enforce Host and Origin headers
Previously the Origin header was only checked if you specified an origin with --origin on startup and when you didn't we just printed a warning that this might make you vulnerable to CSRF attacks. I implemented it this way since I wanted GitPad to be runnable without any command-line options, but such warnings are of course suboptimal for security since they can simply be ignored. This commit changes this behavior so that the Origin header is always checked for POST requests. If you just run "gitpad" the enforced origin defaults to http://127.0.0.1:<port>. Additionally this commit also enforces an exact Host header (extracted from the Origin) to prevent DNS rebinding attacks.
-rw-r--r--src/main.rs106
-rw-r--r--src/post_routes.rs27
2 files changed, 82 insertions, 51 deletions
diff --git a/src/main.rs b/src/main.rs
index d3c88f6..513f59e 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -28,6 +28,7 @@ use std::env;
use std::path::Component;
use std::path::Path;
use std::path::PathBuf;
+use std::str::FromStr;
use std::sync::Arc;
use url::Url;
@@ -80,9 +81,9 @@ struct Args {
#[clap(short, default_value = "8000")]
port: u16,
- /// Enforce the given HTTP Origin header value to prevent CSRF attacks.
- #[clap(long, validator = validate_origin)]
- origin: Option<String>,
+ /// e.g. https://example.com (used to enforce Host and Origin headers)
+ #[clap(long)]
+ origin: Option<HttpOrigin>,
/// Serve via the given Unix domain socket path.
#[cfg(unix)]
@@ -90,33 +91,12 @@ struct Args {
socket: Option<String>,
}
-fn validate_origin(input: &str) -> Result<(), String> {
- let url = Url::parse(input).map_err(|e| e.to_string())?;
- if url.scheme() != "http" && url.scheme() != "https" {
- return Err("must start with http:// or https://".into());
- }
- if url.path() != "/" {
- return Err("must not have a path".into());
- }
- if input.ends_with('/') {
- return Err("must not end with a trailing slash".into());
- }
- Ok(())
-}
-
#[tokio::main]
async fn main() {
let args = Args::parse();
let repo = Repository::open_bare(env::current_dir().unwrap())
.expect("expected current directory to be a bare Git repository");
- if args.origin.is_none() {
- eprintln!(
- "[warning] Running gitpad without --origin might \
- make you vulnerable to CSRF attacks."
- );
- }
-
if args.multiuser {
serve(MultiUserController::new(&repo), args).await;
} else {
@@ -136,22 +116,55 @@ async fn main() {
}
}
+#[derive(Clone, Debug)]
+struct HttpOrigin {
+ origin: String,
+ host_idx: usize,
+}
+
+impl HttpOrigin {
+ /// Returns the Host header value (e.g. `example.com` for the origin `https://example.com`).
+ fn host(&self) -> &str {
+ &self.origin[self.host_idx..]
+ }
+}
+
+impl FromStr for HttpOrigin {
+ type Err = &'static str;
+
+ fn from_str(s: &str) -> Result<Self, Self::Err> {
+ let url = Url::parse(s).map_err(|_| "invalid URL")?;
+ if url.scheme() != "http" && url.scheme() != "https" {
+ return Err("expected http:// or https:// scheme");
+ }
+ if url.path() != "/" {
+ return Err("path must be /".into());
+ }
+ Ok(HttpOrigin {
+ origin: url.origin().ascii_serialization(),
+ host_idx: url.scheme().len() + "://".len(),
+ })
+ }
+}
+
async fn serve<C: Controller + Send + Sync + 'static>(controller: C, args: Args) {
let controller = Arc::new(controller);
- let args = Arc::new(args);
- let server_args = args.clone();
#[cfg(unix)]
- if let Some(socket_path) = &server_args.socket {
+ if let Some(socket_path) = &args.socket {
+ let origin: &'static HttpOrigin = Box::leak(Box::new(
+ args.origin
+ .expect("if you use --socket, you must specify an --origin"),
+ ));
+
// TODO: get rid of code duplication
// we somehow need to specify the closure type or it gets too specific
let service = make_service_fn(move |_| {
let controller = controller.clone();
- let args = args.clone();
async move {
Ok::<_, hyper::Error>(service_fn(move |req| {
- service(controller.clone(), args.clone(), req)
+ service(origin, controller.clone(), req)
}))
}
});
@@ -192,25 +205,29 @@ async fn serve<C: Controller + Send + Sync + 'static>(controller: C, args: Args)
);
eprintln!();
+ let addr = ([127, 0, 0, 1], args.port).into();
+ let url = format!("http://{}", addr);
+ let origin: &'static HttpOrigin = Box::leak(Box::new(
+ args.origin.unwrap_or_else(|| url.parse().unwrap()),
+ ));
+
let service = make_service_fn(move |_| {
let controller = controller.clone();
- let args = args.clone();
async move {
Ok::<_, hyper::Error>(service_fn(move |req| {
- service(controller.clone(), args.clone(), req)
+ service(&origin, controller.clone(), req)
}))
}
});
- let addr = ([127, 0, 0, 1], server_args.port).into();
let server = Server::bind(&addr).serve(service);
- println!("Listening on http://{}", addr);
+ println!("Listening on {}", url);
server.await.expect("server error");
}
async fn service<C: Controller>(
+ origin: &HttpOrigin,
controller: Arc<C>,
- args: Arc<Args>,
request: Request,
) -> Result<HyperResponse, Infallible> {
let (mut parts, body) = request.into_parts();
@@ -218,7 +235,7 @@ async fn service<C: Controller>(
let mut script_csp = "'none'".into();
let mut frame_csp = "'none'".into();
- let mut resp = build_response(args, &*controller, &mut parts, body)
+ let mut resp = build_response(origin, &*controller, &mut parts, body)
.await
.map(|resp| match resp {
Response::Raw(resp) => resp,
@@ -344,11 +361,26 @@ impl Branch {
}
async fn build_response<C: Controller>(
- args: Arc<Args>,
+ origin: &HttpOrigin,
controller: &C,
parts: &mut Parts,
body: Body,
) -> Result<Response, Error> {
+ let host = parts
+ .headers
+ .get("Host")
+ .ok_or_else(|| Error::BadRequest("Host header required".into()))?
+ .to_str()
+ .unwrap();
+ if host != origin.host() {
+ // We enforce an exact Host header to prevent DNS rebinding attacks.
+ return Err(Error::BadRequest(format!("<h1>Bad Request: Unknown Host header</h1>\
+ Received the header <pre>Host: {}</pre>
+ But expected the header <pre>Host: {}</pre> \
+ <p>If you want to serve GitPad under a different hostname you need to specify it on startup with <code>--origin</code>.</p>",
+ html_escape(host), html_escape(origin.host()))));
+ }
+
let unsanitized_path = percent_decode_str(parts.uri.path())
.decode_utf8()
.map_err(|_| Error::BadRequest("failed to percent-decode path as UTF-8".into()))?
@@ -392,7 +424,7 @@ async fn build_response<C: Controller>(
}
if parts.method == Method::POST {
- return post_routes::build_response(&args, &params, controller, ctx, body, parts).await;
+ return post_routes::build_response(origin, &params, controller, ctx, body, parts).await;
}
let mut tree = ctx
diff --git a/src/post_routes.rs b/src/post_routes.rs
index 7588abc..1b5d615 100644
--- a/src/post_routes.rs
+++ b/src/post_routes.rs
@@ -18,32 +18,31 @@ use crate::forms::EditForm;
use crate::forms::MoveForm;
use crate::get_renderer;
use crate::ActionParam;
-use crate::Args;
use crate::Context;
+use crate::HttpOrigin;
use crate::RenderMode;
use crate::Response;
use crate::{controller::Controller, Error};
pub(crate) async fn build_response<C: Controller>(
- args: &Args,
+ host: &HttpOrigin,
params: &ActionParam,
controller: &C,
ctx: Context,
body: Body,
parts: &mut Parts,
) -> Result<Response, Error> {
- if let Some(ref enforced_origin) = args.origin {
- if parts
- .headers
- .get(header::ORIGIN)
- .filter(|h| h.as_bytes() == enforced_origin.as_bytes())
- .is_none()
- {
- return Err(Error::BadRequest(format!(
- "POST requests must be sent with the header Origin: {}",
- enforced_origin
- )));
- }
+ if parts
+ .headers
+ .get(header::ORIGIN)
+ .filter(|h| h.as_bytes() == host.origin.as_bytes())
+ .is_none()
+ {
+ // This check prevents cross-site request forgery (CSRF).
+ return Err(Error::BadRequest(format!(
+ "POST requests must be sent with the header Origin: {}",
+ host.origin
+ )));
}
match params.action.as_ref() {
"edit" => return update_blob(body, controller, ctx, parts).await,