use crate::ast::*;

#[derive(Debug)]
pub enum Recursion {
    Continue,
    Stop,
}

pub trait ASTVisitor: Sized {
    fn expression(&mut self, _expression: &Expression) -> Recursion {
        Recursion::Continue
    }
    fn declaration(&mut self, _declaration: &Declaration, _id: &ItemId) -> Recursion {
        Recursion::Continue
    }

    fn import(&mut self, _import: &ImportSpecifier) -> Recursion {
        Recursion::Continue
    }
    fn pattern(&mut self, _pat: &Pattern) -> Recursion {
        Recursion::Continue
    }
}

pub fn walk_ast<V: ASTVisitor>(v: &mut V, ast: &AST) {
    walk_block(v, &ast.statements);
}

pub fn walk_block<V: ASTVisitor>(v: &mut V, block: &Block) {
    use StatementKind::*;
    for statement in block.statements.iter() {
        match statement.kind {
            StatementKind::Expression(ref expr) => {
                walk_expression(v, expr);
            }
            Declaration(ref decl) => {
                walk_declaration(v, decl, &statement.id);
            }
            Import(ref import_spec) => {
                v.import(import_spec);
            }
            Flow(ref flow_control) =>
                if let FlowControl::Return(Some(ref retval)) = flow_control {
                    walk_expression(v, retval);
                },
        }
    }
}

pub fn walk_declaration<V: ASTVisitor>(v: &mut V, decl: &Declaration, id: &ItemId) {
    use Declaration::*;

    if let Recursion::Continue = v.declaration(decl, id) {
        match decl {
            FuncDecl(_sig, block) => {
                walk_block(v, block);
            }
            Binding { name: _, constant: _, type_anno: _, expr } => {
                walk_expression(v, expr);
            }
            Module { name: _, items } => {
                walk_block(v, items);
            }
            _ => (),
        };
    }
}

pub fn walk_expression<V: ASTVisitor>(v: &mut V, expr: &Expression) {
    use ExpressionKind::*;

    if let Recursion::Continue = v.expression(expr) {
        match &expr.kind {
            NatLiteral(_) | FloatLiteral(_) | StringLiteral(_) | BoolLiteral(_) | Value(_) => (),
            BinExp(_, lhs, rhs) => {
                walk_expression(v, lhs);
                walk_expression(v, rhs);
            }
            PrefixExp(_, arg) => {
                walk_expression(v, arg);
            }
            TupleLiteral(exprs) =>
                for expr in exprs {
                    walk_expression(v, expr);
                },
            NamedStruct { name: _, fields } =>
                for (_, expr) in fields.iter() {
                    walk_expression(v, expr);
                },
            Call { f, arguments } => {
                walk_expression(v, f);
                for arg in arguments.iter() {
                    match arg {
                        InvocationArgument::Positional(expr) | InvocationArgument::Keyword { expr, .. } =>
                            walk_expression(v, expr),
                        _ => (),
                    }
                }
            }
            Index { indexee, indexers } => {
                walk_expression(v, indexee);
                for indexer in indexers.iter() {
                    walk_expression(v, indexer);
                }
            }
            IfExpression { discriminator, body } => {
                if let Some(d) = discriminator.as_ref() {
                    walk_expression(v, d);
                }
                walk_if_expr_body(v, body.as_ref());
            }
            WhileExpression { condition, body } => {
                if let Some(d) = condition.as_ref() {
                    walk_expression(v, d);
                }
                walk_block(v, body);
            }
            ForExpression { enumerators, body } => {
                for enumerator in enumerators {
                    walk_expression(v, &enumerator.generator);
                }
                match body.as_ref() {
                    ForBody::MonadicReturn(expr) => walk_expression(v, expr),
                    ForBody::StatementBlock(block) => walk_block(v, block),
                };
            }
            Lambda { params: _, type_anno: _, body } => {
                walk_block(v, body);
            }
            Access { name: _, expr } => {
                walk_expression(v, expr);
            }
            ListLiteral(exprs) =>
                for expr in exprs {
                    walk_expression(v, expr);
                },
        };
    }
}

pub fn walk_if_expr_body<V: ASTVisitor>(v: &mut V, body: &IfExpressionBody) {
    use IfExpressionBody::*;

    match body {
        SimpleConditional { then_case, else_case } => {
            walk_block(v, then_case);
            if let Some(block) = else_case.as_ref() {
                walk_block(v, block)
            }
        }
        SimplePatternMatch { pattern, then_case, else_case } => {
            walk_pattern(v, pattern);
            walk_block(v, then_case);
            if let Some(block) = else_case.as_ref() {
                walk_block(v, block)
            }
        }
        CondList(arms) =>
            for arm in arms {
                match arm.condition {
                    Condition::Pattern(ref pat) => {
                        walk_pattern(v, pat);
                    }
                    Condition::TruncatedOp(ref _binop, ref expr) => {
                        walk_expression(v, expr);
                    }
                    Condition::Else => (),
                }
                if let Some(ref guard) = arm.guard {
                    walk_expression(v, guard);
                }
                walk_block(v, &arm.body);
            },
    }
}

pub fn walk_pattern<V: ASTVisitor>(v: &mut V, pat: &Pattern) {
    use Pattern::*;

    if let Recursion::Continue = v.pattern(pat) {
        match pat {
            TuplePattern(patterns) =>
                for pat in patterns {
                    walk_pattern(v, pat);
                },
            TupleStruct(_, patterns) =>
                for pat in patterns {
                    walk_pattern(v, pat);
                },
            Record(_, name_and_patterns) =>
                for (_, pat) in name_and_patterns {
                    walk_pattern(v, pat);
                },
            _ => (),
        };
    }
}