aboutsummaryrefslogtreecommitdiff
path: root/src/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/main.rs')
-rw-r--r--src/main.rs106
1 files changed, 69 insertions, 37 deletions
diff --git a/src/main.rs b/src/main.rs
index d3c88f6..513f59e 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -28,6 +28,7 @@ use std::env;
use std::path::Component;
use std::path::Path;
use std::path::PathBuf;
+use std::str::FromStr;
use std::sync::Arc;
use url::Url;
@@ -80,9 +81,9 @@ struct Args {
#[clap(short, default_value = "8000")]
port: u16,
- /// Enforce the given HTTP Origin header value to prevent CSRF attacks.
- #[clap(long, validator = validate_origin)]
- origin: Option<String>,
+ /// e.g. https://example.com (used to enforce Host and Origin headers)
+ #[clap(long)]
+ origin: Option<HttpOrigin>,
/// Serve via the given Unix domain socket path.
#[cfg(unix)]
@@ -90,33 +91,12 @@ struct Args {
socket: Option<String>,
}
-fn validate_origin(input: &str) -> Result<(), String> {
- let url = Url::parse(input).map_err(|e| e.to_string())?;
- if url.scheme() != "http" && url.scheme() != "https" {
- return Err("must start with http:// or https://".into());
- }
- if url.path() != "/" {
- return Err("must not have a path".into());
- }
- if input.ends_with('/') {
- return Err("must not end with a trailing slash".into());
- }
- Ok(())
-}
-
#[tokio::main]
async fn main() {
let args = Args::parse();
let repo = Repository::open_bare(env::current_dir().unwrap())
.expect("expected current directory to be a bare Git repository");
- if args.origin.is_none() {
- eprintln!(
- "[warning] Running gitpad without --origin might \
- make you vulnerable to CSRF attacks."
- );
- }
-
if args.multiuser {
serve(MultiUserController::new(&repo), args).await;
} else {
@@ -136,22 +116,55 @@ async fn main() {
}
}
+#[derive(Clone, Debug)]
+struct HttpOrigin {
+ origin: String,
+ host_idx: usize,
+}
+
+impl HttpOrigin {
+ /// Returns the Host header value (e.g. `example.com` for the origin `https://example.com`).
+ fn host(&self) -> &str {
+ &self.origin[self.host_idx..]
+ }
+}
+
+impl FromStr for HttpOrigin {
+ type Err = &'static str;
+
+ fn from_str(s: &str) -> Result<Self, Self::Err> {
+ let url = Url::parse(s).map_err(|_| "invalid URL")?;
+ if url.scheme() != "http" && url.scheme() != "https" {
+ return Err("expected http:// or https:// scheme");
+ }
+ if url.path() != "/" {
+ return Err("path must be /".into());
+ }
+ Ok(HttpOrigin {
+ origin: url.origin().ascii_serialization(),
+ host_idx: url.scheme().len() + "://".len(),
+ })
+ }
+}
+
async fn serve<C: Controller + Send + Sync + 'static>(controller: C, args: Args) {
let controller = Arc::new(controller);
- let args = Arc::new(args);
- let server_args = args.clone();
#[cfg(unix)]
- if let Some(socket_path) = &server_args.socket {
+ if let Some(socket_path) = &args.socket {
+ let origin: &'static HttpOrigin = Box::leak(Box::new(
+ args.origin
+ .expect("if you use --socket, you must specify an --origin"),
+ ));
+
// TODO: get rid of code duplication
// we somehow need to specify the closure type or it gets too specific
let service = make_service_fn(move |_| {
let controller = controller.clone();
- let args = args.clone();
async move {
Ok::<_, hyper::Error>(service_fn(move |req| {
- service(controller.clone(), args.clone(), req)
+ service(origin, controller.clone(), req)
}))
}
});
@@ -192,25 +205,29 @@ async fn serve<C: Controller + Send + Sync + 'static>(controller: C, args: Args)
);
eprintln!();
+ let addr = ([127, 0, 0, 1], args.port).into();
+ let url = format!("http://{}", addr);
+ let origin: &'static HttpOrigin = Box::leak(Box::new(
+ args.origin.unwrap_or_else(|| url.parse().unwrap()),
+ ));
+
let service = make_service_fn(move |_| {
let controller = controller.clone();
- let args = args.clone();
async move {
Ok::<_, hyper::Error>(service_fn(move |req| {
- service(controller.clone(), args.clone(), req)
+ service(&origin, controller.clone(), req)
}))
}
});
- let addr = ([127, 0, 0, 1], server_args.port).into();
let server = Server::bind(&addr).serve(service);
- println!("Listening on http://{}", addr);
+ println!("Listening on {}", url);
server.await.expect("server error");
}
async fn service<C: Controller>(
+ origin: &HttpOrigin,
controller: Arc<C>,
- args: Arc<Args>,
request: Request,
) -> Result<HyperResponse, Infallible> {
let (mut parts, body) = request.into_parts();
@@ -218,7 +235,7 @@ async fn service<C: Controller>(
let mut script_csp = "'none'".into();
let mut frame_csp = "'none'".into();
- let mut resp = build_response(args, &*controller, &mut parts, body)
+ let mut resp = build_response(origin, &*controller, &mut parts, body)
.await
.map(|resp| match resp {
Response::Raw(resp) => resp,
@@ -344,11 +361,26 @@ impl Branch {
}
async fn build_response<C: Controller>(
- args: Arc<Args>,
+ origin: &HttpOrigin,
controller: &C,
parts: &mut Parts,
body: Body,
) -> Result<Response, Error> {
+ let host = parts
+ .headers
+ .get("Host")
+ .ok_or_else(|| Error::BadRequest("Host header required".into()))?
+ .to_str()
+ .unwrap();
+ if host != origin.host() {
+ // We enforce an exact Host header to prevent DNS rebinding attacks.
+ return Err(Error::BadRequest(format!("<h1>Bad Request: Unknown Host header</h1>\
+ Received the header <pre>Host: {}</pre>
+ But expected the header <pre>Host: {}</pre> \
+ <p>If you want to serve GitPad under a different hostname you need to specify it on startup with <code>--origin</code>.</p>",
+ html_escape(host), html_escape(origin.host()))));
+ }
+
let unsanitized_path = percent_decode_str(parts.uri.path())
.decode_utf8()
.map_err(|_| Error::BadRequest("failed to percent-decode path as UTF-8".into()))?
@@ -392,7 +424,7 @@ async fn build_response<C: Controller>(
}
if parts.method == Method::POST {
- return post_routes::build_response(&args, &params, controller, ctx, body, parts).await;
+ return post_routes::build_response(origin, &params, controller, ctx, body, parts).await;
}
let mut tree = ctx