From 851fd9885fffe9dc379a44056e5ac4e8ef5c1759 Mon Sep 17 00:00:00 2001 From: Greg Shuflin Date: Tue, 26 Oct 2021 14:05:27 -0700 Subject: [PATCH] Make a distinct Block type --- schala-lang/language/src/ast/mod.rs | 29 ++++++- schala-lang/language/src/ast/visitor.rs | 2 +- schala-lang/language/src/parsing/mod.rs | 14 ++-- schala-lang/language/src/parsing/test.rs | 90 +++++++++++----------- schala-lang/language/src/reduced_ir/mod.rs | 8 +- schala-lang/language/src/typechecking.rs | 4 +- 6 files changed, 85 insertions(+), 62 deletions(-) diff --git a/schala-lang/language/src/ast/mod.rs b/schala-lang/language/src/ast/mod.rs index 3bc12c3..a1db480 100644 --- a/schala-lang/language/src/ast/mod.rs +++ b/schala-lang/language/src/ast/mod.rs @@ -3,6 +3,7 @@ use std::rc::Rc; use std::fmt; +use std::convert::{AsRef, From}; mod visitor; mod operators; @@ -53,7 +54,7 @@ impl ItemIdStore { pub struct AST { #[derivative(PartialEq="ignore")] pub id: ItemId, - pub statements: Vec + pub statements: Block, } #[derive(Derivative, Debug, Clone)] @@ -74,7 +75,29 @@ pub enum StatementKind { Module(ModuleSpecifier), } -pub type Block = Vec; +#[derive(Debug, Clone, PartialEq, Default)] +pub struct Block { + pub statements: Vec +} + +impl From> for Block { + fn from(statements: Vec) -> Self { + Self { statements } + } +} + +impl From for Block { + fn from(statement: Statement) -> Self { + Self { statements: vec![statement] } + } +} + +impl AsRef<[Statement]> for Block { + fn as_ref(&self) -> &[Statement] { + self.statements.as_ref() + } +} + pub type ParamName = Rc; #[derive(Debug, Derivative, Clone)] @@ -319,6 +342,6 @@ pub enum ImportedNames { #[derive(Debug, PartialEq, Clone)] pub struct ModuleSpecifier { pub name: Rc, - pub contents: Vec, + pub contents: Block, } diff --git a/schala-lang/language/src/ast/visitor.rs b/schala-lang/language/src/ast/visitor.rs index 1165a42..64725ac 100644 --- a/schala-lang/language/src/ast/visitor.rs +++ b/schala-lang/language/src/ast/visitor.rs @@ -22,7 +22,7 @@ pub fn walk_ast(v: &mut V, ast: &AST) { pub fn walk_block(v: &mut V, block: &Block) { use StatementKind::*; - for statement in block.iter() { + for statement in block.statements.iter() { match statement.kind { StatementKind::Expression(ref expr) => { walk_expression(v, expr); diff --git a/schala-lang/language/src/parsing/mod.rs b/schala-lang/language/src/parsing/mod.rs index c05a83f..6b1ae1d 100644 --- a/schala-lang/language/src/parsing/mod.rs +++ b/schala-lang/language/src/parsing/mod.rs @@ -362,7 +362,7 @@ impl Parser { ), } } - Ok(AST { id: self.id_store.fresh(), statements }) + Ok(AST { id: self.id_store.fresh(), statements: statements.into() }) } /// `statement := expression | declaration` @@ -478,7 +478,7 @@ impl Parser { fn func_declaration(&mut self) -> ParseResult { let signature = self.func_signature()?; if let LCurlyBrace = self.token_handler.peek_kind() { - let statements = self.nonempty_func_body()?; + let statements = self.nonempty_func_body()?.into(); Ok(Declaration::FuncDecl(signature, statements)) } else { Ok(Declaration::FuncSig(signature)) @@ -771,7 +771,7 @@ impl Parser { Colon => Some(self.type_anno()?), _ => None, }; - let body = self.nonempty_func_body()?; + let body = self.nonempty_func_body()?.into(); Ok(Expression::new(self.id_store.fresh(), ExpressionKind::Lambda { params, type_anno, body })) //TODO need to handle types somehow } @@ -1047,7 +1047,7 @@ impl Parser { #[recursive_descent_method] fn block(&mut self) -> ParseResult { let block = delimited!(self, LCurlyBrace, statement, Newline | Semicolon, RCurlyBrace, nonstrict); - Ok(block) + Ok(block.into()) } #[recursive_descent_method] @@ -1058,7 +1058,7 @@ impl Parser { _ => { let expr = self.expression()?; let s = Statement { id: self.id_store.fresh(), location: tok.location, kind: StatementKind::Expression(expr) }; - Ok(vec![s]) + Ok(s.into()) } } } @@ -1118,7 +1118,7 @@ impl Parser { Ok(match tok.get_kind() { LCurlyBrace => { let statements = delimited!(self, LCurlyBrace, statement, Newline | Semicolon, RCurlyBrace, nonstrict); - StatementBlock(statements) + StatementBlock(statements.into()) }, Keyword(Kw::Return) => { self.token_handler.next(); @@ -1279,7 +1279,7 @@ impl Parser { expect!(self, Keyword(Kw::Module)); let name = self.identifier()?; let contents = delimited!(self, LCurlyBrace, statement, Newline | Semicolon, RCurlyBrace, nonstrict); - Ok(ModuleSpecifier { name, contents }) + Ok(ModuleSpecifier { name, contents: contents.into() }) } } diff --git a/schala-lang/language/src/parsing/test.rs b/schala-lang/language/src/parsing/test.rs index 68c390f..768eddd 100644 --- a/schala-lang/language/src/parsing/test.rs +++ b/schala-lang/language/src/parsing/test.rs @@ -48,7 +48,7 @@ macro_rules! parse_test { }; } macro_rules! parse_test_wrap_ast { - ($string:expr, $correct:expr) => { parse_test!($string, AST { id: Default::default(), statements: vec![$correct] }) } + ($string:expr, $correct:expr) => { parse_test!($string, AST { id: Default::default(), statements: vec![$correct].into() }) } } macro_rules! parse_error { ($string:expr) => { assert!(parse($string).is_err()) } @@ -142,7 +142,7 @@ fn parsing_number_literals_and_binexps() { AST { id: Default::default(), statements: vec![exst!(NatLiteral(3)), exst!(NatLiteral(4)), - exst!(FloatLiteral(4.3))] + exst!(FloatLiteral(4.3))].into() } }; @@ -271,10 +271,10 @@ fn parsing_functions() { parse_test_wrap_ast!("fn a(x) { x() }", decl!( FuncDecl(Signature { name: rc!(a), operator: false, params: vec![FormalParam { name: rc!(x), anno: None, default: None }], type_anno: None }, - vec![exst!(Call { f: bx!(ex!(val!("x"))), arguments: vec![] })]))); + vec![exst!(Call { f: bx!(ex!(val!("x"))), arguments: vec![] })].into()))); parse_test_wrap_ast!("fn a(x) {\n x() }", decl!( FuncDecl(Signature { name: rc!(a), operator: false, params: vec![FormalParam { name: rc!(x), anno: None, default: None }], type_anno: None }, - vec![exst!(Call { f: bx!(ex!(val!("x"))), arguments: vec![] })]))); + vec![exst!(Call { f: bx!(ex!(val!("x"))), arguments: vec![] })].into()))); let multiline = r#" fn a(x) { @@ -283,7 +283,7 @@ x() "#; parse_test_wrap_ast!(multiline, decl!( FuncDecl(Signature { name: rc!(a), operator: false, params: vec![FormalParam { name: rc!(x), default: None, anno: None }], type_anno: None }, - vec![exst!(Call { f: bx!(ex!(val!("x"))), arguments: vec![] })]))); + vec![exst!(Call { f: bx!(ex!(val!("x"))), arguments: vec![] })].into()))); let multiline2 = r#" fn a(x) { @@ -293,7 +293,7 @@ x() "#; parse_test_wrap_ast!(multiline2, decl!( FuncDecl(Signature { name: rc!(a), operator: false, params: vec![FormalParam { name: rc!(x), default: None, anno: None }], type_anno: None }, - vec![exst!(s "x()")]))); + exst!(s "x()").into()))); } #[test] @@ -304,7 +304,7 @@ fn functions_with_default_args() { FuncDecl(Signature { name: rc!(func), operator: false, type_anno: None, params: vec![ FormalParam { name: rc!(x), default: None, anno: Some(ty!("Int")) }, FormalParam { name: rc!(y), default: Some(ex!(s "4")), anno: Some(ty!("Int")) } - ]}, vec![]) + ]}, vec![].into()) ) }; } @@ -383,7 +383,7 @@ fn parsing_block_expressions() { }), body: bx! { IfExpressionBody::SimpleConditional { - then_case: vec![exst!(Call { f: bx!(ex!(val!("b"))), arguments: vec![]}), exst!(Call { f: bx!(ex!(val!("c"))), arguments: vec![] })], + then_case: vec![exst!(Call { f: bx!(ex!(val!("b"))), arguments: vec![]}), exst!(Call { f: bx!(ex!(val!("c"))), arguments: vec![] })].into(), else_case: None, } } @@ -399,8 +399,8 @@ fn parsing_block_expressions() { }), body: bx! { IfExpressionBody::SimpleConditional { - then_case: vec![exst!(Call { f: bx!(ex!(val!("b"))), arguments: vec![]}), exst!(Call { f: bx!(ex!(val!("c"))), arguments: vec![] })], - else_case: Some(vec![exst!(val!("q"))]), + then_case: vec![exst!(Call { f: bx!(ex!(val!("b"))), arguments: vec![]}), exst!(Call { f: bx!(ex!(val!("c"))), arguments: vec![] })].into(), + else_case: Some(vec![exst!(val!("q"))].into()), } } } @@ -527,7 +527,7 @@ fn parsing_type_annotations() { #[test] fn parsing_lambdas() { parse_test_wrap_ast! { r#"\(x) { x + 1}"#, exst!( - Lambda { params: vec![FormalParam { name: rc!(x), anno: None, default: None } ], type_anno: None, body: vec![exst!(s "x + 1")] } + Lambda { params: vec![FormalParam { name: rc!(x), anno: None, default: None } ], type_anno: None, body: exst!(s "x + 1").into() } ) } @@ -538,7 +538,7 @@ fn parsing_lambdas() { FormalParam { name: rc!(y), anno: None, default: None } ], type_anno: None, - body: vec![exst!(s "a"), exst!(s "b"), exst!(s "c")] + body: vec![exst!(s "a"), exst!(s "b"), exst!(s "c")].into() }) ); @@ -549,7 +549,7 @@ fn parsing_lambdas() { FormalParam { name: rc!(x), anno: None, default: None } ], type_anno: None, - body: vec![exst!(s "y")] } + body: exst!(s "y").into() } )), arguments: vec![inv!(ex!(NatLiteral(1)))] }) }; @@ -561,7 +561,7 @@ fn parsing_lambdas() { FormalParam { name: rc!(x), anno: Some(ty!("Int")), default: None }, ], type_anno: Some(ty!("String")), - body: vec![exst!(s r#""q""#)] + body: exst!(s r#""q""#).into() }) } } @@ -573,7 +573,7 @@ fn single_param_lambda() { exst!(Lambda { params: vec![FormalParam { name: rc!(x), anno: None, default: None }], type_anno: None, - body: vec![exst!(s r"x + 10")] + body: exst!(s r"x + 10").into() }) } @@ -582,7 +582,7 @@ fn single_param_lambda() { exst!(Lambda { params: vec![FormalParam { name: rc!(x), anno: Some(ty!("Nat")), default: None }], type_anno: None, - body: vec![exst!(s r"x + 10")] + body: exst!(s r"x + 10").into() }) } } @@ -602,7 +602,7 @@ fn more_advanced_lambdas() { arguments: vec![inv!(ex!(NatLiteral(3)))], } } - ] + ].into() } } } @@ -619,12 +619,12 @@ fn more_advanced_lambdas() { fn while_expr() { parse_test_wrap_ast! { "while { }", - exst!(WhileExpression { condition: None, body: vec![] }) + exst!(WhileExpression { condition: None, body: Block::default() }) } parse_test_wrap_ast! { "while a == b { }", - exst!(WhileExpression { condition: Some(bx![ex![binexp!("==", val!("a"), val!("b"))]]), body: vec![] }) + exst!(WhileExpression { condition: Some(bx![ex![binexp!("==", val!("a"), val!("b"))]]), body: Block::default() }) } } @@ -641,7 +641,7 @@ fn for_expr() { parse_test_wrap_ast! { "for n <- someRange { f(n); }", exst!(ForExpression { enumerators: vec![Enumerator { id: rc!(n), generator: ex!(val!("someRange"))}], - body: bx!(ForBody::StatementBlock(vec![exst!(s "f(n)")])) + body: bx!(ForBody::StatementBlock(vec![exst!(s "f(n)")].into())) }) } } @@ -654,8 +654,8 @@ fn patterns() { discriminator: Some(bx!(ex!(s "x"))), body: bx!(IfExpressionBody::SimplePatternMatch { pattern: Pattern::TupleStruct(qname!(Some), vec![Pattern::VarOrName(qname!(a))]), - then_case: vec![exst!(s "4")], - else_case: Some(vec![exst!(s "9")]) }) + then_case: vec![exst!(s "4")].into(), + else_case: Some(vec![exst!(s "9")].into()) }) } ) } @@ -666,8 +666,8 @@ fn patterns() { discriminator: Some(bx!(ex!(s "x"))), body: bx!(IfExpressionBody::SimplePatternMatch { pattern: Pattern::TupleStruct(qname!(Some), vec![Pattern::VarOrName(qname!(a))]), - then_case: vec![exst!(s "4")], - else_case: Some(vec![exst!(s "9")]) } + then_case: vec![exst!(s "4")].into(), + else_case: Some(vec![exst!(s "9")].into()) } ) } ) @@ -682,8 +682,8 @@ fn patterns() { (rc!(a),Pattern::Literal(PatternLiteral::StringPattern(rc!(a)))), (rc!(b),Pattern::VarOrName(qname!(x))) ]), - then_case: vec![exst!(s "4")], - else_case: Some(vec![exst!(s "9")]) + then_case: vec![exst!(s "4")].into(), + else_case: Some(vec![exst!(s "9")].into()) } ) } @@ -700,8 +700,8 @@ fn pattern_literals() { discriminator: Some(bx!(ex!(s "x"))), body: bx!(IfExpressionBody::SimplePatternMatch { pattern: Pattern::Literal(PatternLiteral::NumPattern { neg: true, num: NatLiteral(1) }), - then_case: vec![exst!(NatLiteral(1))], - else_case: Some(vec![exst!(NatLiteral(2))]), + then_case: vec![exst!(NatLiteral(1))].into(), + else_case: Some(vec![exst!(NatLiteral(2))].into()), }) } ) @@ -714,8 +714,8 @@ fn pattern_literals() { discriminator: Some(bx!(ex!(s "x"))), body: bx!(IfExpressionBody::SimplePatternMatch { pattern: Pattern::Literal(PatternLiteral::NumPattern { neg: false, num: NatLiteral(1) }), - then_case: vec![exst!(s "1")], - else_case: Some(vec![exst!(s "2")]), + then_case: vec![exst!(s "1")].into(), + else_case: Some(vec![exst!(s "2")].into()), }) } ) @@ -729,8 +729,8 @@ fn pattern_literals() { body: bx!( IfExpressionBody::SimplePatternMatch { pattern: Pattern::Literal(PatternLiteral::BoolPattern(true)), - then_case: vec![exst!(NatLiteral(1))], - else_case: Some(vec![exst!(NatLiteral(2))]), + then_case: vec![exst!(NatLiteral(1))].into(), + else_case: Some(vec![exst!(NatLiteral(2))].into()), }) } ) @@ -743,8 +743,8 @@ fn pattern_literals() { discriminator: Some(bx!(ex!(s "x"))), body: bx!(IfExpressionBody::SimplePatternMatch { pattern: Pattern::Literal(PatternLiteral::StringPattern(rc!(gnosticism))), - then_case: vec![exst!(s "1")], - else_case: Some(vec![exst!(s "2")]), + then_case: vec![exst!(s "1")].into(), + else_case: Some(vec![exst!(s "2")].into()), }) } ) @@ -816,12 +816,12 @@ fn if_expr() { ConditionArm { condition: Condition::Pattern(Pattern::Literal(PatternLiteral::NumPattern { neg: false, num: NatLiteral(1)})), guard: None, - body: vec![exst!(s "5")], + body: vec![exst!(s "5")].into(), }, ConditionArm { condition: Condition::Else, guard: None, - body: vec![exst!(s "20")], + body: vec![exst!(s "20")].into(), }, ] )) @@ -854,7 +854,7 @@ if (45, "panda", false, 2.2) { ] )), guard: None, - body: vec![exst!(s r#""no""#)], + body: vec![exst!(s r#""no""#)].into(), }, ConditionArm { condition: Condition::Pattern(Pattern::TuplePattern( @@ -866,12 +866,12 @@ if (45, "panda", false, 2.2) { ] )), guard: None, - body: vec![exst!(s r#""yes""#)], + body: vec![exst!(s r#""yes""#)].into(), }, ConditionArm { condition: Condition::Pattern(Pattern::Ignored), guard: None, - body: vec![exst!(s r#""maybe""#)], + body: vec![exst!(s r#""maybe""#)].into(), }, ])) } @@ -891,8 +891,8 @@ r#" module!( ModuleSpecifier { name: rc!(ephraim), contents: vec![ decl!(Binding { name: rc!(a), constant: true, type_anno: None, expr: ex!(s "10") }), - decl!(FuncDecl(Signature { name: rc!(nah), operator: false, params: vec![], type_anno: None }, vec![exst!(NatLiteral(33))])), - ] } + decl!(FuncDecl(Signature { name: rc!(nah), operator: false, params: vec![], type_anno: None }, vec![exst!(NatLiteral(33))].into())), + ].into() } ) } } @@ -911,10 +911,10 @@ fn annotations() { decl!(Annotation { name: rc!(test_annotation), arguments: vec![] }), decl!(FuncDecl( Signature { name: rc!(some_function), operator: false, params: vec![], type_anno: None } - , vec![] + , vec![].into() ) ) - ] + ].into() } }; parse_test! { @@ -932,10 +932,10 @@ fn annotations() { ] }), decl!(FuncDecl( Signature { name: rc!(some_function), operator: false, params: vec![], type_anno: None } - , vec![] + , vec![].into() ) ) - ] + ].into() } }; } diff --git a/schala-lang/language/src/reduced_ir/mod.rs b/schala-lang/language/src/reduced_ir/mod.rs index 2fa8e57..7ae7269 100644 --- a/schala-lang/language/src/reduced_ir/mod.rs +++ b/schala-lang/language/src/reduced_ir/mod.rs @@ -31,14 +31,14 @@ impl<'a> Reducer<'a> { fn reduce(mut self, ast: &ast::AST) -> ReducedIR { // First reduce all functions // TODO once this works, maybe rewrite it using the Visitor - for statement in ast.statements.iter() { + for statement in ast.statements.statements.iter() { self.top_level_statement(statement); } // Then compute the entrypoint statements (which may reference previously-computed // functions by ID) let mut entrypoint = vec![]; - for statement in ast.statements.iter() { + for statement in ast.statements.statements.iter() { let ast::Statement { id: item_id, kind, .. } = statement; match &kind { ast::StatementKind::Expression(expr) => { @@ -224,8 +224,8 @@ impl<'a> Reducer<'a> { } } - fn function_internal_block(&mut self, statements: &ast::Block) -> Vec { - statements.iter().filter_map(|stmt| self.function_internal_statement(stmt)).collect() + fn function_internal_block(&mut self, block: &ast::Block) -> Vec { + block.statements.iter().filter_map(|stmt| self.function_internal_statement(stmt)).collect() } fn prefix(&mut self, prefix: &ast::PrefixOp, arg: &ast::Expression) -> Expression { diff --git a/schala-lang/language/src/typechecking.rs b/schala-lang/language/src/typechecking.rs index 8aa61f4..28dc1e4 100644 --- a/schala-lang/language/src/typechecking.rs +++ b/schala-lang/language/src/typechecking.rs @@ -298,7 +298,7 @@ impl<'a> TypeContext<'a> { /// the AST to ReducedAST pub fn typecheck(&mut self, ast: &AST) -> Result { let mut returned_type = Type::Const(TypeConst::Unit); - for statement in ast.statements.iter() { + for statement in ast.statements.statements.iter() { returned_type = self.statement(statement)?; } Ok(returned_type) @@ -444,7 +444,7 @@ impl<'a> TypeContext<'a> { #[allow(clippy::ptr_arg)] fn block(&mut self, block: &Block) -> InferResult { let mut output = ty!(Unit); - for statement in block.iter() { + for statement in block.statements.iter() { output = self.statement(statement)?; } Ok(output)