diff --git a/schala-lang/src/typechecking.rs b/schala-lang/src/typechecking.rs index 1c55016..993a4fa 100644 --- a/schala-lang/src/typechecking.rs +++ b/schala-lang/src/typechecking.rs @@ -130,9 +130,17 @@ impl TypeEnvironment { self.map.get(name).map(|x| x.clone()) } - fn insert(&mut self, name: &TypeName, ty: PolyType) { + fn extend(&mut self, name: &TypeName, ty: PolyType) { self.map.insert(name.clone(), ty); } + + fn free_vars(&self) -> HashSet { + let mut free = HashSet::new(); + for (_, ptype) in self.map.iter() { + free = free.union(&ptype.free_vars()).cloned().collect() + } + free + } } pub struct TypeContext { @@ -156,7 +164,7 @@ impl TypeContext { } struct Infer<'a> { - env: &'a TypeEnvironment + env: &'a mut TypeEnvironment } #[derive(Debug)] @@ -170,13 +178,21 @@ enum InferError { type InferResult = Result; impl<'a> Infer<'a> { + + fn generalize(&mut self, ty: MonoType) -> PolyType { + let free_mtype = ty.free_vars(); + let free_env = self.env.free_vars(); + let diff: HashSet = free_mtype.difference(&free_env).cloned().collect(); + PolyType(diff, ty) + } + fn block(&mut self, block: &Vec) -> InferResult { let mut ret = MonoType::Const(TypeConst::Unit); for s in block { ret = match s { - parsing::Statement::ExpressionStatement(expr) => self.anno_expression(expr)?, + parsing::Statement::ExpressionStatement(expr) => self.infer_expression(expr)?, parsing::Statement::Declaration(decl) => { - self.declaration(decl)?; + self.infer_declaration(decl)?; MonoType::Const(TypeConst::Unit) } } @@ -184,19 +200,20 @@ impl<'a> Infer<'a> { Ok(ret) } - fn declaration(&mut self, decl: &parsing::Declaration) -> InferResult { + fn infer_declaration(&mut self, decl: &parsing::Declaration) -> InferResult { use parsing::Declaration::*; match decl { Binding { name, expr, .. } => { - let ty = self.anno_expression(&expr)?; - return Err(InferError::Custom(format!("This decl not yet supported"))) + let tau: MonoType = self.infer_expression(&expr)?; + let sigma = self.generalize(tau); + self.env.extend(name, sigma); }, _ => return Err(InferError::Custom(format!("This decl not yet supported"))) } Ok(MonoType::Const(TypeConst::Unit)) } - fn anno_expression(&mut self, expr: &parsing::Expression) -> InferResult { + fn infer_expression(&mut self, expr: &parsing::Expression) -> InferResult { match expr { parsing::Expression(e, Some(anno)) => { return Err(InferError::Custom(format!("Annotations not done yet"))) @@ -206,11 +223,11 @@ impl<'a> Infer<'a> { self.unify(ty, anno_ty) */ }, - parsing::Expression(e, None) => self.expression(e) + parsing::Expression(e, None) => self.infer_expression_type(e) } } - fn expression(&mut self, expr: &parsing::ExpressionType) -> InferResult { + fn infer_expression_type(&mut self, expr: &parsing::ExpressionType) -> InferResult { use self::parsing::ExpressionType::*; Ok(match expr { NatLiteral(_) => MonoType::Const(TypeConst::Nat),