use std::convert::TryFrom; use super::character_classes::{ HEXDIG, IPV_FUTURE_LAST_PART, REG_NAME_NOT_PCT_ENCODED, }; use super::context::Context; use super::error::Error; use super::percent_encoded_character_decoder::PercentEncodedCharacterDecoder; use super::validate_ipv6_address::validate_ipv6_address; struct Shared { host: Vec, host_is_reg_name: bool, ipv6_address: String, pec_decoder: PercentEncodedCharacterDecoder, port_string: String, } enum State { NotIpLiteral(Shared), PercentEncodedCharacter(Shared), Ipv6Address(Shared), IpvFutureNumber(Shared), IpvFutureBody(Shared), GarbageCheck(Shared), Port(Shared), } impl State{ fn finalize(self) -> Result<(Vec, Option), Error> { match self { Self::PercentEncodedCharacter(_) | Self::Ipv6Address(_) | Self::IpvFutureNumber(_) | Self::IpvFutureBody(_) => { // truncated or ended early Err(Error::TruncatedHost) }, Self::NotIpLiteral(state) | Self::GarbageCheck(state) | Self::Port(state) => { let mut state = state; if state.host_is_reg_name { state.host.make_ascii_lowercase(); } let port = if state.port_string.is_empty() { None } else { match state.port_string.parse::() { Ok(port) => { Some(port) }, Err(error) => { return Err(Error::IllegalPortNumber(error)); } } }; Ok((state.host, port)) }, } } fn new(host_port_string: &str) -> (Self, &str) { let mut shared = Shared{ host: Vec::::new(), host_is_reg_name: false, ipv6_address: String::new(), pec_decoder: PercentEncodedCharacterDecoder::new(), port_string: String::new(), }; let mut host_port_string = host_port_string; if host_port_string.starts_with("[v") { host_port_string = &host_port_string[2..]; shared.host.push(b'v'); ( Self::IpvFutureNumber(shared), host_port_string ) } else if host_port_string.starts_with('[') { host_port_string = &host_port_string[1..]; ( Self::Ipv6Address(shared), host_port_string ) } else { shared.host_is_reg_name = true; ( Self::NotIpLiteral(shared), host_port_string ) } } fn next(self, c: char) -> Result { match self { Self::NotIpLiteral(state) => Self::next_not_ip_literal(state, c), Self::PercentEncodedCharacter(state) => Self::next_percent_encoded_character(state, c), Self::Ipv6Address(state) => Self::next_ipv6_address(state, c), Self::IpvFutureNumber(state) => Self::next_ipv_future_number(state, c), Self::IpvFutureBody(state) => Self::next_ipv_future_body(state, c), Self::GarbageCheck(state) => Self::next_garbage_check(state, c), Self::Port(state) => Self::next_port(state, c), } } fn next_not_ip_literal(state: Shared, c: char) -> Result { let mut state = state; if c == '%' { Ok(Self::PercentEncodedCharacter(state)) } else if c == ':' { Ok(Self::Port(state)) } else if REG_NAME_NOT_PCT_ENCODED.contains(&c) { state.host.push(u8::try_from(c as u32).unwrap()); Ok(Self::NotIpLiteral(state)) } else { Err(Error::IllegalCharacter(Context::Host)) } } fn next_percent_encoded_character(state: Shared, c: char) -> Result { let mut state = state; // We can't use `Option::map_or` (or `Option::map_or_else`, for similar // reasons) in this case because the closure would take ownership of // `state`, preventing it from being used to construct the default // value. #[allow(clippy::option_if_let_else)] if let Some(ci) = state.pec_decoder.next(c)? { state.host.push(ci); Ok(Self::NotIpLiteral(state)) } else { Ok(Self::PercentEncodedCharacter(state)) } } fn next_ipv6_address(state: Shared, c: char) -> Result { let mut state = state; if c == ']' { validate_ipv6_address(&state.ipv6_address)?; state.host = state.ipv6_address.chars().map( |c| u8::try_from(c as u32).unwrap() ).collect(); Ok(Self::GarbageCheck(state)) } else { state.ipv6_address.push(c); Ok(Self::Ipv6Address(state)) } } fn next_ipv_future_number(state: Shared, c: char) -> Result { let mut state = state; if c == '.' { state.host.push(b'.'); Ok(Self::IpvFutureBody(state)) } else if c == ']' { Err(Error::TruncatedHost) } else if HEXDIG.contains(&c) { state.host.push(u8::try_from(c as u32).unwrap()); Ok(Self::IpvFutureNumber(state)) } else { Err(Error::IllegalCharacter(Context::IpvFuture)) } } fn next_ipv_future_body(state: Shared, c: char) -> Result { let mut state = state; if c == ']' { Ok(Self::GarbageCheck(state)) } else if IPV_FUTURE_LAST_PART.contains(&c) { state.host.push(u8::try_from(c as u32).unwrap()); Ok(Self::IpvFutureBody(state)) } else { Err(Error::IllegalCharacter(Context::IpvFuture)) } } fn next_garbage_check(state: Shared, c: char) -> Result { // illegal to have anything else, unless it's a colon, // in which case it's a port delimiter if c == ':' { Ok(Self::Port(state)) } else { Err(Error::IllegalCharacter(Context::Host)) } } fn next_port(state: Shared, c: char) -> Result { let mut state = state; state.port_string.push(c); Ok(Self::Port(state)) } } pub fn parse_host_port(host_port_string: T) -> Result<(Vec, Option), Error> where T: AsRef { let (machine, host_port_string) = State::new(host_port_string.as_ref()); host_port_string .chars() .try_fold(machine, |machine, c| { machine.next(c) })? .finalize() } #[cfg(test)] mod tests { use super::*; #[test] fn non_empty_port_number() { let result = parse_host_port("www.example.com:8080"); assert!(result.is_ok()); let (host, port) = result.unwrap(); assert_eq!(b"www.example.com", &host[..]); assert_eq!(Some(8080), port); } #[test] fn empty_port_number() { let result = parse_host_port("www.example.com:"); assert!(result.is_ok()); let (host, port) = result.unwrap(); assert_eq!(b"www.example.com", &host[..]); assert_eq!(None, port); } #[test] fn no_port_number() { let result = parse_host_port("www.example.com"); assert!(result.is_ok()); let (host, port) = result.unwrap(); assert_eq!(b"www.example.com", &host[..]); assert_eq!(None, port); } #[test] fn bad_port_number_purly_alphabetic() { let result = parse_host_port("www.example.com:spam"); assert!(result.is_err()); } #[test] fn bad_port_number_starts_numeric_ends_alphabetic() { let result = parse_host_port("www.example.com:8080spam"); assert!(result.is_err()); } #[test] fn largest_valid_port_number() { let result = parse_host_port("www.example.com:65535"); assert!(result.is_ok()); let (_, port) = result.unwrap(); assert_eq!(Some(65535), port); } #[test] fn bad_port_number_too_big() { let result = parse_host_port("www.example.com:65536"); assert!(matches!(result, Err(Error::IllegalPortNumber(_)))); } #[test] fn bad_port_number_negative() { let result = parse_host_port("www.example.com:-1234"); assert!(result.is_err()); } #[test] fn truncated_host() { let test_vectors = [ "[::ffff:1.2.3.4/", "[:]/", "[v]/", ]; for test_vector in &test_vectors { assert_eq!( Err(Error::TruncatedHost), parse_host_port(test_vector), "{}", test_vector ); } } #[test] fn ipv6_address_with_ipv4_part_missing_bracket() { assert!(matches!( parse_host_port("::ffff:1.2.3.4]"), Err(Error::IllegalPortNumber(_)) )); } }