#![warn(clippy::pedantic)] use super::character_classes::{ DIGIT, }; use super::context::Context; use super::error::Error; struct Shared { num_groups: usize, octet_buffer: String, } enum State { NotInOctet(Shared), ExpectDigitOrDot(Shared), } impl State { fn finalize(self) -> Result<(), Error> { match self { Self::NotInOctet(_) => Err(Error::TruncatedHost), Self::ExpectDigitOrDot(state) => Self::finalize_expect_digit_or_dot(state), } } fn finalize_expect_digit_or_dot(state: Shared) -> Result<(), Error> { let mut state = state; if !state.octet_buffer.is_empty() { state.num_groups += 1; if state.octet_buffer.parse::().is_err() { return Err(Error::InvalidDecimalOctet); } } match state.num_groups { 4 => Ok(()), n if n < 4 => Err(Error::TooFewAddressParts), _ => Err(Error::TooManyAddressParts), } } fn new() -> Self { Self::NotInOctet(Shared{ num_groups: 0, octet_buffer: String::new(), }) } fn next(self, c: char) -> Result { match self { Self::NotInOctet(state) => Self::next_not_in_octet(state, c), Self::ExpectDigitOrDot(state) => Self::next_expect_digit_or_dot(state, c), } } fn next_not_in_octet(state: Shared, c: char) -> Result { let mut state = state; if DIGIT.contains(&c) { state.octet_buffer.push(c); Ok(Self::ExpectDigitOrDot(state)) } else { Err(Error::IllegalCharacter(Context::Ipv4Address)) } } fn next_expect_digit_or_dot(state: Shared, c: char)-> Result { let mut state = state; if c == '.' { state.num_groups += 1; if state.num_groups > 4 { return Err(Error::TooManyAddressParts); } if state.octet_buffer.parse::().is_err() { return Err(Error::InvalidDecimalOctet); } state.octet_buffer.clear(); Ok(Self::NotInOctet(state)) } else if DIGIT.contains(&c) { state.octet_buffer.push(c); Ok(Self::ExpectDigitOrDot(state)) } else { Err(Error::IllegalCharacter(Context::Ipv4Address)) } } } pub fn validate_ipv4_address(address: T) -> Result<(), Error> where T: AsRef { address.as_ref() .chars() .try_fold(State::new(), |machine, c| { machine.next(c) })? .finalize() } #[cfg(test)] mod tests { use super::*; #[test] fn good() { let test_vectors = [ "0.0.0.0", "1.2.3.0", "1.2.3.4", "1.2.3.255", "1.2.255.4", "1.255.3.4", "255.2.3.4", "255.255.255.255", ]; for test_vector in &test_vectors { assert!(validate_ipv4_address(*test_vector).is_ok()); } } #[test] fn bad() { named_tuple!( struct TestVector { address_string: &'static str, expected_error: Error, } ); let test_vectors: &[TestVector] = &[ ("1.2.x.4", Error::IllegalCharacter(Context::Ipv4Address)).into(), ("1.2.3.4.8", Error::TooManyAddressParts).into(), ("1.2.3", Error::TooFewAddressParts).into(), ("1.2.3.", Error::TruncatedHost).into(), ("1.2.3.256", Error::InvalidDecimalOctet).into(), ("1.2.3.-4", Error::IllegalCharacter(Context::Ipv4Address)).into(), ("1.2.3. 4", Error::IllegalCharacter(Context::Ipv4Address)).into(), ("1.2.3.4 ", Error::IllegalCharacter(Context::Ipv4Address)).into(), ]; for test_vector in test_vectors { let result = validate_ipv4_address(test_vector.address_string()); assert!(result.is_err(), "{}", test_vector.address_string()); assert_eq!( *test_vector.expected_error(), result.unwrap_err(), "{}", test_vector.address_string() ); } } }