From df173a0096f00bffd7c959c6da4c9293c8442f1b Mon Sep 17 00:00:00 2001 From: Greg Shuflin Date: Tue, 26 Oct 2021 11:37:43 -0700 Subject: [PATCH] Variables in pattern match --- schala-lang/language/src/ast/mod.rs | 2 +- schala-lang/language/src/ast/visitor.rs | 8 +- schala-lang/language/src/reduced_ir/mod.rs | 5 +- schala-lang/language/src/reduced_ir/types.rs | 1 + schala-lang/language/src/symbol_table/mod.rs | 12 ++- .../language/src/symbol_table/resolver.rs | 74 ++++++++++++++++++- .../language/src/tree_walk_eval/mod.rs | 25 +++++-- .../language/src/tree_walk_eval/test.rs | 9 ++- 8 files changed, 114 insertions(+), 22 deletions(-) diff --git a/schala-lang/language/src/ast/mod.rs b/schala-lang/language/src/ast/mod.rs index 087e816..3bc12c3 100644 --- a/schala-lang/language/src/ast/mod.rs +++ b/schala-lang/language/src/ast/mod.rs @@ -8,7 +8,7 @@ mod visitor; mod operators; pub use operators::{PrefixOp, BinOp}; -pub use visitor::{walk_ast, walk_block, ASTVisitor, Recursion}; +pub use visitor::*; use crate::derivative::Derivative; use crate::tokenizing::Location; diff --git a/schala-lang/language/src/ast/visitor.rs b/schala-lang/language/src/ast/visitor.rs index b5c6744..6db62d9 100644 --- a/schala-lang/language/src/ast/visitor.rs +++ b/schala-lang/language/src/ast/visitor.rs @@ -42,7 +42,7 @@ pub fn walk_block(v: &mut V, block: &Block) { } } -fn walk_declaration(v: &mut V, decl: &Declaration, id: &ItemId) { +pub fn walk_declaration(v: &mut V, decl: &Declaration, id: &ItemId) { use Declaration::*; if let Recursion::Continue = v.declaration(decl, id) { @@ -63,7 +63,7 @@ fn walk_declaration(v: &mut V, decl: &Declaration, id: &ItemId) { } } -fn walk_expression(v: &mut V, expr: &Expression) { +pub fn walk_expression(v: &mut V, expr: &Expression) { use ExpressionKind::*; if let Recursion::Continue = v.expression(expr) { @@ -142,7 +142,7 @@ fn walk_expression(v: &mut V, expr: &Expression) { } } -fn walk_if_expr_body(v: &mut V, body: &IfExpressionBody) { +pub fn walk_if_expr_body(v: &mut V, body: &IfExpressionBody) { use IfExpressionBody::*; match body { @@ -189,7 +189,7 @@ fn walk_if_expr_body(v: &mut V, body: &IfExpressionBody) { } } -fn walk_pattern(v: &mut V, pat: &Pattern) { +pub fn walk_pattern(v: &mut V, pat: &Pattern) { use Pattern::*; if let Recursion::Continue = v.pattern(pat) { diff --git a/schala-lang/language/src/reduced_ir/mod.rs b/schala-lang/language/src/reduced_ir/mod.rs index a678830..4c62a93 100644 --- a/schala-lang/language/src/reduced_ir/mod.rs +++ b/schala-lang/language/src/reduced_ir/mod.rs @@ -332,11 +332,9 @@ impl ast::Pattern { ast::Pattern::TupleStruct(name, subpatterns) => { let symbol = symbol_table.lookup_symbol(&name.id).unwrap(); if let SymbolSpec::DataConstructor { index: tag, type_id, arity } = symbol.spec() { - let items: Vec<_> = subpatterns.iter().map(|pat| pat.reduce(symbol_table)).collect(); let items: Result, PatternError> = items.into_iter().collect(); let items = items?; - Pattern::Tuple { tag: Some(tag as u32), subpatterns: items, @@ -349,7 +347,8 @@ impl ast::Pattern { //TODO fix this symbol not existing let symbol = symbol_table.lookup_symbol(&name.id).unwrap(); println!("VarOrName symbol: {:?}", symbol); - Pattern::Ignored + let def_id = symbol.def_id().unwrap().clone(); + Pattern::Binding(def_id) }, ast::Pattern::Record(name, /*Vec<(Rc, Pattern)>*/ _) => { unimplemented!() diff --git a/schala-lang/language/src/reduced_ir/types.rs b/schala-lang/language/src/reduced_ir/types.rs index 03b7e8b..c31713f 100644 --- a/schala-lang/language/src/reduced_ir/types.rs +++ b/schala-lang/language/src/reduced_ir/types.rs @@ -131,6 +131,7 @@ pub enum Pattern { }, Literal(Literal), Ignored, + Binding(DefId) } /* diff --git a/schala-lang/language/src/symbol_table/mod.rs b/schala-lang/language/src/symbol_table/mod.rs index 9b5ff26..477a492 100644 --- a/schala-lang/language/src/symbol_table/mod.rs +++ b/schala-lang/language/src/symbol_table/mod.rs @@ -221,6 +221,15 @@ impl SymbolTable { self.id_to_symbol.iter().find(|(_, sym)| sym.def_id == *def) .map(|(_, sym)| sym.as_ref()) } + + #[allow(dead_code)] + pub fn debug(&self) { + println!("Symbol table:"); + println!("----------------"); + for (id, sym) in self.id_to_symbol.iter() { + println!("{} => {}", id, sym); + } + } } #[allow(dead_code)] @@ -247,7 +256,7 @@ impl Symbol { impl fmt::Display for Symbol { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "", self.local_name(), self.spec) + write!(f, "", self.local_name(), self.fully_qualified_name, self.spec) } } @@ -322,6 +331,7 @@ impl SymbolTable { spec, def_id, }); + println!("In add_symbol(), adding: {:?}", symbol); self.symbol_trie.insert(&fqsn); self.fqsn_to_symbol.insert(fqsn, symbol.clone()); self.id_to_symbol.insert(id.clone(), symbol.clone()); diff --git a/schala-lang/language/src/symbol_table/resolver.rs b/schala-lang/language/src/symbol_table/resolver.rs index e7cb174..8500565 100644 --- a/schala-lang/language/src/symbol_table/resolver.rs +++ b/schala-lang/language/src/symbol_table/resolver.rs @@ -18,11 +18,14 @@ enum ScopeType { name: Rc }, Lambda, + PatternMatch, //TODO add some notion of a let-like scope? } pub struct ScopeResolver<'a> { symbol_table: &'a mut super::SymbolTable, + //TODO maybe this shouldn't be a scope stack, b/c the recursion behavior comes from multiple + //instances of ScopeResolver lexical_scopes: ScopeStack<'a, Rc, NameType, ScopeType>, } @@ -225,6 +228,61 @@ impl<'a> ASTVisitor for ScopeResolver<'a> { walk_block(&mut new_resolver, body); return Recursion::Stop; } + IfExpression { discriminator, body } => { + if let Some(d) = discriminator.as_ref() { + walk_expression(self, &d); + } + let mut resolver = ScopeResolver { + lexical_scopes: self.lexical_scopes.new_scope(Some(ScopeType::PatternMatch)), + symbol_table: self.symbol_table + }; + let new_resolver = &mut resolver; + + match body.as_ref() { + IfExpressionBody::SimpleConditional { + then_case, + else_case, + } => { + walk_block(new_resolver, then_case); + if let Some(block) = else_case.as_ref() { + walk_block(new_resolver, block) + } + } + IfExpressionBody::SimplePatternMatch { + pattern, + then_case, + else_case, + } => { + walk_pattern(new_resolver, pattern); + walk_block(new_resolver, &then_case); + if let Some(ref block) = else_case.as_ref() { + walk_block(new_resolver, &block) + } + } + IfExpressionBody::CondList(arms) => { + for arm in arms { + match arm.condition { + Condition::Pattern(ref pat) => { + walk_pattern(new_resolver, pat); + } + Condition::TruncatedOp(ref _binop, ref expr) => { + walk_expression(new_resolver, expr); + } + Condition::Expression(ref expr) => { + walk_expression(new_resolver, expr); + } + Condition::Else => (), + } + if let Some(ref guard) = arm.guard { + walk_expression(new_resolver, &guard); + } + walk_block(new_resolver, &arm.body); + } + } + }; + + return Recursion::Stop; + }, _ => (), } Recursion::Continue @@ -234,12 +292,20 @@ impl<'a> ASTVisitor for ScopeResolver<'a> { use Pattern::*; match pat { - //TODO I think not handling TuplePattern is an oversight - TuplePattern(_) => (), - Literal(_) | Ignored => (), - TupleStruct(name, _) | Record(name, _) | VarOrName(name) => { + Literal(..) | Ignored | TuplePattern(..) => (), + TupleStruct(name, _) | Record(name, _) => { self.lookup_name_in_scope(name); } + //TODO this isn't really the right syntax for a VarOrName + VarOrName(ref name @ QualifiedName { id, components }) => { + //TODO need a better way to construct a FQSN from a QualifiedName + let local_name: Rc = components[0].clone(); + let lscope = Scope::Name(Rc::new("".to_string())); + let fqsn = Fqsn { scopes: vec![lscope, Scope::Name(local_name.clone())] }; + //let local_name = fqsn.local_name(); + self.symbol_table.add_symbol(id, fqsn, SymbolSpec::LocalVariable); + self.lexical_scopes.insert(local_name.clone(), NameType::LocalVariable(id.clone())); + }, }; Recursion::Continue } diff --git a/schala-lang/language/src/tree_walk_eval/mod.rs b/schala-lang/language/src/tree_walk_eval/mod.rs index 5a1aed4..c8d372c 100644 --- a/schala-lang/language/src/tree_walk_eval/mod.rs +++ b/schala-lang/language/src/tree_walk_eval/mod.rs @@ -276,9 +276,14 @@ impl<'a> State<'a> { } fn case_match_expression(&mut self, cond: Expression, alternatives: Vec) -> EvalResult { - fn matches(scrut: &Primitive, pat: &Pattern) -> bool { + fn matches(scrut: &Primitive, pat: &Pattern, scope: &mut ScopeStack) -> bool { match pat { Pattern::Ignored => true, + Pattern::Binding(ref def_id) => { + let mem = def_id.into(); + scope.insert(mem, MemoryValue::Primitive(scrut.clone())); //TODO make sure this doesn't cause problems with nesting + true + }, Pattern::Literal(pat_literal) => if let Primitive::Literal(scrut_literal) = scrut { pat_literal == scrut_literal } else { @@ -287,13 +292,13 @@ impl<'a> State<'a> { Pattern::Tuple { subpatterns, tag } => match tag { None => match scrut { Primitive::Tuple(items) if items.len() == subpatterns.len() => - items.iter().zip(subpatterns.iter()).all(|(item, subpat)| matches(item, subpat)), + items.iter().zip(subpatterns.iter()).all(|(item, subpat)| matches(item, subpat, scope)), _ => false //TODO should be a type error }, Some(pattern_tag) => match scrut { //TODO should test type_ids for runtime type checking, once those work Primitive::Object { tag, items, .. } if tag == pattern_tag && items.len() == subpatterns.len() => { - items.iter().zip(subpatterns.iter()).all(|(item, subpat)| matches(item, subpat)) + items.iter().zip(subpatterns.iter()).all(|(item, subpat)| matches(item, subpat, scope)) } _ => false } @@ -303,9 +308,13 @@ impl<'a> State<'a> { let cond = self.expression(cond)?; for alt in alternatives.into_iter() { - if matches(&cond, &alt.pattern) { - // Set up local vars - return self.block(alt.item) + let mut new_scope = self.environments.new_scope(None); + if matches(&cond, &alt.pattern, &mut new_scope) { + let mut new_state = State { + environments: new_scope + }; + + return new_state.block(alt.item) } } Err("No valid match in match expression".into()) @@ -382,7 +391,11 @@ impl<'a> State<'a> { }, /* Binops */ (binop, &[ref lhs, ref rhs]) => match (binop, lhs, rhs) { + // TODO need a better way of handling these literals (Add, Lit(Nat(l)), Lit(Nat(r))) => Nat(l + r).into(), + (Add, Lit(Int(l)), Lit(Int(r))) => Int(l + r).into(), + (Add, Lit(Nat(l)), Lit(Int(r))) => Int((*l as i64) + (*r as i64)).into(), + (Add, Lit(Int(l)), Lit(Nat(r))) => Int((*l as i64) + (*r as i64)).into(), (Concatenate, Lit(StringLit(ref s1)), Lit(StringLit(ref s2))) => StringLit(Rc::new(format!("{}{}", s1, s2))).into(), (Subtract, Lit(Nat(l)), Lit(Nat(r))) => Nat(l - r).into(), (Multiply, Lit(Nat(l)), Lit(Nat(r))) => Nat(l * r).into(), diff --git a/schala-lang/language/src/tree_walk_eval/test.rs b/schala-lang/language/src/tree_walk_eval/test.rs index 5deddff..c33cf2c 100644 --- a/schala-lang/language/src/tree_walk_eval/test.rs +++ b/schala-lang/language/src/tree_walk_eval/test.rs @@ -10,6 +10,8 @@ fn evaluate_input(input: &str) -> Result { symbol_table.process_ast(&ast).unwrap(); let reduced_ir = crate::reduced_ir::reduce(&ast, &symbol_table); reduced_ir.debug(&symbol_table); + println!("========"); + symbol_table.debug(); let mut state = State::new(); let mut outputs = state.evaluate(reduced_ir, true); outputs.pop().unwrap() @@ -134,16 +136,17 @@ println!("{}", source); fn if_is_patterns() { let source = r#" type Option = Some(T) | None +let q = "a string" let x = Option::Some(9); if x is Option::Some(q) then { q } else { 0 }"#; eval_assert(source, "9"); -/* let source = r#" type Option = Some(T) | None -let x = Option::None; if x is Option::Some(q) then { q } else { 0 }"#; +let q = "a string" +let outer = 2 +let x = Option::None; if x is Option::Some(q) then { q } else { -2 + outer }"#; eval_assert(source, "0"); -*/ }