diff options
-rw-r--r-- | src/lib.rs | 165 |
1 files changed, 125 insertions, 40 deletions
@@ -154,9 +154,59 @@ lazy_static! { } #[derive(Debug, Clone, PartialEq)] +pub enum Context { + Fragment, + Host, + Ipv4Address, + Ipv6Address, + IpvFuture, + Path, + Query, + Scheme, + Userinfo, +} + +impl std::fmt::Display for Context { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Context::Fragment => { + write!(f, "fragment") + }, + Context::Host => { + write!(f, "host") + }, + Context::Ipv4Address => { + write!(f, "IPv4 address") + }, + Context::Ipv6Address => { + write!(f, "IPv6 address") + }, + Context::IpvFuture => { + write!(f, "IPvFuture") + }, + Context::Path => { + write!(f, "path") + }, + Context::Query => { + write!(f, "query") + }, + Context::Scheme => { + write!(f, "scheme") + }, + Context::Userinfo => { + write!(f, "user info") + }, + } + } +} + +#[derive(Debug, Clone, PartialEq)] pub enum Error { + CannotExpressAsUtf8(std::string::FromUtf8Error), EmptyScheme, - IllegalCharacter, + IllegalCharacter(Context), + IllegalPercentEncoding(percent_encoded_character_decoder::Error), + IllegalPortNumber(std::num::ParseIntError), InvalidDecimalOctet, TooFewAddressParts, TooManyAddressParts, @@ -168,12 +218,21 @@ pub enum Error { impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { + Error::CannotExpressAsUtf8(_) => { + write!(f, "URI contains non-UTF8 sequences") + }, Error::EmptyScheme => { write!(f, "scheme expected but missing") }, - Error::IllegalCharacter => { - write!(f, "illegal character") + Error::IllegalCharacter(context) => { + write!(f, "illegal character in {}", context) + }, + Error::IllegalPercentEncoding(_) => { + write!(f, "illegal percent encoding") }, + Error::IllegalPortNumber(_) => { + write!(f, "illegal port number") + } Error::TruncatedHost => { write!(f, "truncated host") }, @@ -198,7 +257,12 @@ impl std::fmt::Display for Error { impl std::error::Error for Error { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - None + match self { + Error::CannotExpressAsUtf8(source) => Some(source), + Error::IllegalPercentEncoding(source) => Some(source), + Error::IllegalPortNumber(source) => Some(source), + _ => None, + } } } @@ -206,15 +270,15 @@ impl From<percent_encoded_character_decoder::Error> for Error { fn from(error: percent_encoded_character_decoder::Error) -> Self { match error { percent_encoded_character_decoder::Error::IllegalCharacter => { - Error::IllegalCharacter + Error::IllegalPercentEncoding(error) }, } } } impl From<std::string::FromUtf8Error> for Error { - fn from(_: std::string::FromUtf8Error) -> Self { - Error::IllegalCharacter + fn from(error: std::string::FromUtf8Error) -> Self { + Error::CannotExpressAsUtf8(error) } } @@ -264,7 +328,7 @@ fn validate_ipv4_address(address: &str) -> Result<(), Error> { }, State::NotInOctet => { - return Err(Error::IllegalCharacter); + return Err(Error::IllegalCharacter(Context::Ipv4Address)); }, State::ExpectDigitOrDot if c == '.' => { @@ -287,7 +351,7 @@ fn validate_ipv4_address(address: &str) -> Result<(), Error> { }, State::ExpectDigitOrDot => { - return Err(Error::IllegalCharacter); + return Err(Error::IllegalCharacter(Context::Ipv4Address)); }, }; } @@ -339,13 +403,13 @@ fn validate_ipv6_address(address: &str) -> Result<(), Error> { num_digits = 1; ValidationState::InGroupNotIpv4 } else { - return Err(Error::IllegalCharacter); + return Err(Error::IllegalCharacter(Context::Ipv6Address)); } }, ValidationState::ColonButNoGroupsYet => { if c != ':' { - return Err(Error::IllegalCharacter); + return Err(Error::IllegalCharacter(Context::Ipv6Address)); } double_colon_encountered = true; ValidationState::AfterDoubleColon @@ -362,7 +426,7 @@ fn validate_ipv6_address(address: &str) -> Result<(), Error> { } else if HEXDIG.contains(&c) { ValidationState::InGroupNotIpv4 } else { - return Err(Error::IllegalCharacter); + return Err(Error::IllegalCharacter(Context::Ipv6Address)); } }, @@ -378,7 +442,7 @@ fn validate_ipv6_address(address: &str) -> Result<(), Error> { } ValidationState::InGroupNotIpv4 } else { - return Err(Error::IllegalCharacter); + return Err(Error::IllegalCharacter(Context::Ipv6Address)); } }, @@ -400,7 +464,7 @@ fn validate_ipv6_address(address: &str) -> Result<(), Error> { } else if HEXDIG.contains(&c) { ValidationState::InGroupNotIpv4 } else { - return Err(Error::IllegalCharacter); + return Err(Error::IllegalCharacter(Context::Ipv6Address)); } } }, @@ -421,7 +485,7 @@ fn validate_ipv6_address(address: &str) -> Result<(), Error> { num_digits += 1; ValidationState::InGroupNotIpv4 } else { - return Err(Error::IllegalCharacter); + return Err(Error::IllegalCharacter(Context::Ipv6Address)); } }, }; @@ -548,7 +612,7 @@ impl Uri { &SCHEME_NOT_FIRST }; if !valid_characters.contains(&c) { - return Err(Error::IllegalCharacter); + return Err(Error::IllegalCharacter(Context::Scheme)); } is_first_character = false; } @@ -576,7 +640,8 @@ impl Uri { // TODO: look into making element type more flexible fn decode_element( element: &str, - allowed_characters: &'static HashSet<char> + allowed_characters: &'static HashSet<char>, + context: Context ) -> Result<Vec<u8>, Error> { let mut decoding_pec = false; let mut output = Vec::<u8>::new(); @@ -599,16 +664,20 @@ impl Uri { } else if allowed_characters.contains(&c) { output.push(c as u8); } else { - return Err(Error::IllegalCharacter); + return Err(Error::IllegalCharacter(context)); } } Ok(output) } - fn decode_query_or_fragment(query_or_fragment: &str) -> Result<Vec<u8>, Error> { + fn decode_query_or_fragment( + query_or_fragment: &str, + context: Context, + ) -> Result<Vec<u8>, Error> { Self::decode_element( query_or_fragment, - &QUERY_OR_FRAGMENT_NOT_PCT_ENCODED + &QUERY_OR_FRAGMENT_NOT_PCT_ENCODED, + context ) } @@ -758,7 +827,8 @@ impl Uri { Some( Self::decode_element( &authority_string[0..user_info_delimiter], - &USER_INFO_NOT_PCT_ENCODED + &USER_INFO_NOT_PCT_ENCODED, + Context::Userinfo )? ), &authority_string[user_info_delimiter+1..] @@ -795,7 +865,7 @@ impl Uri { host.push(u8::try_from(c as u32).unwrap()); host_parsing_state } else { - return Err(Error::IllegalCharacter); + return Err(Error::IllegalCharacter(Context::Host)); } }, @@ -827,7 +897,7 @@ impl Uri { } else if c == ']' { return Err(Error::TruncatedHost); } else if !HEXDIG.contains(&c) { - return Err(Error::IllegalCharacter); + return Err(Error::IllegalCharacter(Context::IpvFuture)); } host.push(u8::try_from(c as u32).unwrap()); host_parsing_state @@ -840,7 +910,7 @@ impl Uri { host.push(u8::try_from(c as u32).unwrap()); host_parsing_state } else { - return Err(Error::IllegalCharacter); + return Err(Error::IllegalCharacter(Context::IpvFuture)); } }, @@ -850,7 +920,7 @@ impl Uri { if c == ':' { HostParsingState::Port } else { - return Err(Error::IllegalCharacter); + return Err(Error::IllegalCharacter(Context::Host)); } }, @@ -879,10 +949,15 @@ impl Uri { } let port = if port_string.is_empty() { None - } else if let Ok(port) = port_string.parse::<u16>() { - Some(port) } else { - return Err(Error::IllegalCharacter); + match port_string.parse::<u16>() { + Ok(port) => { + Some(port) + }, + Err(error) => { + return Err(Error::IllegalPortNumber(error)); + } + } }; Ok(Authority{ userinfo, @@ -894,7 +969,8 @@ impl Uri { fn parse_fragment(query_and_or_fragment: &str) -> Result<(Option<Vec<u8>>, &str), Error> { if let Some(fragment_delimiter) = query_and_or_fragment.find('#') { let fragment = Self::decode_query_or_fragment( - &query_and_or_fragment[fragment_delimiter+1..] + &query_and_or_fragment[fragment_delimiter+1..], + Context::Fragment )?; Ok(( Some(fragment), @@ -939,7 +1015,11 @@ impl Uri { } path_encoded.into_iter().map( |segment| { - Self::decode_element(&segment, &PCHAR_NOT_PCT_ENCODED) + Self::decode_element( + &segment, + &PCHAR_NOT_PCT_ENCODED, + Context::Path + ) } ) .collect::<Result<Vec<Vec<u8>>, Error>>() @@ -949,7 +1029,10 @@ impl Uri { if query_and_or_fragment.is_empty() { Ok(None) } else { - let query = Self::decode_query_or_fragment(&query_and_or_fragment[1..])?; + let query = Self::decode_query_or_fragment( + &query_and_or_fragment[1..], + Context::Query + )?; Ok(Some(query)) } } @@ -1304,7 +1387,7 @@ mod tests { #[test] fn parse_from_string_bad_port_number_too_big() { let uri = Uri::parse("http://www.example.com:65536/foo/bar"); - assert!(uri.is_err()); + assert!(matches!(uri, Err(Error::IllegalPortNumber(_)))); } #[test] @@ -1965,19 +2048,16 @@ mod tests { }; let test_vectors = [ TestVector{ uri_string: "http://[::fFfF::1]", expected_error: Error::TooManyDoubleColons }, - TestVector{ uri_string: "http://[::ffff:1.2.x.4]/", expected_error: Error::IllegalCharacter }, + TestVector{ uri_string: "http://[::ffff:1.2.x.4]/", expected_error: Error::IllegalCharacter(Context::Ipv4Address) }, TestVector{ uri_string: "http://[::ffff:1.2.3.4.8]/", expected_error: Error::TooManyAddressParts }, TestVector{ uri_string: "http://[::ffff:1.2.3]/", expected_error: Error::TooFewAddressParts }, TestVector{ uri_string: "http://[::ffff:1.2.3.]/", expected_error: Error::TruncatedHost }, TestVector{ uri_string: "http://[::ffff:1.2.3.256]/", expected_error: Error::InvalidDecimalOctet }, - TestVector{ uri_string: "http://[::fxff:1.2.3.4]/", expected_error: Error::IllegalCharacter }, - TestVector{ uri_string: "http://[::ffff:1.2.3.-4]/", expected_error: Error::IllegalCharacter }, - TestVector{ uri_string: "http://[::ffff:1.2.3. 4]/", expected_error: Error::IllegalCharacter }, - TestVector{ uri_string: "http://[::ffff:1.2.3.4 ]/", expected_error: Error::IllegalCharacter }, + TestVector{ uri_string: "http://[::fxff:1.2.3.4]/", expected_error: Error::IllegalCharacter(Context::Ipv6Address) }, + TestVector{ uri_string: "http://[::ffff:1.2.3.-4]/", expected_error: Error::IllegalCharacter(Context::Ipv4Address) }, + TestVector{ uri_string: "http://[::ffff:1.2.3. 4]/", expected_error: Error::IllegalCharacter(Context::Ipv4Address) }, + TestVector{ uri_string: "http://[::ffff:1.2.3.4 ]/", expected_error: Error::IllegalCharacter(Context::Ipv4Address) }, TestVector{ uri_string: "http://[::ffff:1.2.3.4/", expected_error: Error::TruncatedHost }, - TestVector{ uri_string: "http://::ffff:1.2.3.4]/", expected_error: Error::IllegalCharacter }, - TestVector{ uri_string: "http://::ffff:a.2.3.4]/", expected_error: Error::IllegalCharacter }, - TestVector{ uri_string: "http://::ffff:1.a.3.4]/", expected_error: Error::IllegalCharacter }, TestVector{ uri_string: "http://[2001:db8:85a3:8d3:1319:8a2e:370:7348:0000]/", expected_error: Error::TooManyAddressParts }, TestVector{ uri_string: "http://[2001:db8:85a3:8d3:1319:8a2e:370:7348::1]/", expected_error: Error::TooManyAddressParts }, TestVector{ uri_string: "http://[2001:db8:85a3::8a2e:0:]/", expected_error: Error::TruncatedHost }, @@ -1995,6 +2075,11 @@ mod tests { test_vector.uri_string ); } + + // This is a special case because std::num doesn't trust that we're + // good enough to make our own ParseIntError values. FeelsBadMan + let uri = Uri::parse("http://::ffff:1.2.3.4]/"); + assert!(matches!(uri, Err(Error::IllegalPortNumber(_)))); } #[test] |