//! Provides the [`Reader`] trait (and implementations) used by the tokenizer.

use std::collections::VecDeque;
use std::convert::Infallible;
use std::io::{self, BufReader, Read};

/// An object that provides characters to the tokenizer.
///
/// Patches are welcome for providing an efficient implementation over async streams,
/// iterators, files, etc, as long as any dependencies come behind featureflags.
pub trait Reader {
    /// The error returned by this reader.
    type Error: std::error::Error;

    /// Return a new character from the input stream.
    ///
    /// The input stream does **not** have to be preprocessed in any way, it can contain standalone
    /// surrogates and have inconsistent newlines.
    fn read_char(&mut self) -> Result<Option<char>, Self::Error>;

    /// Attempt to read an entire string at once, either case-insensitively or not.
    ///
    /// `case_sensitive=false` means that characters of the input stream should be compared while
    /// ignoring ASCII-casing.
    ///
    /// It can be assumed that this function is never called with a string that contains `\r` or
    /// `\n`.
    ///
    /// If the next characters equal to `s`, this function consumes the respective characters from
    /// the input stream and returns `true`. If not, it does nothing and returns `false`.
    // TODO: document a maximum s length that may be assumed (depends on longest named character reference ... which may change?)
    fn try_read_string(&mut self, s: &str, case_sensitive: bool) -> Result<bool, Self::Error>;

    /// Returns the number of bytes that the given character takes up in the current character encoding.
    fn len_of_char_in_current_encoding(&self, c: char) -> usize;
}

/// An object that can be converted into a [`Reader`].
///
/// For example, any utf8-string can be converted into a `StringReader`.
// TODO: , such that [give concrete examples of NaiveParser::new] work.
pub trait IntoReader<'a> {
    /// The reader type into which this type should be converted.
    type Reader: Reader + 'a;

    /// Convert self into some sort of reader.
    fn into_reader(self) -> Self::Reader;
}

impl<'a, R: 'a + Reader> IntoReader<'a> for R {
    type Reader = Self;

    fn into_reader(self) -> Self::Reader {
        self
    }
}

/// A helper struct to seek forwards and backwards in strings. Used by the tokenizer to read HTML
/// from strings.
pub struct StringReader<'a> {
    input: &'a str,
    cursor: std::str::Chars<'a>,
    pos: usize,
}

impl<'a> StringReader<'a> {
    fn new(input: &'a str) -> Self {
        let cursor = input.chars();
        StringReader {
            input,
            cursor,
            pos: 0,
        }
    }
}

impl<'a> Reader for StringReader<'a> {
    type Error = Infallible;

    fn read_char(&mut self) -> Result<Option<char>, Self::Error> {
        let c = match self.cursor.next() {
            Some(c) => c,
            None => return Ok(None),
        };
        self.pos += c.len_utf8();
        Ok(Some(c))
    }

    fn try_read_string(&mut self, s1: &str, case_sensitive: bool) -> Result<bool, Self::Error> {
        // we do not need to call validate_char here because `s` hopefully does not contain invalid
        // characters
        if let Some(s2) = self.input.get(self.pos..self.pos + s1.len()) {
            if s1 == s2 || (!case_sensitive && s1.eq_ignore_ascii_case(s2)) {
                self.pos += s1.len();
                self.cursor = self.input[self.pos..].chars();
                return Ok(true);
            }
        }

        Ok(false)
    }

    fn len_of_char_in_current_encoding(&self, c: char) -> usize {
        c.len_utf8()
    }
}

impl<'a> IntoReader<'a> for &'a str {
    type Reader = StringReader<'a>;

    fn into_reader(self) -> Self::Reader {
        StringReader::new(self)
    }
}

impl<'a> IntoReader<'a> for &'a String {
    type Reader = StringReader<'a>;

    fn into_reader(self) -> Self::Reader {
        StringReader::new(self.as_str())
    }
}

/// Just the same as [`std::sys_common::io::DEFAULT_BUF_SIZE`] (which isn't public).
const BUF_SIZE: usize = 8 * 1024;

/// A [`Read`]-based buffered [`Reader`] implementation that attempts to read UTF-8.
pub struct BufReadReader<R: Read> {
    reader: R,
    /// The buffer into which bytes will be read from the reader.
    buffer: [u8; BUF_SIZE],
    /// Number of bytes in the buffer that have been read.
    read: usize,
    /// Position in the buffer up until the bytes have been parsed to chars.
    pos: usize,
    /// The characters parsed from the buffer in read order.
    chars: VecDeque<char>,
    /// An error that has occurred after reading the current content of chars.
    error: Option<io::Error>,
    /// Indicates if the end-of-file has been reached (we won't read anymore).
    eof: bool,
}

impl<R: Read> BufReadReader<R> {
    /// Construct a new `BufReadReader` from any type that implements [`Read`].
    pub fn new(reader: R) -> Self {
        BufReadReader {
            reader,
            buffer: [0; BUF_SIZE],
            read: 0,
            pos: 0,
            chars: VecDeque::new(),
            error: None,
            eof: false,
        }
    }

    #[inline]
    fn read(&mut self) -> Result<(), io::Error> {
        debug_assert!(!self.eof);
        debug_assert!(self.error.is_none());

        if self.pos == self.read {
            self.read = match self.reader.read(&mut self.buffer)? {
                0 => {
                    self.eof = true;
                    return Ok(());
                }
                n => n,
            };
            self.pos = 0;
        }

        let unprocessed = &self.buffer[self.pos..self.read];

        let (valid_str, err) = match std::str::from_utf8(unprocessed) {
            Ok(s) => (s, None),
            Err(err) => (
                // SAFETY: The UTF-8 checking has already been done by the previous from_utf8 call.
                unsafe { std::str::from_utf8_unchecked(&unprocessed[..err.valid_up_to()]) },
                Some(err),
            ),
        };
        for c in valid_str.chars() {
            self.chars.push_back(c);
        }
        self.pos += valid_str.len();

        if let Some(err) = err {
            self.error = Some(io::Error::new(io::ErrorKind::InvalidData, err));

            match err.error_len() {
                None => self.eof = true,
                Some(error_len) => self.pos += error_len,
            }
        }
        Ok(())
    }
}

impl<R: Read> Reader for BufReadReader<R> {
    type Error = io::Error;

    fn read_char(&mut self) -> Result<Option<char>, Self::Error> {
        if let Some(char) = self.chars.pop_front() {
            return Ok(Some(char));
        }
        if let Some(error) = self.error.take() {
            return Err(error);
        }
        if self.eof {
            return Ok(None);
        }

        self.read()?;

        if let Some(char) = self.chars.pop_front() {
            return Ok(Some(char));
        }
        if let Some(error) = self.error.take() {
            return Err(error);
        }
        debug_assert!(self.eof);
        Ok(None)
    }

    fn try_read_string(&mut self, s1: &str, case_sensitive: bool) -> Result<bool, Self::Error> {
        debug_assert!(!s1.contains('\r'));
        debug_assert!(!s1.contains('\n'));
        debug_assert!(s1.len() <= self.buffer.len());

        while self.chars.len() < s1.len() {
            if self.error.is_some() {
                return Ok(false);
            }
            if self.eof {
                return Ok(false);
            }
            self.read()?;
        }

        for (c, expected) in std::iter::zip(self.chars.iter(), s1.chars()) {
            if case_sensitive {
                if *c != expected {
                    return Ok(false);
                }
            } else {
                if !c.eq_ignore_ascii_case(&expected) {
                    return Ok(false);
                }
            }
        }

        self.chars.drain(..s1.len());

        Ok(true)
    }

    fn len_of_char_in_current_encoding(&self, c: char) -> usize {
        c.len_utf8()
    }
}

impl<'a, R: Read + 'a> IntoReader<'a> for BufReader<R> {
    type Reader = BufReadReader<BufReader<R>>;

    fn into_reader(self) -> Self::Reader {
        BufReadReader::new(self)
    }
}

#[cfg(test)]
mod tests {
    use std::io::{BufReader, ErrorKind};
    use std::str::Utf8Error;

    use super::{IntoReader, Reader};

    #[test]
    fn buf_read_reader_invalid_utf8() {
        let mut reader = BufReader::new(b" \xc3\x28" as &[u8]).into_reader();
        assert_eq!(reader.read_char().unwrap(), Some(' '));
        let error = reader.read_char().unwrap_err();
        assert!(matches!(error.kind(), ErrorKind::InvalidData));
        error.into_inner().unwrap().downcast::<Utf8Error>().unwrap();
        assert_eq!(reader.read_char().unwrap(), Some('('));
        assert_eq!(reader.read_char().unwrap(), None);
    }
}