use std::convert::TryFrom; use super::{ character_classes::{ HEXDIG, IPV_FUTURE_LAST_PART, REG_NAME_NOT_PCT_ENCODED, }, context::Context, error::Error, percent_encoded_character_decoder::PercentEncodedCharacterDecoder, validate_ipv6_address::validate_ipv6_address, }; struct Shared { host: Vec, host_is_reg_name: bool, ipv6_address: String, pec_decoder: PercentEncodedCharacterDecoder, port_string: String, } enum State { NotIpLiteral(Shared), PercentEncodedCharacter(Shared), Ipv6Address(Shared), IpvFutureNumber(Shared), IpvFutureBody(Shared), GarbageCheck(Shared), Port(Shared), } impl State { fn finalize(self) -> Result<(Vec, Option), Error> { match self { Self::PercentEncodedCharacter(_) | Self::Ipv6Address(_) | Self::IpvFutureNumber(_) | Self::IpvFutureBody(_) => { // truncated or ended early Err(Error::TruncatedHost) }, Self::NotIpLiteral(state) | Self::GarbageCheck(state) | Self::Port(state) => { let mut state = state; if state.host_is_reg_name { state.host.make_ascii_lowercase(); } let port = if state.port_string.is_empty() { None } else { match state.port_string.parse::() { Ok(port) => Some(port), Err(error) => { return Err(Error::IllegalPortNumber(error)); }, } }; Ok((state.host, port)) }, } } fn new(host_port_string: &str) -> (Self, &str) { let mut shared = Shared { host: Vec::::new(), host_is_reg_name: false, ipv6_address: String::new(), pec_decoder: PercentEncodedCharacterDecoder::new(), port_string: String::new(), }; let mut host_port_string = host_port_string; if host_port_string.starts_with("[v") { host_port_string = &host_port_string[2..]; shared.host.push(b'v'); (Self::IpvFutureNumber(shared), host_port_string) } else if host_port_string.starts_with('[') { host_port_string = &host_port_string[1..]; (Self::Ipv6Address(shared), host_port_string) } else { shared.host_is_reg_name = true; (Self::NotIpLiteral(shared), host_port_string) } } fn next( self, c: char, ) -> Result { match self { Self::NotIpLiteral(state) => Self::next_not_ip_literal(state, c), Self::PercentEncodedCharacter(state) => { Self::next_percent_encoded_character(state, c) }, Self::Ipv6Address(state) => Self::next_ipv6_address(state, c), Self::IpvFutureNumber(state) => { Self::next_ipv_future_number(state, c) }, Self::IpvFutureBody(state) => Self::next_ipv_future_body(state, c), Self::GarbageCheck(state) => Self::next_garbage_check(state, c), Self::Port(state) => Self::next_port(state, c), } } fn next_not_ip_literal( state: Shared, c: char, ) -> Result { let mut state = state; if c == '%' { Ok(Self::PercentEncodedCharacter(state)) } else if c == ':' { Ok(Self::Port(state)) } else if REG_NAME_NOT_PCT_ENCODED.contains(&c) { state.host.push(u8::try_from(c as u32).unwrap()); Ok(Self::NotIpLiteral(state)) } else { Err(Error::IllegalCharacter(Context::Host)) } } fn next_percent_encoded_character( state: Shared, c: char, ) -> Result { let mut state = state; // We can't use `Option::map_or` (or `Option::map_or_else`, for similar // reasons) in this case because the closure would take ownership of // `state`, preventing it from being used to construct the default // value. #[allow(clippy::option_if_let_else)] if let Some(ci) = state.pec_decoder.next(c)? { state.host.push(ci); Ok(Self::NotIpLiteral(state)) } else { Ok(Self::PercentEncodedCharacter(state)) } } fn next_ipv6_address( state: Shared, c: char, ) -> Result { let mut state = state; if c == ']' { validate_ipv6_address(&state.ipv6_address)?; state.host = state .ipv6_address .chars() .map(|c| u8::try_from(c as u32).unwrap()) .collect(); Ok(Self::GarbageCheck(state)) } else { state.ipv6_address.push(c); Ok(Self::Ipv6Address(state)) } } fn next_ipv_future_number( state: Shared, c: char, ) -> Result { let mut state = state; if c == '.' { state.host.push(b'.'); Ok(Self::IpvFutureBody(state)) } else if c == ']' { Err(Error::TruncatedHost) } else if HEXDIG.contains(&c) { state.host.push(u8::try_from(c as u32).unwrap()); Ok(Self::IpvFutureNumber(state)) } else { Err(Error::IllegalCharacter(Context::IpvFuture)) } } fn next_ipv_future_body( state: Shared, c: char, ) -> Result { let mut state = state; if c == ']' { Ok(Self::GarbageCheck(state)) } else if IPV_FUTURE_LAST_PART.contains(&c) { state.host.push(u8::try_from(c as u32).unwrap()); Ok(Self::IpvFutureBody(state)) } else { Err(Error::IllegalCharacter(Context::IpvFuture)) } } fn next_garbage_check( state: Shared, c: char, ) -> Result { // illegal to have anything else, unless it's a colon, // in which case it's a port delimiter if c == ':' { Ok(Self::Port(state)) } else { Err(Error::IllegalCharacter(Context::Host)) } } fn next_port( state: Shared, c: char, ) -> Result { let mut state = state; state.port_string.push(c); Ok(Self::Port(state)) } } pub fn parse_host_port( host_port_string: T ) -> Result<(Vec, Option), Error> where T: AsRef, { let (machine, host_port_string) = State::new(host_port_string.as_ref()); host_port_string .chars() .try_fold(machine, |machine, c| machine.next(c))? .finalize() } #[cfg(test)] mod tests { use super::*; #[test] fn non_empty_port_number() { let result = parse_host_port("www.example.com:8080"); assert!(result.is_ok()); let (host, port) = result.unwrap(); assert_eq!(b"www.example.com", &host[..]); assert_eq!(Some(8080), port); } #[test] fn empty_port_number() { let result = parse_host_port("www.example.com:"); assert!(result.is_ok()); let (host, port) = result.unwrap(); assert_eq!(b"www.example.com", &host[..]); assert_eq!(None, port); } #[test] fn no_port_number() { let result = parse_host_port("www.example.com"); assert!(result.is_ok()); let (host, port) = result.unwrap(); assert_eq!(b"www.example.com", &host[..]); assert_eq!(None, port); } #[test] fn bad_port_number_purly_alphabetic() { let result = parse_host_port("www.example.com:spam"); assert!(result.is_err()); } #[test] fn bad_port_number_starts_numeric_ends_alphabetic() { let result = parse_host_port("www.example.com:8080spam"); assert!(result.is_err()); } #[test] fn largest_valid_port_number() { let result = parse_host_port("www.example.com:65535"); assert!(result.is_ok()); let (_, port) = result.unwrap(); assert_eq!(Some(65535), port); } #[test] fn bad_port_number_too_big() { let result = parse_host_port("www.example.com:65536"); assert!(matches!(result, Err(Error::IllegalPortNumber(_)))); } #[test] fn bad_port_number_negative() { let result = parse_host_port("www.example.com:-1234"); assert!(result.is_err()); } #[test] fn truncated_host() { let test_vectors = ["[::ffff:1.2.3.4/", "[:]/", "[v]/"]; for test_vector in &test_vectors { assert_eq!( Err(Error::TruncatedHost), parse_host_port(test_vector), "{}", test_vector ); } } #[test] fn ipv6_address_with_ipv4_part_missing_bracket() { assert!(matches!( parse_host_port("::ffff:1.2.3.4]"), Err(Error::IllegalPortNumber(_)) )); } }