diff options
author | Richard Walters <rwalters@digitalstirling.com> | 2020-10-13 01:09:18 -0700 |
---|---|---|
committer | Richard Walters <rwalters@digitalstirling.com> | 2020-10-13 01:09:18 -0700 |
commit | dc2a011598f4aa9e9de927333e467e623276d5ec (patch) | |
tree | 4b5c71634af516cdc96c512f28a02370d48c25b3 /src | |
parent | 4accf8c296ef7a1f6bd10a90b7a06b3b499ccda6 (diff) |
Rust refactoring
* Move Context, Error, and character classes to their own modules.
* Move host/port parsing and IP address validation to their
own modules, and break the code up into different functions
to process their state machines.
Diffstat (limited to 'src')
-rw-r--r-- | src/character_classes.rs | 131 | ||||
-rw-r--r-- | src/context.rs | 48 | ||||
-rw-r--r-- | src/error.rs | 40 | ||||
-rw-r--r-- | src/lib.rs | 645 | ||||
-rw-r--r-- | src/parse_host_port.rs | 202 | ||||
-rw-r--r-- | src/percent_encoded_character_decoder.rs | 8 | ||||
-rw-r--r-- | src/validate_ipv4_address.rs | 96 | ||||
-rw-r--r-- | src/validate_ipv6_address.rs | 212 |
8 files changed, 794 insertions, 588 deletions
diff --git a/src/character_classes.rs b/src/character_classes.rs new file mode 100644 index 0000000..4b13f01 --- /dev/null +++ b/src/character_classes.rs @@ -0,0 +1,131 @@ +#![warn(clippy::pedantic)] + +use once_cell::sync::Lazy; +use std::collections::HashSet; + +// This is the character set containing just the alphabetic characters +// from the ASCII character set. +pub static ALPHA: Lazy<HashSet<char>> = Lazy::new(|| + ('a'..='z') + .chain('A'..='Z') + .collect() +); + +// This is the character set containing just numbers. +pub static DIGIT: Lazy<HashSet<char>> = Lazy::new(|| + ('0'..='9') + .collect() +); + +// This is the character set containing just the characters allowed +// in a hexadecimal digit. +pub static HEXDIG: Lazy<HashSet<char>> = Lazy::new(|| + ('0'..='9') + .chain('A'..='F') + .chain('a'..='f') + .collect() +); + +// This is the character set corresponds to the "unreserved" syntax +// specified in RFC 3986 (https://tools.ietf.org/html/rfc3986). +pub static UNRESERVED: Lazy<HashSet<char>> = Lazy::new(|| + ALPHA.iter() + .chain(DIGIT.iter()) + .chain(['-', '.', '_', '~'].iter()) + .copied() + .collect() +); + +// This is the character set corresponds to the "sub-delims" syntax +// specified in RFC 3986 (https://tools.ietf.org/html/rfc3986). +pub static SUB_DELIMS: Lazy<HashSet<char>> = Lazy::new(|| + [ + '!', '$', '&', '\'', '(', ')', + '*', '+', ',', ';', '=' + ] + .iter() + .copied() + .collect() +); + +// This is the character set corresponds to the second part +// of the "scheme" syntax +// specified in RFC 3986 (https://tools.ietf.org/html/rfc3986). +pub static SCHEME_NOT_FIRST: Lazy<HashSet<char>> = Lazy::new(|| + ALPHA.iter() + .chain(DIGIT.iter()) + .chain(['+', '-', '.'].iter()) + .copied() + .collect() +); + +// This is the character set corresponds to the "pchar" syntax +// specified in RFC 3986 (https://tools.ietf.org/html/rfc3986), +// leaving out "pct-encoded". +pub static PCHAR_NOT_PCT_ENCODED: Lazy<HashSet<char>> = Lazy::new(|| + UNRESERVED.iter() + .chain(SUB_DELIMS.iter()) + .chain([':', '@'].iter()) + .copied() + .collect() +); + +// This is the character set corresponds to the "query" syntax +// and the "fragment" syntax +// specified in RFC 3986 (https://tools.ietf.org/html/rfc3986), +// leaving out "pct-encoded". +pub static QUERY_OR_FRAGMENT_NOT_PCT_ENCODED: Lazy<HashSet<char>> = Lazy::new(|| + PCHAR_NOT_PCT_ENCODED.iter() + .chain(['/', '?'].iter()) + .copied() + .collect() +); + +// This is the character set almost corresponds to the "query" syntax +// specified in RFC 3986 (https://tools.ietf.org/html/rfc3986), +// leaving out "pct-encoded", except that '+' is also excluded, because +// for some web services (e.g. AWS S3) a '+' is treated as +// synonymous with a space (' ') and thus gets misinterpreted. +pub static QUERY_NOT_PCT_ENCODED_WITHOUT_PLUS: Lazy<HashSet<char>> = Lazy::new(|| + UNRESERVED.iter() + .chain([ + '!', '$', '&', '\'', '(', ')', + '*', ',', ';', '=', + ':', '@', + '/', '?' + ].iter()) + .copied() + .collect() +); + +// This is the character set corresponds to the "userinfo" syntax +// specified in RFC 3986 (https://tools.ietf.org/html/rfc3986), +// leaving out "pct-encoded". +pub static USER_INFO_NOT_PCT_ENCODED: Lazy<HashSet<char>> = Lazy::new(|| + UNRESERVED.iter() + .chain(SUB_DELIMS.iter()) + .chain([':'].iter()) + .copied() + .collect() +); + +// This is the character set corresponds to the "reg-name" syntax +// specified in RFC 3986 (https://tools.ietf.org/html/rfc3986), +// leaving out "pct-encoded". +pub static REG_NAME_NOT_PCT_ENCODED: Lazy<HashSet<char>> = Lazy::new(|| + UNRESERVED.iter() + .chain(SUB_DELIMS.iter()) + .copied() + .collect() +); + +// This is the character set corresponds to the last part of +// the "IPvFuture" syntax +// specified in RFC 3986 (https://tools.ietf.org/html/rfc3986). +pub static IPV_FUTURE_LAST_PART: Lazy<HashSet<char>> = Lazy::new(|| + UNRESERVED.iter() + .chain(SUB_DELIMS.iter()) + .chain([':'].iter()) + .copied() + .collect() +); diff --git a/src/context.rs b/src/context.rs new file mode 100644 index 0000000..ef8c7cb --- /dev/null +++ b/src/context.rs @@ -0,0 +1,48 @@ +#![warn(clippy::pedantic)] + +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum Context { + Fragment, + Host, + Ipv4Address, + Ipv6Address, + IpvFuture, + Path, + Query, + Scheme, + Userinfo, +} + +impl std::fmt::Display for Context { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Context::Fragment => { + write!(f, "fragment") + }, + Context::Host => { + write!(f, "host") + }, + Context::Ipv4Address => { + write!(f, "IPv4 address") + }, + Context::Ipv6Address => { + write!(f, "IPv6 address") + }, + Context::IpvFuture => { + write!(f, "IPvFuture") + }, + Context::Path => { + write!(f, "path") + }, + Context::Query => { + write!(f, "query") + }, + Context::Scheme => { + write!(f, "scheme") + }, + Context::Userinfo => { + write!(f, "user info") + }, + } + } +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..e37a02f --- /dev/null +++ b/src/error.rs @@ -0,0 +1,40 @@ +#![warn(clippy::pedantic)] + +use super::context::Context; + +#[derive(Debug, Clone, thiserror::Error, PartialEq)] +pub enum Error { + #[error("URI contains non-UTF8 sequences")] + CannotExpressAsUtf8(#[from] std::string::FromUtf8Error), + + #[error("scheme expected but missing")] + EmptyScheme, + + #[error("illegal character in {0}")] + IllegalCharacter(Context), + + #[error("illegal percent encoding")] + IllegalPercentEncoding, + + #[error("illegal port number")] + IllegalPortNumber(#[source] std::num::ParseIntError), + + #[error("octet group expected")] + InvalidDecimalOctet, + + #[error("too few address parts")] + TooFewAddressParts, + + #[error("too many address parts")] + TooManyAddressParts, + + #[error("too many digits in IPv6 address part")] + TooManyDigits, + + #[error("too many double-colons in IPv6 address")] + TooManyDoubleColons, + + #[error("truncated host")] + TruncatedHost, +} + @@ -6,222 +6,32 @@ #[macro_use] extern crate named_tuple; -mod percent_encoded_character_decoder; -use percent_encoded_character_decoder::PercentEncodedCharacterDecoder; - use std::collections::HashSet; use std::convert::TryFrom; -use once_cell::sync::Lazy; - -// This is the character set containing just the alphabetic characters -// from the ASCII character set. -static ALPHA: Lazy<HashSet<char>> = Lazy::new(|| - ('a'..='z') - .chain('A'..='Z') - .collect() -); - -// This is the character set containing just numbers. -static DIGIT: Lazy<HashSet<char>> = Lazy::new(|| - ('0'..='9') - .collect() -); - -// This is the character set containing just the characters allowed -// in a hexadecimal digit. -static HEXDIG: Lazy<HashSet<char>> = Lazy::new(|| - ('0'..='9') - .chain('A'..='F') - .chain('a'..='f') - .collect() -); - -// This is the character set corresponds to the "unreserved" syntax -// specified in RFC 3986 (https://tools.ietf.org/html/rfc3986). -static UNRESERVED: Lazy<HashSet<char>> = Lazy::new(|| - ALPHA.iter() - .chain(DIGIT.iter()) - .chain(['-', '.', '_', '~'].iter()) - .copied() - .collect() -); - -// This is the character set corresponds to the "sub-delims" syntax -// specified in RFC 3986 (https://tools.ietf.org/html/rfc3986). -static SUB_DELIMS: Lazy<HashSet<char>> = Lazy::new(|| - [ - '!', '$', '&', '\'', '(', ')', - '*', '+', ',', ';', '=' - ] - .iter() - .copied() - .collect() -); - -// This is the character set corresponds to the second part -// of the "scheme" syntax -// specified in RFC 3986 (https://tools.ietf.org/html/rfc3986). -static SCHEME_NOT_FIRST: Lazy<HashSet<char>> = Lazy::new(|| - ALPHA.iter() - .chain(DIGIT.iter()) - .chain(['+', '-', '.'].iter()) - .copied() - .collect() -); - -// This is the character set corresponds to the "pchar" syntax -// specified in RFC 3986 (https://tools.ietf.org/html/rfc3986), -// leaving out "pct-encoded". -static PCHAR_NOT_PCT_ENCODED: Lazy<HashSet<char>> = Lazy::new(|| - UNRESERVED.iter() - .chain(SUB_DELIMS.iter()) - .chain([':', '@'].iter()) - .copied() - .collect() -); - -// This is the character set corresponds to the "query" syntax -// and the "fragment" syntax -// specified in RFC 3986 (https://tools.ietf.org/html/rfc3986), -// leaving out "pct-encoded". -static QUERY_OR_FRAGMENT_NOT_PCT_ENCODED: Lazy<HashSet<char>> = Lazy::new(|| - PCHAR_NOT_PCT_ENCODED.iter() - .chain(['/', '?'].iter()) - .copied() - .collect() -); - -// This is the character set almost corresponds to the "query" syntax -// specified in RFC 3986 (https://tools.ietf.org/html/rfc3986), -// leaving out "pct-encoded", except that '+' is also excluded, because -// for some web services (e.g. AWS S3) a '+' is treated as -// synonymous with a space (' ') and thus gets misinterpreted. -static QUERY_NOT_PCT_ENCODED_WITHOUT_PLUS: Lazy<HashSet<char>> = Lazy::new(|| - UNRESERVED.iter() - .chain([ - '!', '$', '&', '\'', '(', ')', - '*', ',', ';', '=', - ':', '@', - '/', '?' - ].iter()) - .copied() - .collect() -); - -// This is the character set corresponds to the "userinfo" syntax -// specified in RFC 3986 (https://tools.ietf.org/html/rfc3986), -// leaving out "pct-encoded". -static USER_INFO_NOT_PCT_ENCODED: Lazy<HashSet<char>> = Lazy::new(|| - UNRESERVED.iter() - .chain(SUB_DELIMS.iter()) - .chain([':'].iter()) - .copied() - .collect() -); - -// This is the character set corresponds to the "reg-name" syntax -// specified in RFC 3986 (https://tools.ietf.org/html/rfc3986), -// leaving out "pct-encoded". -static REG_NAME_NOT_PCT_ENCODED: Lazy<HashSet<char>> = Lazy::new(|| - UNRESERVED.iter() - .chain(SUB_DELIMS.iter()) - .copied() - .collect() -); - -// This is the character set corresponds to the last part of -// the "IPvFuture" syntax -// specified in RFC 3986 (https://tools.ietf.org/html/rfc3986). -static IPV_FUTURE_LAST_PART: Lazy<HashSet<char>> = Lazy::new(|| - UNRESERVED.iter() - .chain(SUB_DELIMS.iter()) - .chain([':'].iter()) - .copied() - .collect() -); - -#[derive(Debug, Clone, Copy, PartialEq)] -pub enum Context { - Fragment, - Host, - Ipv4Address, - Ipv6Address, - IpvFuture, - Path, - Query, - Scheme, - Userinfo, -} - -impl std::fmt::Display for Context { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Context::Fragment => { - write!(f, "fragment") - }, - Context::Host => { - write!(f, "host") - }, - Context::Ipv4Address => { - write!(f, "IPv4 address") - }, - Context::Ipv6Address => { - write!(f, "IPv6 address") - }, - Context::IpvFuture => { - write!(f, "IPvFuture") - }, - Context::Path => { - write!(f, "path") - }, - Context::Query => { - write!(f, "query") - }, - Context::Scheme => { - write!(f, "scheme") - }, - Context::Userinfo => { - write!(f, "user info") - }, - } - } -} - -#[derive(Debug, Clone, thiserror::Error, PartialEq)] -pub enum Error { - #[error("URI contains non-UTF8 sequences")] - CannotExpressAsUtf8(#[from] std::string::FromUtf8Error), - - #[error("scheme expected but missing")] - EmptyScheme, - - #[error("illegal character in {0}")] - IllegalCharacter(Context), - #[error("illegal percent encoding")] - IllegalPercentEncoding(#[from] percent_encoded_character_decoder::Error), - - #[error("illegal port number")] - IllegalPortNumber(#[source] std::num::ParseIntError), - - #[error("octet group expected")] - InvalidDecimalOctet, - - #[error("too few address parts")] - TooFewAddressParts, - - #[error("too many address parts")] - TooManyAddressParts, - - #[error("too many digits in IPv6 address part")] - TooManyDigits, - - #[error("too many double-colons in IPv6 address")] - TooManyDoubleColons, +mod context; +mod error; +mod parse_host_port; +mod percent_encoded_character_decoder; +mod validate_ipv4_address; +mod validate_ipv6_address; - #[error("truncated host")] - TruncatedHost, -} +use context::Context; +use error::Error; +use parse_host_port::parse_host_port; +use percent_encoded_character_decoder::PercentEncodedCharacterDecoder; +use validate_ipv6_address::validate_ipv6_address; + +mod character_classes; +use character_classes::{ + ALPHA, + SCHEME_NOT_FIRST, + PCHAR_NOT_PCT_ENCODED, + QUERY_OR_FRAGMENT_NOT_PCT_ENCODED, + QUERY_NOT_PCT_ENCODED_WITHOUT_PLUS, + USER_INFO_NOT_PCT_ENCODED, + REG_NAME_NOT_PCT_ENCODED, +}; fn decode_element<T>( element: T, @@ -271,213 +81,6 @@ fn encode_element( .collect::<String>() } -fn validate_ipv4_address<T>(address: T) -> Result<(), Error> - where T: AsRef<str> -{ - #[derive(PartialEq)] - enum State { - NotInOctet, - ExpectDigitOrDot, - } - let mut num_groups = 0; - let mut state = State::NotInOctet; - let mut octet_buffer = String::new(); - for c in address.as_ref().chars() { - state = match state { - State::NotInOctet if DIGIT.contains(&c) => { - octet_buffer.push(c); - State::ExpectDigitOrDot - }, - - State::NotInOctet => { - return Err(Error::IllegalCharacter(Context::Ipv4Address)); - }, - - State::ExpectDigitOrDot if c == '.' => { - num_groups += 1; - if num_groups > 4 { - return Err(Error::TooManyAddressParts); - } - if octet_buffer.parse::<u8>().is_err() { - return Err(Error::InvalidDecimalOctet); - } - octet_buffer.clear(); - State::NotInOctet - }, - - State::ExpectDigitOrDot if DIGIT.contains(&c) => { - octet_buffer.push(c); - State::ExpectDigitOrDot - }, - - State::ExpectDigitOrDot => { - return Err(Error::IllegalCharacter(Context::Ipv4Address)); - }, - }; - } - if state == State::NotInOctet { - return Err(Error::TruncatedHost); - } - if !octet_buffer.is_empty() { - num_groups += 1; - if octet_buffer.parse::<u8>().is_err() { - return Err(Error::InvalidDecimalOctet); - } - } - match num_groups { - 4 => Ok(()), - n if n < 4 => Err(Error::TooFewAddressParts), - _ => Err(Error::TooManyAddressParts), - } -} - -// TODO: Clippy correctly advises us that this function needs refactoring -// because it has too many lines. We'll get back to that. -#[allow(clippy::too_many_lines)] -fn validate_ipv6_address<T>(address: T) -> Result<(), Error> - where T: AsRef<str> -{ - #[derive(PartialEq)] - enum ValidationState { - NoGroupsYet, - ColonButNoGroupsYet, - AfterDoubleColon, - InGroupNotIpv4, - InGroupCouldBeIpv4, - ColonAfterGroup, - } - let mut state = ValidationState::NoGroupsYet; - let mut num_groups = 0; - let mut num_digits = 0; - let mut double_colon_encountered = false; - let mut potential_ipv4_address_start = 0; - let mut ipv4_address_encountered = false; - let address = address.as_ref(); - for (i, c) in address.char_indices() { - state = match state { - ValidationState::NoGroupsYet => { - if c == ':' { - ValidationState::ColonButNoGroupsYet - } else if DIGIT.contains(&c) { - potential_ipv4_address_start = i; - num_digits = 1; - ValidationState::InGroupCouldBeIpv4 - } else if HEXDIG.contains(&c) { - num_digits = 1; - ValidationState::InGroupNotIpv4 - } else { - return Err(Error::IllegalCharacter(Context::Ipv6Address)); - } - }, - - ValidationState::ColonButNoGroupsYet => { - if c != ':' { - return Err(Error::IllegalCharacter(Context::Ipv6Address)); - } - double_colon_encountered = true; - ValidationState::AfterDoubleColon - }, - - ValidationState::AfterDoubleColon => { - num_digits += 1; - if num_digits > 4 { - return Err(Error::TooManyDigits); - } - if DIGIT.contains(&c) { - potential_ipv4_address_start = i; - ValidationState::InGroupCouldBeIpv4 - } else if HEXDIG.contains(&c) { - ValidationState::InGroupNotIpv4 - } else { - return Err(Error::IllegalCharacter(Context::Ipv6Address)); - } - }, - - ValidationState::InGroupNotIpv4 => { - if c == ':' { - num_digits = 0; - num_groups += 1; - ValidationState::ColonAfterGroup - } else if HEXDIG.contains(&c) { - num_digits += 1; - if num_digits > 4 { - return Err(Error::TooManyDigits); - } - ValidationState::InGroupNotIpv4 - } else { - return Err(Error::IllegalCharacter(Context::Ipv6Address)); - } - }, - - ValidationState::InGroupCouldBeIpv4 => { - if c == ':' { - num_digits = 0; - num_groups += 1; - ValidationState::ColonAfterGroup - } else if c == '.' { - ipv4_address_encountered = true; - break; - } else { - num_digits += 1; - if num_digits > 4 { - return Err(Error::TooManyDigits); - } - if DIGIT.contains(&c) { - ValidationState::InGroupCouldBeIpv4 - } else if HEXDIG.contains(&c) { - ValidationState::InGroupNotIpv4 - } else { - return Err(Error::IllegalCharacter(Context::Ipv6Address)); - } - } - }, - - ValidationState::ColonAfterGroup => { - if c == ':' { - if double_colon_encountered { - return Err(Error::TooManyDoubleColons); - } else { - double_colon_encountered = true; - ValidationState::AfterDoubleColon - } - } else if DIGIT.contains(&c) { - potential_ipv4_address_start = i; - num_digits += 1; - ValidationState::InGroupCouldBeIpv4 - } else if HEXDIG.contains(&c) { - num_digits += 1; - ValidationState::InGroupNotIpv4 - } else { - return Err(Error::IllegalCharacter(Context::Ipv6Address)); - } - }, - }; - } - if - (state == ValidationState::InGroupNotIpv4) - || (state == ValidationState::InGroupCouldBeIpv4) - { - // count trailing group - num_groups += 1; - } - if - (state == ValidationState::ColonButNoGroupsYet) - || (state == ValidationState::ColonAfterGroup) - { // trailing single colon - return Err(Error::TruncatedHost); - } - if ipv4_address_encountered { - validate_ipv4_address(&address[potential_ipv4_address_start..])?; - num_groups += 2; - } - match (double_colon_encountered, num_groups) { - (true, n) if n <= 7 => Ok(()), - (false, 8) => Ok(()), - (false, n) if n < 8 => Err(Error::TooFewAddressParts), - (_, _) => Err(Error::TooManyAddressParts), - } -} - #[derive(Clone, Debug, Default, PartialEq)] pub struct Authority { userinfo: Option<Vec<u8>>, @@ -516,6 +119,43 @@ impl Authority { pub fn userinfo(&self) -> Option<&[u8]> { self.userinfo.as_deref() } + + #[must_use = "you parsed it; don't you want the results?"] + pub fn parse<T>(authority_string: T) -> Result<Self, Error> + where T: AsRef<str> + { + // First, check if there is a UserInfo, and if so, extract it. + let (userinfo, host_port_string) = Self::parse_userinfo(authority_string.as_ref())?; + + // Next, parsing host and port from authority and path. + let (host, port) = parse_host_port(host_port_string)?; + + // Assemble authority from its parts. + Ok(Self{ + userinfo, + host, + port, + }) + } + + fn parse_userinfo(authority: &str) -> Result<(Option<Vec<u8>>, &str), Error> { + Ok(match authority.find('@') { + Some(delimiter) => ( + Some( + decode_element( + &authority[0..delimiter], + &USER_INFO_NOT_PCT_ENCODED, + Context::Userinfo + )? + ), + &authority[delimiter+1..] + ), + None => ( + None, + authority + ) + }) + } } impl std::fmt::Display for Authority { @@ -527,7 +167,7 @@ impl std::fmt::Display for Authority { match host_as_string { Ok(host_as_string) if validate_ipv6_address(&host_as_string).is_ok() => { write!(f, "[{}]", host_as_string.to_ascii_lowercase())?; - }, + }, _ => { write!(f, "{}", encode_element(&self.host, ®_NAME_NOT_PCT_ENCODED))?; } @@ -717,7 +357,7 @@ impl Uri { normalized_path } - pub fn parse<T>(uri_string: T) -> Result<Uri, Error> + pub fn parse<T>(uri_string: T) -> Result<Self, Error> where T: AsRef<str> { let (scheme, rest) = Self::parse_scheme(uri_string.as_ref())?; @@ -729,7 +369,7 @@ impl Uri { let (authority, path) = Self::split_authority_from_path_and_parse_them(authority_and_path_string)?; let (fragment, possible_query) = Self::parse_fragment(query_and_or_fragment)?; let query = Self::parse_query(possible_query)?; - Ok(Uri{ + Ok(Self{ scheme, authority, path, @@ -738,165 +378,6 @@ impl Uri { }) } - // TODO: Needs refactoring, as Clippy dutifully told us. - #[allow(clippy::too_many_lines)] - fn parse_authority<T>(authority_string: T) -> Result<Authority, Error> - where T: AsRef<str> - { - // These are the various states for the state machine implemented - // below to correctly split up and validate the URI substring - // containing the host and potentially a port number as well. - #[derive(PartialEq)] - enum HostParsingState { - NotIpLiteral, - PercentEncodedCharacter, - Ipv6Address, - IpvFutureNumber, - IpvFutureBody, - GarbageCheck, - Port, - }; - - // First, check if there is a UserInfo, and if so, extract it. - let authority_string = authority_string.as_ref(); - let (userinfo, mut host_port_string) = match authority_string.find('@') { - Some(user_info_delimiter) => ( - Some( - decode_element( - &authority_string[0..user_info_delimiter], - &USER_INFO_NOT_PCT_ENCODED, - Context::Userinfo - )? - ), - &authority_string[user_info_delimiter+1..] - ), - None => ( - None, - authority_string - ) - }; - - // Next, parsing host and port from authority and path. - let mut port_string = String::new(); - let mut host = Vec::<u8>::new(); - let (mut host_parsing_state, host_is_reg_name) = if host_port_string.starts_with("[v") { - host_port_string = &host_port_string[2..]; - host.push(b'v'); - (HostParsingState::IpvFutureNumber, false) - } else if host_port_string.starts_with('[') { - host_port_string = &host_port_string[1..]; - (HostParsingState::Ipv6Address, false) - } else { - (HostParsingState::NotIpLiteral, true) - }; - let mut ipv6_address = String::new(); - let mut pec_decoder = PercentEncodedCharacterDecoder::new(); - for c in host_port_string.chars() { - host_parsing_state = match host_parsing_state { - HostParsingState::NotIpLiteral => { - if c == '%' { - HostParsingState::PercentEncodedCharacter - } else if c == ':' { - HostParsingState::Port - } else if REG_NAME_NOT_PCT_ENCODED.contains(&c) { - host.push(u8::try_from(c as u32).unwrap()); - host_parsing_state - } else { - return Err(Error::IllegalCharacter(Context::Host)); - } - }, - - HostParsingState::PercentEncodedCharacter => { - if let Some(ci) = pec_decoder.next(c)? { - host.push(ci); - HostParsingState::NotIpLiteral - } else { - host_parsing_state - } - }, - - HostParsingState::Ipv6Address => { - if c == ']' { - validate_ipv6_address(&ipv6_address)?; - host = ipv6_address.chars().map( - |c| u8::try_from(c as u32).unwrap() - ).collect(); - HostParsingState::GarbageCheck - } else { - ipv6_address.push(c); - host_parsing_state - } - }, - - HostParsingState::IpvFutureNumber => { - if c == '.' { - host_parsing_state = HostParsingState::IpvFutureBody - } else if c == ']' { - return Err(Error::TruncatedHost); - } else if !HEXDIG.contains(&c) { - return Err(Error::IllegalCharacter(Context::IpvFuture)); - } - host.push(u8::try_from(c as u32).unwrap()); - host_parsing_state - }, - - HostParsingState::IpvFutureBody => { - if c == ']' { - HostParsingState::GarbageCheck - } else if IPV_FUTURE_LAST_PART.contains(&c) { - host.push(u8::try_from(c as u32).unwrap()); - host_parsing_state - } else { - return Err(Error::IllegalCharacter(Context::IpvFuture)); - } - }, - - HostParsingState::GarbageCheck => { - // illegal to have anything else, unless it's a colon, - // in which case it's a port delimiter - if c == ':' { - HostParsingState::Port - } else { - return Err(Error::IllegalCharacter(Context::Host)); - } - }, - - HostParsingState::Port => { - port_string.push(c); - host_parsing_state - }, - } - } - if - (host_parsing_state != HostParsingState::NotIpLiteral) - && (host_parsing_state != HostParsingState::GarbageCheck) - && (host_parsing_state != HostParsingState::Port) - { - // truncated or ended early - return Err(Error::TruncatedHost); - } - if host_is_reg_name { - host.make_ascii_lowercase(); - } - let port = if port_string.is_empty() { - None - } else { - match port_string.parse::<u16>() { - Ok(port) => { - Some(port) - }, - Err(error) => { - return Err(Error::IllegalPortNumber(error)); - } - } - }; - Ok(Authority{ - userinfo, - host, - port, - }) - } - fn parse_fragment(query_and_or_fragment: &str) -> Result<(Option<Vec<u8>>, &str), Error> { if let Some(fragment_delimiter) = query_and_or_fragment.find('#') { let fragment = Self::decode_query_or_fragment( @@ -1176,7 +657,7 @@ impl Uri { let path_string = &authority_and_path_string[authority_end..]; // Parse the elements inside the authority string. - let authority = Self::parse_authority(authority_string)?; + let authority = Authority::parse(authority_string)?; let path = if path_string.is_empty() { vec![vec![]] } else { diff --git a/src/parse_host_port.rs b/src/parse_host_port.rs new file mode 100644 index 0000000..2dcdde9 --- /dev/null +++ b/src/parse_host_port.rs @@ -0,0 +1,202 @@ +#![warn(clippy::pedantic)] + +use std::convert::TryFrom; + +use super::character_classes::{ + HEXDIG, + IPV_FUTURE_LAST_PART, + REG_NAME_NOT_PCT_ENCODED, +}; +use super::context::Context; +use super::error::Error; +use super::percent_encoded_character_decoder::PercentEncodedCharacterDecoder; +use super::validate_ipv6_address::validate_ipv6_address; + +struct Shared { + host: Vec<u8>, + host_is_reg_name: bool, + ipv6_address: String, + pec_decoder: PercentEncodedCharacterDecoder, + port_string: String, +} + +enum State { + NotIpLiteral(Shared), + PercentEncodedCharacter(Shared), + Ipv6Address(Shared), + IpvFutureNumber(Shared), + IpvFutureBody(Shared), + GarbageCheck(Shared), + Port(Shared), +} + +impl State{ + fn finalize(self) -> Result<(Vec<u8>, Option<u16>), Error> { + match self { + Self::PercentEncodedCharacter(_) + | Self::Ipv6Address(_) + | Self::IpvFutureNumber(_) + | Self::IpvFutureBody(_) => { + // truncated or ended early + Err(Error::TruncatedHost) + }, + Self::NotIpLiteral(state) + | Self::GarbageCheck(state) + | Self::Port(state) => { + let mut state = state; + if state.host_is_reg_name { + state.host.make_ascii_lowercase(); + } + let port = if state.port_string.is_empty() { + None + } else { + match state.port_string.parse::<u16>() { + Ok(port) => { + Some(port) + }, + Err(error) => { + return Err(Error::IllegalPortNumber(error)); + } + } + }; + Ok((state.host, port)) + }, + } + } + + fn new(host_port_string: &str) -> (Self, &str) { + let mut shared = Shared{ + host: Vec::<u8>::new(), + host_is_reg_name: false, + ipv6_address: String::new(), + pec_decoder: PercentEncodedCharacterDecoder::new(), + port_string: String::new(), + }; + let mut host_port_string = host_port_string; + if host_port_string.starts_with("[v") { + host_port_string = &host_port_string[2..]; + shared.host.push(b'v'); + ( + Self::IpvFutureNumber(shared), + host_port_string + ) + } else if host_port_string.starts_with('[') { + host_port_string = &host_port_string[1..]; + ( + Self::Ipv6Address(shared), + host_port_string + ) + } else { + shared.host_is_reg_name = true; + ( + Self::NotIpLiteral(shared), + host_port_string + ) + } + } + + fn next(self, c: char) -> Result<Self, Error> { + match self { + Self::NotIpLiteral(state) => Self::next_not_ip_literal(state, c), + Self::PercentEncodedCharacter(state) => Self::next_percent_encoded_character(state, c), + Self::Ipv6Address(state) => Self::next_ipv6_address(state, c), + Self::IpvFutureNumber(state) => Self::next_ipv_future_number(state, c), + Self::IpvFutureBody(state) => Self::next_ipv_future_body(state, c), + Self::GarbageCheck(state) => Self::next_garbage_check(state, c), + Self::Port(state) => Self::next_port(state, c), + } + } + + fn next_not_ip_literal(state: Shared, c: char) -> Result<Self, Error> { + let mut state = state; + if c == '%' { + Ok(Self::PercentEncodedCharacter(state)) + } else if c == ':' { + Ok(Self::Port(state)) + } else if REG_NAME_NOT_PCT_ENCODED.contains(&c) { + state.host.push(u8::try_from(c as u32).unwrap()); + Ok(Self::NotIpLiteral(state)) + } else { + Err(Error::IllegalCharacter(Context::Host)) + } + } + + fn next_percent_encoded_character(state: Shared, c: char) -> Result<Self, Error> { + let mut state = state; + if let Some(ci) = state.pec_decoder.next(c)? { + state.host.push(ci); + Ok(Self::NotIpLiteral(state)) + } else { + Ok(Self::PercentEncodedCharacter(state)) + } + } + + fn next_ipv6_address(state: Shared, c: char) -> Result<Self, Error> { + let mut state = state; + if c == ']' { + validate_ipv6_address(&state.ipv6_address)?; + state.host = state.ipv6_address.chars().map( + |c| u8::try_from(c as u32).unwrap() + ).collect(); + Ok(Self::GarbageCheck(state)) + } else { + state.ipv6_address.push(c); + Ok(Self::Ipv6Address(state)) + } + } + + fn next_ipv_future_number(state: Shared, c: char) -> Result<Self, Error> { + let mut state = state; + if c == '.' { + state.host.push(b'.'); + Ok(Self::IpvFutureBody(state)) + } else if c == ']' { + Err(Error::TruncatedHost) + } else if HEXDIG.contains(&c) { + state.host.push(u8::try_from(c as u32).unwrap()); + Ok(Self::IpvFutureNumber(state)) + } else { + Err(Error::IllegalCharacter(Context::IpvFuture)) + } + } + + fn next_ipv_future_body(state: Shared, c: char) -> Result<Self, Error> { + let mut state = state; + if c == ']' { + Ok(Self::GarbageCheck(state)) + } else if IPV_FUTURE_LAST_PART.contains(&c) { + state.host.push(u8::try_from(c as u32).unwrap()); + Ok(Self::IpvFutureBody(state)) + } else { + Err(Error::IllegalCharacter(Context::IpvFuture)) + } + } + + fn next_garbage_check(state: Shared, c: char) -> Result<Self, Error> { + // illegal to have anything else, unless it's a colon, + // in which case it's a port delimiter + if c == ':' { + Ok(Self::Port(state)) + } else { + Err(Error::IllegalCharacter(Context::Host)) + } + } + + fn next_port(state: Shared, c: char) -> Result<Self, Error> { + let mut state = state; + state.port_string.push(c); + Ok(Self::Port(state)) + } +} + +pub fn parse_host_port<T>(host_port_string: T) -> Result<(Vec<u8>, Option<u16>), Error> + where T: AsRef<str> +{ + let (machine, host_port_string) = State::new(host_port_string.as_ref()); + host_port_string + .chars() + .try_fold(machine, |machine, c| { + machine.next(c) + })? + .finalize() +} diff --git a/src/percent_encoded_character_decoder.rs b/src/percent_encoded_character_decoder.rs index b07be50..8a11c8d 100644 --- a/src/percent_encoded_character_decoder.rs +++ b/src/percent_encoded_character_decoder.rs @@ -2,11 +2,7 @@ use std::convert::TryFrom; -#[derive(Debug, Clone, thiserror::Error, PartialEq)] -pub enum Error { - #[error("illegal character")] - IllegalCharacter, -} +use super::error::Error; pub struct PercentEncodedCharacterDecoder { decoded_character: u8, @@ -50,7 +46,7 @@ impl PercentEncodedCharacterDecoder { self.decoded_character += u8::try_from(ci).unwrap(); } else { self.reset(); - return Err(Error::IllegalCharacter); + return Err(Error::IllegalPercentEncoding); } Ok(()) } diff --git a/src/validate_ipv4_address.rs b/src/validate_ipv4_address.rs new file mode 100644 index 0000000..1b9f7b2 --- /dev/null +++ b/src/validate_ipv4_address.rs @@ -0,0 +1,96 @@ +#![warn(clippy::pedantic)] + +use super::character_classes::{ + DIGIT, +}; +use super::context::Context; +use super::error::Error; + +struct Shared { + num_groups: usize, + octet_buffer: String, +} + +enum State { + NotInOctet(Shared), + ExpectDigitOrDot(Shared), +} + +impl State { + fn finalize(self) -> Result<(), Error> { + match self { + Self::NotInOctet(_) => Err(Error::TruncatedHost), + Self::ExpectDigitOrDot(state) => Self::finalize_expect_digit_or_dot(state), + } + } + + fn finalize_expect_digit_or_dot(state: Shared) -> Result<(), Error> { + let mut state = state; + if !state.octet_buffer.is_empty() { + state.num_groups += 1; + if state.octet_buffer.parse::<u8>().is_err() { + return Err(Error::InvalidDecimalOctet); + } + } + match state.num_groups { + 4 => Ok(()), + n if n < 4 => Err(Error::TooFewAddressParts), + _ => Err(Error::TooManyAddressParts), + } + } + + fn new() -> Self { + Self::NotInOctet(Shared{ + num_groups: 0, + octet_buffer: String::new(), + }) + } + + fn next(self, c: char) -> Result<Self, Error> { + match self { + Self::NotInOctet(state) => Self::next_not_in_octet(state, c), + Self::ExpectDigitOrDot(state) => Self::next_expect_digit_or_dot(state, c), + } + } + + fn next_not_in_octet(state: Shared, c: char) -> Result<Self, Error> { + let mut state = state; + if DIGIT.contains(&c) { + state.octet_buffer.push(c); + Ok(Self::ExpectDigitOrDot(state)) + } else { + Err(Error::IllegalCharacter(Context::Ipv4Address)) + } + } + + fn next_expect_digit_or_dot(state: Shared, c: char)-> Result<Self, Error> { + let mut state = state; + if c == '.' { + state.num_groups += 1; + if state.num_groups > 4 { + return Err(Error::TooManyAddressParts); + } + if state.octet_buffer.parse::<u8>().is_err() { + return Err(Error::InvalidDecimalOctet); + } + state.octet_buffer.clear(); + Ok(Self::NotInOctet(state)) + } else if DIGIT.contains(&c) { + state.octet_buffer.push(c); + Ok(Self::ExpectDigitOrDot(state)) + } else { + Err(Error::IllegalCharacter(Context::Ipv4Address)) + } + } +} + +pub fn validate_ipv4_address<T>(address: T) -> Result<(), Error> + where T: AsRef<str> +{ + address.as_ref() + .chars() + .try_fold(State::new(), |machine, c| { + machine.next(c) + })? + .finalize() +} diff --git a/src/validate_ipv6_address.rs b/src/validate_ipv6_address.rs new file mode 100644 index 0000000..eb3900d --- /dev/null +++ b/src/validate_ipv6_address.rs @@ -0,0 +1,212 @@ +#![warn(clippy::pedantic)] + +use super::character_classes::{ + DIGIT, + HEXDIG, +}; +use super::context::Context; +use super::error::Error; +use super::validate_ipv4_address::validate_ipv4_address; + +enum MachineExitStatus<'a> { + Error(Error), + Ipv4Trailer(Shared<'a>), +} + +impl<'a> From<Error> for MachineExitStatus<'a> { + fn from(error: Error) -> Self { + MachineExitStatus::Error(error) + } +} + +struct Shared<'a> { + address: &'a str, + num_groups: usize, + num_digits: usize, + double_colon_encountered: bool, + potential_ipv4_address_start: usize, +} + +enum State<'a> { + NoGroupsYet(Shared<'a>), + ColonButNoGroupsYet(Shared<'a>), + AfterDoubleColon(Shared<'a>), + InGroupNotIpv4(Shared<'a>), + InGroupCouldBeIpv4(Shared<'a>), + InGroupIpv4(Shared<'a>), + ColonAfterGroup(Shared<'a>), +} + +impl<'a> State<'a> { + fn finalize(mut self) -> Result<(), Error> { + match &mut self { + Self::InGroupNotIpv4(state) + | Self::InGroupCouldBeIpv4(state) => { + // count trailing group + state.num_groups += 1; + }, + Self::InGroupIpv4(state) => { + validate_ipv4_address(&state.address[state.potential_ipv4_address_start..])?; + state.num_groups += 2; + }, + _ => {}, + }; + match self { + Self::ColonButNoGroupsYet(_) + | Self::ColonAfterGroup(_) => Err(Error::TruncatedHost), + + Self::AfterDoubleColon(state) + | Self::InGroupNotIpv4(state) + | Self::InGroupCouldBeIpv4(state) + | Self::InGroupIpv4(state) + | Self::NoGroupsYet(state) => { + match (state.double_colon_encountered, state.num_groups) { + (true, n) if n <= 7 => Ok(()), + (false, 8) => Ok(()), + (false, n) if n < 8 => Err(Error::TooFewAddressParts), + (_, _) => Err(Error::TooManyAddressParts), + } + } + } + } + + fn new(address: &'a str) -> Self { + Self::NoGroupsYet(Shared{ + address, + num_groups: 0, + num_digits: 0, + double_colon_encountered: false, + potential_ipv4_address_start: 0, + }) + } + + fn next(self, i: usize, c: char) -> Result<Self, MachineExitStatus<'a>> { + match self { + Self::NoGroupsYet(state) => Self::next_no_groups_yet(state, i, c), + Self::ColonButNoGroupsYet(state) => Self::next_colon_but_no_groups_yet(state, c), + Self::AfterDoubleColon(state) => Self::next_after_double_colon(state, i, c), + Self::InGroupNotIpv4(state) => Self::next_in_group_not_ipv4(state, c), + Self::InGroupCouldBeIpv4(state) => Self::next_in_group_could_be_ipv4(state, c), + Self::InGroupIpv4(state) => Ok(Self::InGroupIpv4(state)), + Self::ColonAfterGroup(state) => Self::next_colon_after_group(state, i, c), + } + } + + fn next_no_groups_yet(state: Shared<'a>, i: usize, c: char) -> Result<Self, MachineExitStatus> { + let mut state = state; + if c == ':' { + Ok(Self::ColonButNoGroupsYet(state)) + } else if DIGIT.contains(&c) { + state.potential_ipv4_address_start = i; + state.num_digits = 1; + Ok(Self::InGroupCouldBeIpv4(state)) + } else if HEXDIG.contains(&c) { + state.num_digits = 1; + Ok(Self::InGroupNotIpv4(state)) + } else { + Err(Error::IllegalCharacter(Context::Ipv6Address).into()) + } + } + + fn next_colon_but_no_groups_yet(state: Shared<'a>, c: char) -> Result<Self, MachineExitStatus> { + let mut state = state; + if c == ':' { + state.double_colon_encountered = true; + Ok(Self::AfterDoubleColon(state)) + } else { + Err(Error::IllegalCharacter(Context::Ipv6Address).into()) + } + } + + fn next_after_double_colon(state: Shared<'a>, i: usize, c: char) -> Result<Self, MachineExitStatus> { + let mut state = state; + state.num_digits += 1; + if state.num_digits > 4 { + Err(Error::TooManyDigits.into()) + } else if DIGIT.contains(&c) { + state.potential_ipv4_address_start = i; + Ok(Self::InGroupCouldBeIpv4(state)) + } else if HEXDIG.contains(&c) { + Ok(Self::InGroupNotIpv4(state)) + } else { + Err(Error::IllegalCharacter(Context::Ipv6Address).into()) + } + } + + fn next_in_group_not_ipv4(state: Shared<'a>, c: char) -> Result<Self, MachineExitStatus> { + let mut state = state; + if c == ':' { + state.num_digits = 0; + state.num_groups += 1; + Ok(Self::ColonAfterGroup(state)) + } else if HEXDIG.contains(&c) { + state.num_digits += 1; + if state.num_digits > 4 { + Err(Error::TooManyDigits.into()) + } else { + Ok(Self::InGroupNotIpv4(state)) + } + } else { + Err(Error::IllegalCharacter(Context::Ipv6Address).into()) + } + } + + fn next_in_group_could_be_ipv4(state: Shared<'a>, c: char) -> Result<Self, MachineExitStatus> { + let mut state = state; + if c == ':' { + state.num_digits = 0; + state.num_groups += 1; + Ok(Self::ColonAfterGroup(state)) + } else if c == '.' { + Err(MachineExitStatus::Ipv4Trailer(state)) + } else { + state.num_digits += 1; + if state.num_digits > 4 { + Err(Error::TooManyDigits.into()) + } else if DIGIT.contains(&c) { + Ok(Self::InGroupCouldBeIpv4(state)) + } else if HEXDIG.contains(&c) { + Ok(Self::InGroupNotIpv4(state)) + } else { + Err(Error::IllegalCharacter(Context::Ipv6Address).into()) + } + } + } + + fn next_colon_after_group(state: Shared<'a>, i: usize, c: char) -> Result<Self, MachineExitStatus> { + let mut state = state; + if c == ':' { + if state.double_colon_encountered { + Err(Error::TooManyDoubleColons.into()) + } else { + state.double_colon_encountered = true; + Ok(Self::AfterDoubleColon(state)) + } + } else if DIGIT.contains(&c) { + state.potential_ipv4_address_start = i; + state.num_digits += 1; + Ok(Self::InGroupCouldBeIpv4(state)) + } else if HEXDIG.contains(&c) { + state.num_digits += 1; + Ok(Self::InGroupNotIpv4(state)) + } else { + Err(Error::IllegalCharacter(Context::Ipv6Address).into()) + } + } +} + +pub fn validate_ipv6_address<T>(address: T) -> Result<(), Error> + where T: AsRef<str> +{ + let address = address.as_ref(); + address + .char_indices() + .try_fold(State::new(address), |machine, (i, c)| { + machine.next(i, c) + }) + .or_else(|machine_exit_status| match machine_exit_status { + MachineExitStatus::Ipv4Trailer(state) => Ok(State::InGroupIpv4(state)), + MachineExitStatus::Error(error) => Err(error) + })? + .finalize() +} |