use std::iter;

use syn::{
    punctuated::Punctuated, GenericArgument, Lifetime, Path, PathArguments, ReturnType, TraitBound,
    Type, TypeParamBound,
};

pub fn trait_bounds<T>(
    bounds: &Punctuated<TypeParamBound, T>,
) -> impl Iterator<Item = &TraitBound> {
    bounds.iter().filter_map(|b| match b {
        TypeParamBound::Trait(t) => Some(t),
        TypeParamBound::Lifetime(_) => None,
    })
}

pub fn lifetime_bounds<T>(
    bounds: &Punctuated<TypeParamBound, T>,
) -> impl Iterator<Item = &Lifetime> {
    bounds.iter().filter_map(|b| match b {
        TypeParamBound::Trait(_) => None,
        TypeParamBound::Lifetime(l) => Some(l),
    })
}

pub fn type_arguments_mut<P>(
    args: &mut Punctuated<GenericArgument, P>,
) -> impl Iterator<Item = &mut Type> {
    args.iter_mut().filter_map(|a| match a {
        GenericArgument::Type(t) => Some(t),
        _other => None,
    })
}

enum IterTypes<A, B, C, D, E, F> {
    Single(A),
    Function(B),
    Tuple(C),
    ImplTrait(D),
    TraitObject(E),
    Path(F),
    Empty,
}

impl<A, B, C, D, E, F> Iterator for IterTypes<A, B, C, D, E, F>
where
    A: Iterator,
    B: Iterator<Item = A::Item>,
    C: Iterator<Item = A::Item>,
    D: Iterator<Item = A::Item>,
    E: Iterator<Item = A::Item>,
    F: Iterator<Item = A::Item>,
{
    type Item = A::Item;

    fn next(&mut self) -> Option<Self::Item> {
        match self {
            IterTypes::Single(a) => a.next(),
            IterTypes::Function(a) => a.next(),
            IterTypes::Tuple(a) => a.next(),
            IterTypes::ImplTrait(a) => a.next(),
            IterTypes::TraitObject(a) => a.next(),
            IterTypes::Path(a) => a.next(),
            IterTypes::Empty => None,
        }
    }
}

pub fn iter_path(path: &Path) -> impl Iterator<Item = &Type> {
    types_in_path(path).flat_map(|t| iter_type(t))
}

pub fn iter_type<'a>(t: &'a Type) -> Box<dyn Iterator<Item = &'a Type> + 'a> {
    Box::new(
        iter::once(t).chain(match t {
            Type::Array(array) => IterTypes::Single(iter_type(&array.elem)),
            Type::Group(group) => IterTypes::Single(iter_type(&group.elem)),
            Type::Paren(paren) => IterTypes::Single(iter_type(&paren.elem)),
            Type::Ptr(ptr) => IterTypes::Single(iter_type(&ptr.elem)),
            Type::Reference(r) => IterTypes::Single(iter_type(&r.elem)),
            Type::Slice(slice) => IterTypes::Single(iter_type(&slice.elem)),

            Type::Tuple(tuple) => IterTypes::Tuple(tuple.elems.iter().flat_map(|i| iter_type(i))),

            Type::BareFn(fun) => {
                IterTypes::Function(fun.inputs.iter().flat_map(|i| iter_type(&i.ty)).chain(
                    match &fun.output {
                        ReturnType::Default => IterEnum::Left(iter::empty()),
                        ReturnType::Type(_, ty) => IterEnum::Right(iter_type(ty.as_ref())),
                    },
                ))
            }
            Type::ImplTrait(impl_trait) => IterTypes::ImplTrait(
                trait_bounds(&impl_trait.bounds)
                    .flat_map(|b| types_in_path(&b.path))
                    .flat_map(|t| iter_type(t)),
            ),

            Type::Path(path) => IterTypes::Path(iter_path(&path.path)),
            Type::TraitObject(trait_obj) => IterTypes::TraitObject(
                trait_bounds(&trait_obj.bounds)
                    .flat_map(|b| types_in_path(&b.path))
                    .flat_map(|t| iter_type(t)),
            ),
            _other => IterTypes::Empty,
        }),
    )
}

enum IterEnum<L, R> {
    Left(L),
    Right(R),
}

impl<L, R> Iterator for IterEnum<L, R>
where
    L: Iterator,
    R: Iterator<Item = L::Item>,
{
    type Item = L::Item;

    fn next(&mut self) -> Option<Self::Item> {
        match self {
            IterEnum::Left(iter) => iter.next(),
            IterEnum::Right(iter) => iter.next(),
        }
    }
}

fn types_in_path(p: &Path) -> impl Iterator<Item = &Type> {
    p.segments.iter().flat_map(|s| match &s.arguments {
        PathArguments::AngleBracketed(ang) => {
            IterEnum::Left(ang.args.iter().flat_map(|a| match a {
                GenericArgument::Type(t) => Some(t),
                GenericArgument::Binding(b) => Some(&b.ty),
                // TODO: handle GenericArgument::Constraint
                _other => None,
            }))
        }
        _other => IterEnum::Right(iter::empty()),
    })
}