From 68506571a8e37e049d7e439e0fab19a979580f24 Mon Sep 17 00:00:00 2001 From: Greg Shuflin Date: Fri, 29 Oct 2021 22:03:34 -0700 Subject: [PATCH] Implement records --- schala-lang/language/src/reduced_ir/mod.rs | 55 ++++++++++++++----- schala-lang/language/src/reduced_ir/test.rs | 2 +- schala-lang/language/src/reduced_ir/types.rs | 2 +- schala-lang/language/src/schala.rs | 2 +- .../language/src/tree_walk_eval/evaluator.rs | 22 ++++++-- .../language/src/tree_walk_eval/mod.rs | 18 +++++- .../language/src/tree_walk_eval/test.rs | 30 +++++++++- .../language/src/type_inference/mod.rs | 16 +++++- 8 files changed, 119 insertions(+), 28 deletions(-) diff --git a/schala-lang/language/src/reduced_ir/mod.rs b/schala-lang/language/src/reduced_ir/mod.rs index b2923da..5d9bc25 100644 --- a/schala-lang/language/src/reduced_ir/mod.rs +++ b/schala-lang/language/src/reduced_ir/mod.rs @@ -4,6 +4,7 @@ use crate::{ ast, builtin::Builtin, symbol_table::{DefId, SymbolSpec, SymbolTable}, + type_inference::TypeContext, }; mod test; @@ -11,19 +12,20 @@ mod types; pub use types::*; -pub fn reduce(ast: &ast::AST, symbol_table: &SymbolTable) -> ReducedIR { - let reducer = Reducer::new(symbol_table); +pub fn reduce(ast: &ast::AST, symbol_table: &SymbolTable, type_context: &TypeContext) -> ReducedIR { + let reducer = Reducer::new(symbol_table, type_context); reducer.reduce(ast) } -struct Reducer<'a> { +struct Reducer<'a, 'b> { symbol_table: &'a SymbolTable, functions: HashMap, + type_context: &'b TypeContext, } -impl<'a> Reducer<'a> { - fn new(symbol_table: &'a SymbolTable) -> Self { - Self { symbol_table, functions: HashMap::new() } +impl<'a, 'b> Reducer<'a, 'b> { + fn new(symbol_table: &'a SymbolTable, type_context: &'b TypeContext) -> Self { + Self { symbol_table, functions: HashMap::new(), type_context } } fn reduce(mut self, ast: &ast::AST) -> ReducedIR { @@ -132,20 +134,45 @@ impl<'a> Reducer<'a> { body: self.function_internal_block(body), }), NamedStruct { name, fields } => { + self.symbol_table.debug(); let symbol = self.symbol_table.lookup_symbol(&name.id).unwrap(); - let constructor = match symbol.spec() { - SymbolSpec::RecordConstructor { tag, members: _, type_id } => - Expression::Callable(Callable::RecordConstructor { type_id, tag }), + let (tag, type_id) = match symbol.spec() { + SymbolSpec::RecordConstructor { tag, members: _, type_id } => (tag, type_id), e => return Expression::ReductionError(format!("Bad symbol for NamedStruct: {:?}", e)), }; - //TODO need to order the fields correctly, which needs symbol table information - // Until this happens, NamedStructs won't work - let mut ordered_args = vec![]; - for (_name, _type_id) in fields { - unimplemented!() + // Eventually, the ReducedIR should decide what field ordering is optimal. + // For now, just do it alphabetically. + let mut field_order: Vec = self + .type_context + .lookup_record_members(&type_id, tag) + .unwrap() + .iter() + .map(|(field, _type_id)| field) + .cloned() + .collect(); + field_order.sort_unstable(); + + let mut field_map = HashMap::new(); + for (name, expr) in fields.iter() { + field_map.insert(name.as_ref(), expr); } + let mut ordered_args = vec![]; + for field in field_order.iter() { + let expr = match field_map.get(&field) { + Some(expr) => expr, + None => + return Expression::ReductionError(format!( + "Field {} not specified for record {}", + field, name + )), + }; + ordered_args.push(self.expression(expr)); + } + + let constructor = + Expression::Callable(Callable::RecordConstructor { type_id, tag, field_order }); Expression::Call { f: Box::new(constructor), args: ordered_args } } Index { .. } => Expression::ReductionError("Index expr not implemented".to_string()), diff --git a/schala-lang/language/src/reduced_ir/test.rs b/schala-lang/language/src/reduced_ir/test.rs index a2b603f..55eb4a7 100644 --- a/schala-lang/language/src/reduced_ir/test.rs +++ b/schala-lang/language/src/reduced_ir/test.rs @@ -11,7 +11,7 @@ fn build_ir(input: &str) -> ReducedIR { symbol_table.process_ast(&ast, &mut type_context).unwrap(); - let reduced = reduce(&ast, &symbol_table); + let reduced = reduce(&ast, &symbol_table, &type_context); reduced.debug(&symbol_table); reduced } diff --git a/schala-lang/language/src/reduced_ir/types.rs b/schala-lang/language/src/reduced_ir/types.rs index 5ef73aa..47fd4b7 100644 --- a/schala-lang/language/src/reduced_ir/types.rs +++ b/schala-lang/language/src/reduced_ir/types.rs @@ -74,7 +74,7 @@ pub enum Callable { UserDefined(DefId), Lambda { arity: u8, body: Vec }, DataConstructor { type_id: TypeId, tag: u32 }, - RecordConstructor { type_id: TypeId, tag: u32 }, + RecordConstructor { type_id: TypeId, tag: u32, field_order: Vec }, } #[derive(Debug, Clone)] diff --git a/schala-lang/language/src/schala.rs b/schala-lang/language/src/schala.rs index c104f32..5751137 100644 --- a/schala-lang/language/src/schala.rs +++ b/schala-lang/language/src/schala.rs @@ -91,7 +91,7 @@ impl<'a> Schala<'a> { // TODO typechecking not working //let _overall_type = self.type_context.typecheck(&ast).map_err(SchalaError::from_type_error); - let reduced_ir = reduced_ir::reduce(&ast, &self.symbol_table); + let reduced_ir = reduced_ir::reduce(&ast, &self.symbol_table, &self.type_context); let evaluation_outputs = self.eval_state.evaluate(reduced_ir, &self.type_context, true); let text_output: Result, String> = evaluation_outputs.into_iter().collect(); diff --git a/schala-lang/language/src/tree_walk_eval/evaluator.rs b/schala-lang/language/src/tree_walk_eval/evaluator.rs index b1e0a29..1d6adf8 100644 --- a/schala-lang/language/src/tree_walk_eval/evaluator.rs +++ b/schala-lang/language/src/tree_walk_eval/evaluator.rs @@ -109,7 +109,7 @@ impl<'a, 'b> Evaluator<'a, 'b> { Expression::Callable(Callable::DataConstructor { type_id, tag }) => { let arity = self.type_context.lookup_variant_arity(&type_id, tag).unwrap(); if arity == 0 { - Primitive::Object { type_id, tag, items: vec![] } + Primitive::Object { type_id, tag, items: vec![], ordered_fields: None } } else { Primitive::Callable(Callable::DataConstructor { type_id, tag }) } @@ -222,14 +222,24 @@ impl<'a, 'b> Evaluator<'a, 'b> { .into()); } - let mut evaluated_args: Vec = vec![]; + let mut items: Vec = vec![]; for arg in args.into_iter() { - evaluated_args.push(self.expression(arg)?); + items.push(self.expression(arg)?); } - Ok(Primitive::Object { type_id, tag, items: evaluated_args }) + Ok(Primitive::Object { type_id, tag, items, ordered_fields: None }) } - Callable::RecordConstructor { type_id: _, tag: _ } => { - unimplemented!() + Callable::RecordConstructor { type_id, tag, field_order } => { + //TODO maybe I'll want to do a runtime check of the evaluated fields + /* + let record_members = self.type_context.lookup_record_members(type_id, tag) + .ok_or(format!("Runtime record lookup for: {} {} not found", type_id, tag).into())?; + */ + + let mut items: Vec = vec![]; + for arg in args.into_iter() { + items.push(self.expression(arg)?); + } + Ok(Primitive::Object { type_id, tag, items, ordered_fields: Some(field_order) }) } } } diff --git a/schala-lang/language/src/tree_walk_eval/mod.rs b/schala-lang/language/src/tree_walk_eval/mod.rs index 568622b..af443d8 100644 --- a/schala-lang/language/src/tree_walk_eval/mod.rs +++ b/schala-lang/language/src/tree_walk_eval/mod.rs @@ -121,21 +121,33 @@ enum Primitive { Tuple(Vec), Literal(Literal), Callable(Callable), - Object { type_id: TypeId, tag: u32, items: Vec }, + Object { type_id: TypeId, tag: u32, ordered_fields: Option>, items: Vec }, } impl Primitive { fn to_repl(&self, type_context: &TypeContext) -> String { match self { - Primitive::Object { type_id, items, tag } if items.is_empty() => + Primitive::Object { type_id, items, tag, ordered_fields: _ } if items.is_empty() => type_context.variant_local_name(type_id, *tag).unwrap().to_string(), - Primitive::Object { type_id, items, tag } => { + Primitive::Object { type_id, items, tag, ordered_fields: None } => { format!( "{}{}", type_context.variant_local_name(type_id, *tag).unwrap(), paren_wrapped(items.iter().map(|item| item.to_repl(type_context))) ) } + Primitive::Object { type_id, items, tag, ordered_fields: Some(fields) } => { + let mut buf = format!("{}", type_context.variant_local_name(type_id, *tag).unwrap()); + write!(buf, " {{ ").unwrap(); + for item in fields.iter().zip(items.iter()).map(Some).intersperse(None) { + match item { + Some((name, val)) => write!(buf, "{}: {}", name, val.to_repl(type_context)).unwrap(), + None => write!(buf, ", ").unwrap(), + } + } + write!(buf, " }}").unwrap(); + buf + } Primitive::Literal(lit) => match lit { Literal::Nat(n) => format!("{}", n), Literal::Int(i) => format!("{}", i), diff --git a/schala-lang/language/src/tree_walk_eval/test.rs b/schala-lang/language/src/tree_walk_eval/test.rs index e49d1e6..7f60eb7 100644 --- a/schala-lang/language/src/tree_walk_eval/test.rs +++ b/schala-lang/language/src/tree_walk_eval/test.rs @@ -14,7 +14,7 @@ fn evaluate_input(input: &str) -> Result { symbol_table.process_ast(&ast, &mut type_context).unwrap(); - let reduced_ir = crate::reduced_ir::reduce(&ast, &symbol_table); + let reduced_ir = crate::reduced_ir::reduce(&ast, &symbol_table, &type_context); reduced_ir.debug(&symbol_table); println!("========"); symbol_table.debug(); @@ -29,6 +29,10 @@ fn eval_assert(input: &str, expected: &str) { assert_eq!(evaluate_input(input), Ok(expected.to_string())); } +fn eval_assert_failure(input: &str, expected: &str) { + assert_eq!(evaluate_input(input), Err(expected.to_string())); +} + #[test] fn test_basic_eval() { eval_assert("1 + 2", "3"); @@ -85,6 +89,30 @@ let b = Option::Some(10) eval_assert(source, "(Some(10), None)"); } +#[test] +fn adt_output_2() { + let source = r#" +type Gobble = Unknown | Rufus { a: Int, torrid: Nat } +let b = Gobble::Rufus { a: 3, torrid: 99 } +b + "#; + eval_assert(source, "Rufus { a: 3, torrid: 99 }"); + + let source = r#" +type Gobble = Unknown | Rufus { a: Int, torrid: Nat } +let b = Gobble::Rufus { torrid: 3, a: 84 } +b + "#; + eval_assert(source, "Rufus { a: 84, torrid: 3 }"); + + let source = r#" +type Gobble = Unknown | Rufus { a: Int, torrid: Nat } +let b = Gobble::Rufus { a: 84 } +b + "#; + eval_assert_failure(source, "Field torrid not specified for record Gobble::Rufus"); +} + #[test] fn basic_if_statement() { let source = r#" diff --git a/schala-lang/language/src/type_inference/mod.rs b/schala-lang/language/src/type_inference/mod.rs index 9a8dd0a..1c89277 100644 --- a/schala-lang/language/src/type_inference/mod.rs +++ b/schala-lang/language/src/type_inference/mod.rs @@ -26,7 +26,7 @@ impl TypeContext { let members = variant_builder.members; if members.is_empty() { pending_variants.push(Variant { name: variant_builder.name, members: VariantMembers::Unit }); - break; + continue; } let record_variant = matches!(members.get(0).unwrap(), VariantMemberBuilder::KeyVal(..)); @@ -84,6 +84,15 @@ impl TypeContext { ) } + pub fn lookup_record_members(&self, type_id: &TypeId, tag: u32) -> Option<&[(String, TypeId)]> { + self.defined_types.get(type_id).and_then(|defined| defined.variants.get(tag as usize)).and_then( + |variant| match &variant.members { + VariantMembers::Record(items) => Some(items.as_ref()), + _ => None, + }, + ) + } + pub fn lookup_type(&self, type_id: &TypeId) -> Option<&DefinedType> { self.defined_types.get(type_id) } @@ -91,6 +100,7 @@ impl TypeContext { /// A type defined in program source code, as opposed to a builtin. #[allow(dead_code)] +#[derive(Debug)] pub struct DefinedType { pub name: String, @@ -98,11 +108,13 @@ pub struct DefinedType { pub variants: Vec, } +#[derive(Debug)] pub struct Variant { pub name: String, pub members: VariantMembers, } +#[derive(Debug)] pub enum VariantMembers { Unit, // Should be non-empty @@ -124,6 +136,7 @@ impl From<&TypeIdentifier> for PendingType { } } +#[derive(Debug)] pub struct TypeBuilder { name: String, variants: Vec, @@ -139,6 +152,7 @@ impl TypeBuilder { } } +#[derive(Debug)] pub struct VariantBuilder { name: String, members: Vec,